Skip to content

Commit

Permalink
[Fix] Fix img_shape in data pipeline (#9966)
Browse files Browse the repository at this point in the history
  • Loading branch information
RangiLyu authored Mar 17, 2023
1 parent 06d338c commit 0f5cd10
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
20 changes: 10 additions & 10 deletions mmdet/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def _crop_data(self, results: dict, crop_size: Tuple[int, int],
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
img_shape = img.shape
results['img'] = img
results['img_shape'] = img_shape
results['img_shape'] = img_shape[:2]

# crop bboxes accordingly and clip to the image boundary
if results.get('gt_bboxes', None) is not None:
Expand Down Expand Up @@ -1510,7 +1510,7 @@ def transform(self, results: dict) -> Union[dict, None]:
return None
# back to the original format
results = self.mapper(results, self.keymap_back)
results['img_shape'] = results['img'].shape
results['img_shape'] = results['img'].shape[:2]
return results

def _preprocess_results(self, results: dict) -> tuple:
Expand Down Expand Up @@ -1861,7 +1861,7 @@ def _train_aug(self, results):

if len(gt_bboxes) == 0:
results['img'] = cropped_img
results['img_shape'] = cropped_img.shape
results['img_shape'] = cropped_img.shape[:2]
return results

# if image do not have valid bbox, any crop patch is valid.
Expand All @@ -1870,7 +1870,7 @@ def _train_aug(self, results):
continue

results['img'] = cropped_img
results['img_shape'] = cropped_img.shape
results['img_shape'] = cropped_img.shape[:2]

x0, y0, x1, y1 = patch

Expand Down Expand Up @@ -1936,7 +1936,7 @@ def _test_aug(self, results):
cropped_img, border, _ = self._crop_image_and_paste(
img, [h // 2, w // 2], [target_h, target_w])
results['img'] = cropped_img
results['img_shape'] = cropped_img.shape
results['img_shape'] = cropped_img.shape[:2]
results['border'] = border
return results

Expand Down Expand Up @@ -2240,7 +2240,7 @@ def transform(self, results: dict) -> dict:
mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]

results['img'] = mosaic_img
results['img_shape'] = mosaic_img.shape
results['img_shape'] = mosaic_img.shape[:2]
results['gt_bboxes'] = mosaic_bboxes
results['gt_bboxes_labels'] = mosaic_bboxes_labels
results['gt_ignore_flags'] = mosaic_ignore_flags
Expand Down Expand Up @@ -2522,7 +2522,7 @@ def transform(self, results: dict) -> dict:
mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds]

results['img'] = mixup_img.astype(np.uint8)
results['img_shape'] = mixup_img.shape
results['img_shape'] = mixup_img.shape[:2]
results['gt_bboxes'] = mixup_gt_bboxes
results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
results['gt_ignore_flags'] = mixup_gt_ignore_flags
Expand Down Expand Up @@ -2645,7 +2645,7 @@ def transform(self, results: dict) -> dict:
dsize=(width, height),
borderValue=self.border_val)
results['img'] = img
results['img_shape'] = img.shape
results['img_shape'] = img.shape[:2]

bboxes = results['gt_bboxes']
num_bboxes = len(bboxes)
Expand Down Expand Up @@ -3334,7 +3334,7 @@ def transform(self, results: dict) -> dict:
mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]

results['img'] = mosaic_img
results['img_shape'] = mosaic_img.shape
results['img_shape'] = mosaic_img.shape[:2]
results['gt_bboxes'] = mosaic_bboxes
results['gt_bboxes_labels'] = mosaic_bboxes_labels
results['gt_ignore_flags'] = mosaic_ignore_flags
Expand Down Expand Up @@ -3614,7 +3614,7 @@ def transform(self, results: dict) -> dict:
mixup_gt_masks = mixup_gt_masks[inside_inds]

results['img'] = mixup_img.astype(np.uint8)
results['img_shape'] = mixup_img.shape
results['img_shape'] = mixup_img.shape[:2]
results['gt_bboxes'] = mixup_gt_bboxes
results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
results['gt_ignore_flags'] = mixup_gt_ignore_flags
Expand Down
8 changes: 4 additions & 4 deletions mmdet/structures/det_data_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class DetDataSample(BaseDataElement):
>>> from mmdet.structures import DetDataSample
>>> data_sample = DetDataSample()
>>> img_meta = dict(img_shape=(800, 1196, 3),
... pad_shape=(800, 1216, 3))
>>> img_meta = dict(img_shape=(800, 1196),
... pad_shape=(800, 1216))
>>> gt_instances = InstanceData(metainfo=img_meta)
>>> gt_instances.bboxes = torch.rand((5, 4))
>>> gt_instances.labels = torch.rand((5,))
Expand All @@ -48,8 +48,8 @@ class DetDataSample(BaseDataElement):
gt_instances: <InstanceData(
META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
pad_shape: (800, 1216)
img_shape: (800, 1196)
DATA FIELDS
labels: tensor([0.8533, 0.1550, 0.5433, 0.7294, 0.5098])
Expand Down
10 changes: 10 additions & 0 deletions tests/test_datasets/test_transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def test_transform(self):
results['gt_bboxes'].shape[0])
self.assertEqual(results['gt_labels'].dtype, np.int64)
self.assertEqual(results['gt_bboxes'].dtype, np.float32)
self.assertEqual(results['img_shape'], results['img'].shape[:2])

patch = np.array(
[0, 0, results['img_shape'][1], results['img_shape'][0]])
Expand Down Expand Up @@ -514,6 +515,7 @@ def setUp(self):
def test_transform(self):
transform = Expand()
results = transform.transform(copy.deepcopy(self.results))
self.assertEqual(results['img_shape'], results['img'].shape[:2])
self.assertEqual(
results['img_shape'],
(results['gt_masks'].height, results['gt_masks'].width))
Expand Down Expand Up @@ -609,13 +611,15 @@ def test_transform(self):
h, w = results['img'].shape
self.assertTrue(10 <= w <= 20)
self.assertTrue(10 <= h <= 20)
self.assertEqual(results['img_shape'], results['img'].shape[:2])
# test relative_range crop
transform = RandomCrop(
crop_size=(0.5, 0.5), crop_type='relative_range')
results = transform(copy.deepcopy(src_results))
h, w = results['img'].shape
self.assertTrue(16 <= w <= 32)
self.assertTrue(12 <= h <= 24)
self.assertEqual(results['img_shape'], results['img'].shape[:2])

# test with gt_bboxes, gt_bboxes_labels, gt_ignore_flags,
# gt_masks, gt_seg_map
Expand Down Expand Up @@ -649,6 +653,7 @@ def test_transform(self):
self.assertEqual(results['gt_bboxes_labels'].shape[0], 2)
self.assertEqual(results['gt_ignore_flags'].shape[0], 2)
self.assertTupleEqual(results['gt_seg_map'].shape[:2], (5, 7))
self.assertEqual(results['img_shape'], results['img'].shape[:2])

# test geometric transformation with homography matrix
bboxes = copy.deepcopy(src_results['gt_bboxes'])
Expand Down Expand Up @@ -918,6 +923,7 @@ def test_transform(self):
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == np.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
self.assertEqual(results['img_shape'], results['img'].shape[:2])

def test_transform_with_no_gt(self):
self.results['gt_bboxes'] = np.empty((0, 4), dtype=np.float32)
Expand Down Expand Up @@ -1003,6 +1009,7 @@ def test_transform(self):
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == np.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
self.assertEqual(results['img_shape'], results['img'].shape[:2])

def test_transform_use_box_type(self):
results = copy.deepcopy(self.results)
Expand Down Expand Up @@ -1074,6 +1081,7 @@ def test_transform(self):
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == np.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
self.assertEqual(results['img_shape'], results['img'].shape[:2])

def test_transform_use_box_type(self):
results = copy.deepcopy(self.results)
Expand Down Expand Up @@ -1249,6 +1257,7 @@ def test_transform(self):
assert train_results['img_shape'][:2] == (h - 20, w - 20)
assert train_results['gt_bboxes'].shape[0] == 4
assert train_results['gt_bboxes'].dtype == np.float32
self.assertEqual(results['img_shape'], results['img'].shape[:2])

crop_module = RandomCenterCropPad(
crop_size=None,
Expand Down Expand Up @@ -1534,6 +1543,7 @@ def test_transform(self):
self.assertEqual(results['gt_bboxes'].dtype, np.float32)
self.assertEqual(results['gt_ignore_flags'].dtype, bool)
self.assertEqual(results['gt_bboxes_labels'].dtype, np.int64)
self.assertEqual(results['img_shape'], results['img'].shape[:2])

@unittest.skipIf(albumentations is None, 'albumentations is not installed')
def test_repr(self):
Expand Down

0 comments on commit 0f5cd10

Please sign in to comment.