-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathtorch_dataset.py
More file actions
77 lines (64 loc) · 3.01 KB
/
torch_dataset.py
File metadata and controls
77 lines (64 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""SHIFT dataset example using PyTorch."""
import os
import sys
import torch
from torch.utils.data import DataLoader
# Add the root directory of the project to the path. Remove the following two lines
# if you have installed shift_dev as a package.
root_dir = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
sys.path.append(root_dir)
from shift_dev import SHIFTDataset
from shift_dev.types import Keys
from shift_dev.utils.backend import ZipBackend
def main():
"""Load the SHIFT dataset and print the tensor shape of the first batch."""
dataset = SHIFTDataset(
data_root="./SHIFT_dataset/",
split="val",
keys_to_load=[
Keys.images, # note: images, shape (1, 3, H, W), uint8 (RGB)
Keys.intrinsics, # note: camera intrinsics, shape (3, 3)
Keys.boxes2d, # note: 2D boxes in image coordinate, (x1, y1, x2, y2)
Keys.boxes2d_classes, # note: class indices, shape (num_boxes,)
Keys.boxes2d_track_ids, # note: object ids, shape (num_ins,)
Keys.boxes3d, # note: 3D boxes in camera coordinate, (x, y, z, dim_x, dim_y, dim_z, rot_x, rot_y, rot_z)
Keys.boxes3d_classes, # note: class indices, shape (num_boxes,), the same as 'boxes2d_classes'
Keys.boxes3d_track_ids, # note: object ids, shape (num_ins,), the same as 'boxes2d_track_ids'
Keys.segmentation_masks, # note: semantic masks, shape (1, H, W), long
Keys.masks, # note: instance masks, shape (num_ins, H, W), binary
Keys.depth_maps, # note: depth maps, shape (1, H, W), float (meters)
],
views_to_load=["front"],
shift_type="discrete", # also supports "continuous/1x", "continuous/10x", "continuous/100x"
backend=ZipBackend(), # also supports HDF5Backend(), FileBackend()
verbose=True,
)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
)
# Print the dataset size
print(f"Total number of samples: {len(dataset)}.")
# Print the tensor shape of the first batch.
print('\n')
for i, batch in enumerate(dataloader):
print(f"Batch {i}:\n")
print(f"{'Item':20} {'Shape':35} {'Min':10} {'Max':10}")
print("-" * 80)
for k, data in batch["front"].items():
if isinstance(data, torch.Tensor):
print(f"{k:20} {str(data.shape):35} {data.min():10.2f} {data.max():10.2f}")
else:
print(f"{k:20} {data}")
break
# Print the sample indices within a video.
# The video indices groups frames based on their video sequences. They are useful for training on videos.
print('\n')
video_to_indices = dataset.video_to_indices
for video, indices in video_to_indices.items():
print(f"Video name: {video}")
print(f"Sample indices within a video: {indices}")
break
if __name__ == "__main__":
main()