-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathmmdet_dataset.py
More file actions
186 lines (156 loc) · 6.37 KB
/
mmdet_dataset.py
File metadata and controls
186 lines (156 loc) · 6.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""SHIFT dataset for mmdet.
This is a reference code for mmdet style dataset of the SHIFT dataset. Note that
only single-view 2D detection, instance segmentation, and tracking are supported.
Please refer to the torch version of the dataloader for multi-view multi-task cases.
The codes are tested in mmdet-2.20.0.
Example
-------
Below is a snippet showing how to add the SHIFTDataset class in mmdet config files.
>>> dict(
>>> type='SHIFTDataset',
>>> data_root='./SHIFT_dataset/discrete/images'
>>> ann_file='train/front/det_2d.json',
>>> img_prefix='train/front/img.zip',
>>> backend_type='zip',
>>> pipeline=[
>>> ...
>>> ]
>>> )
Notes
-----
1. Please copy this file to `mmdet/datasets/` and update the `mmdet/datasets/__init__.py`
so that the `SHIFTDataset` class is imported. You can refer to their official tutorial at
https://mmdetection.readthedocs.io/en/latest/tutorials/customize_dataset.html.
2. The `backend_type` must be one of ['file', 'zip', 'hdf5'] and the `img_prefix`
must be consistent with the backend_type.
3. Since the images are loaded before the pipeline with the selected backend, there is no need
to add a `LoadImageFromFile` module in the pipeline again.
4. For instance segmentation please use the `det_insseg_2d.json` for the `ann_file`,
and add a `LoadAnnotations(with_mask=True)` module in the pipeline.
"""
import json
import os
import sys
import mmcv
import numpy as np
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset
from mmdet.datasets.pipelines import LoadAnnotations
# Add the root directory of the project to the path. Remove the following two lines
# if you have installed shift_dev as a package.
root_dir = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
sys.path.append(root_dir)
from shift_dev.utils.backend import HDF5Backend, ZipBackend
@DATASETS.register_module()
class SHIFTDataset(CustomDataset):
CLASSES = ("pedestrian", "car", "truck", "bus", "motorcycle", "bicycle")
WIDTH = 1280
HEIGHT = 800
def __init__(self, *args, backend_type: str = "file", **kwargs):
"""Initialize the SHIFT dataset.
Args:
backend_type (str, optional): The type of the backend. Must be one of
['file', 'zip', 'hdf5']. Defaults to "file".
"""
super().__init__(*args, **kwargs)
self.backend_type = backend_type
if backend_type == "file":
self.backend = None
elif backend_type == "zip":
self.backend = ZipBackend()
elif backend_type == "hdf5":
self.backend = HDF5Backend()
else:
raise ValueError(
f"Unknown backend type: {backend_type}! "
"Must be one of ['file', 'zip', 'hdf5']"
)
def load_annotations(self, ann_file):
print("Loading annotations...")
with open(ann_file, "r") as f:
data = json.load(f)
data_infos = []
for img_info in data["frames"]:
img_filename = os.path.join(
self.img_prefix, img_info["videoName"], img_info["name"]
)
bboxes = []
labels = []
track_ids = []
masks = []
for label in img_info["labels"]:
bbox = label["box2d"]
bboxes.append((bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]))
labels.append(self.CLASSES.index(label["category"]))
track_ids.append(label["id"])
if "rle" in label and label["rle"] is not None:
masks.append(label["rle"])
data_infos.append(
dict(
filename=img_filename,
width=self.WIDTH,
height=self.HEIGHT,
ann=dict(
bboxes=np.array(bboxes).astype(np.float32),
labels=np.array(labels).astype(np.int64),
track_ids=np.array(track_ids).astype(np.int64),
masks=masks if len(masks) > 0 else None,
),
)
)
return data_infos
def get_img(self, idx):
filename = self.data_infos[idx]["filename"]
if self.backend_type == "zip":
img_bytes = self.backend.get(filename)
return mmcv.imfrombytes(img_bytes)
elif self.backend_type == "hdf5":
img_bytes = self.backend.get(filename)
return mmcv.imfrombytes(img_bytes)
else:
return mmcv.imread(filename)
def get_img_info(self, idx):
return dict(
filename=self.data_infos[idx]["filename"],
width=self.WIDTH,
height=self.HEIGHT,
)
def get_ann_info(self, idx):
return self.data_infos[idx]["ann"]
def prepare_train_img(self, idx):
img = self.get_img(idx)
img_info = self.get_img_info(idx)
ann_info = self.get_ann_info(idx)
# Filter out images without annotations during training
if len(ann_info["bboxes"]) == 0:
return None
results = dict(img=img, img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return self.pipeline(results)
def prepare_test_img(self, idx):
img = self.get_img(idx)
img_info = self.get_img_info(idx)
results = dict(img=img, img_info=img_info)
self.pre_pipeline(results)
return self.pipeline(results)
if __name__ == "__main__":
"""Example for loading the SHIFT dataset for instance segmentation."""
dataset = SHIFTDataset(
data_root="./SHIFT_dataset/discrete/images",
ann_file="train/front/det_insseg_2d.json",
img_prefix="train/front/img.zip",
backend_type="zip",
pipeline=[LoadAnnotations(with_mask=True)],
)
# Print the dataset size
print(f"Total number of samples: {len(dataset)}.")
# Print the tensor shape of the first batch.
for i, data in enumerate(dataset):
print(f"Sample {i}:")
print("img:", data["img"].shape)
print("ann_info.bboxes:", data["ann_info"]["bboxes"].shape)
print("ann_info.labels:", data["ann_info"]["labels"].shape)
print("ann_info.track_ids:", data["ann_info"]["track_ids"].shape)
if "gt_masks" in data:
print("gt_masks:", data["gt_masks"])
break