From 9bd595e40c693ced0dc8ce4f33d65048424189e5 Mon Sep 17 00:00:00 2001 From: MhLiao Date: Thu, 17 Sep 2020 19:52:40 -0400 Subject: [PATCH] fix magic number --- maskrcnn_benchmark/data/datasets/icdar.py | 2 +- maskrcnn_benchmark/data/datasets/scut.py | 2 +- maskrcnn_benchmark/data/datasets/synthtext.py | 2 +- maskrcnn_benchmark/data/datasets/tdtr.py | 2 +- .../data/datasets/total_text.py | 2 +- .../structures/segmentation_mask.py | 63 +++++++------------ 6 files changed, 27 insertions(+), 46 deletions(-) diff --git a/maskrcnn_benchmark/data/datasets/icdar.py b/maskrcnn_benchmark/data/datasets/icdar.py index caa6250..e8c4538 100644 --- a/maskrcnn_benchmark/data/datasets/icdar.py +++ b/maskrcnn_benchmark/data/datasets/icdar.py @@ -81,7 +81,7 @@ def __getitem__(self, item): if not self.use_charann: use_char_ann = False char_masks = SegmentationCharMask( - charsbbs, words=words, use_char_ann=use_char_ann, size=img.size + charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes) ) target.add_field("char_masks", char_masks) else: diff --git a/maskrcnn_benchmark/data/datasets/scut.py b/maskrcnn_benchmark/data/datasets/scut.py index 75fa2f1..74bfa28 100644 --- a/maskrcnn_benchmark/data/datasets/scut.py +++ b/maskrcnn_benchmark/data/datasets/scut.py @@ -79,7 +79,7 @@ def __getitem__(self, item): masks = SegmentationMask(segmentations, img.size) target.add_field("masks", masks) char_masks = SegmentationCharMask( - charsbbs, words=words, use_char_ann=use_char_ann, size=img.size + charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes) ) target.add_field("char_masks", char_masks) if self.transforms is not None: diff --git a/maskrcnn_benchmark/data/datasets/synthtext.py b/maskrcnn_benchmark/data/datasets/synthtext.py index e2f7123..9d8fa1c 100644 --- a/maskrcnn_benchmark/data/datasets/synthtext.py +++ b/maskrcnn_benchmark/data/datasets/synthtext.py @@ -61,7 +61,7 @@ def __getitem__(self, item): if not self.use_charann: use_char_ann = False char_masks = SegmentationCharMask( - charsbbs, words=words, use_char_ann=use_char_ann, size=img.size + charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes) ) target.add_field("char_masks", char_masks) if self.transforms is not None: diff --git a/maskrcnn_benchmark/data/datasets/tdtr.py b/maskrcnn_benchmark/data/datasets/tdtr.py index bfaff2f..4103c08 100644 --- a/maskrcnn_benchmark/data/datasets/tdtr.py +++ b/maskrcnn_benchmark/data/datasets/tdtr.py @@ -79,7 +79,7 @@ def __getitem__(self, item): masks = SegmentationMask(segmentations, img.size) target.add_field("masks", masks) char_masks = SegmentationCharMask( - charsbbs, words=words, use_char_ann=use_char_ann, size=img.size + charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes) ) target.add_field("char_masks", char_masks) else: diff --git a/maskrcnn_benchmark/data/datasets/total_text.py b/maskrcnn_benchmark/data/datasets/total_text.py index 5b98018..390d01c 100644 --- a/maskrcnn_benchmark/data/datasets/total_text.py +++ b/maskrcnn_benchmark/data/datasets/total_text.py @@ -79,7 +79,7 @@ def __getitem__(self, item): masks = SegmentationMask(segmentations, img.size) target.add_field("masks", masks) char_masks = SegmentationCharMask( - charsbbs, words=words, use_char_ann=use_char_ann, size=img.size + charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes) ) target.add_field("char_masks", char_masks) else: diff --git a/maskrcnn_benchmark/structures/segmentation_mask.py b/maskrcnn_benchmark/structures/segmentation_mask.py index 61274b1..0d1b072 100644 --- a/maskrcnn_benchmark/structures/segmentation_mask.py +++ b/maskrcnn_benchmark/structures/segmentation_mask.py @@ -332,6 +332,7 @@ def __init__( char_classes=None, size=None, mode=None, + char_num_classes=37, ): if isinstance(char_boxes, CharPolygons): if char_classes is None: @@ -352,6 +353,7 @@ def __init__( self.size = size self.mode = mode self.use_char_ann = use_char_ann + self.char_num_classes = char_num_classes def transpose(self, method): if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): @@ -381,6 +383,7 @@ def transpose(self, method): char_classes=self.char_classes, size=self.size, mode=self.mode, + char_num_classes=self.char_num_classes, ) def crop(self, box): @@ -403,6 +406,7 @@ def crop(self, box): char_classes=self.char_classes, size=(w, h), mode=self.mode, + char_num_classes=self.char_num_classes, ) def rotate(self, angle, r_c, start_h, start_w): @@ -423,6 +427,7 @@ def rotate(self, angle, r_c, start_h, start_w): char_classes=self.char_classes, size=(r_c[0] * 2, r_c[1] * 2), mode=self.mode, + char_num_classes=self.char_num_classes, ) def resize(self, size, *args, **kwargs): @@ -437,6 +442,7 @@ def resize(self, size, *args, **kwargs): char_classes=self.char_classes, size=size, mode=self.mode, + char_num_classes=self.char_num_classes, ) ratio_w, ratio_h = ratios @@ -454,6 +460,7 @@ def resize(self, size, *args, **kwargs): char_classes=self.char_classes, size=size, mode=self.mode, + char_num_classes=self.char_num_classes, ) def set_size(self, size): @@ -464,31 +471,16 @@ def convert(self, mode): if mode == "char_mask": if not self.use_char_ann: char_map = -np.ones((height, width)) - char_map_weight = np.zeros((37,)) + char_map_weight = np.zeros((self.char_num_classes,)) else: char_map = np.zeros((height, width)) - char_map_weight = np.ones((37,)) + char_map_weight = np.ones((self.char_num_classes,)) for i, p in enumerate(self.char_boxes): poly = p.numpy().reshape(4, 2) - # x_center = np.mean(poly[:,0], axis = 0).astype(np.int32) - # y_center = np.mean(poly[:,1], axis = 0).astype(np.int32) poly = shrink_poly(poly, 0.25) cv2.fillPoly( char_map, [poly.astype(np.int32)], int(self.char_classes[i]) ) - # if is_poly_inbox(poly, height, width) and x_center>=0 and x_center=0 and y_center 0) pos_num = pos_index[0].size if pos_num > 0: @@ -496,27 +488,24 @@ def convert(self, mode): char_map_weight[1:] = pos_weight return torch.from_numpy(char_map), torch.from_numpy(char_map_weight) elif mode == "seq_char_mask": - decoder_target = (38 - 1) * np.ones((32,)) + decoder_target = self.char_num_classes * np.ones((32,)) word_target = -np.ones((32,)) if not self.use_char_ann: char_map = -np.ones((height, width)) - char_map_weight = np.zeros((37,)) + char_map_weight = np.zeros((self.char_num_classes,)) for i, char in enumerate(self.word): if i > 31: break decoder_target[i] = char2num(char) word_target[i] = char2num(char) end_point = min(max(1, len(self.word)), 31) - word_target[end_point] = 37 + word_target[end_point] = self.char_num_classes else: char_map = np.zeros((height, width)) - char_map_weight = np.ones((37,)) + char_map_weight = np.ones((self.char_num_classes,)) word_length = 0 for i, p in enumerate(self.char_boxes): poly = p.numpy().reshape(4, 2) - # x_center = np.mean(poly[:,0], axis = 0).astype(np.int32) - # y_center = np.mean(poly[:,1], axis = 0).astype(np.int32) - # if is_poly_inbox(poly, height, width): if i < 32: decoder_target[i] = int(self.char_classes[i]) word_target[i] = int(self.char_classes[i]) @@ -525,21 +514,8 @@ def convert(self, mode): cv2.fillPoly( char_map, [poly.astype(np.int32)], int(self.char_classes[i]) ) - # if x_center>=0 and x_center=0 and y_center 0) pos_num = pos_index[0].size if pos_num > 0: @@ -576,7 +552,7 @@ def __repr__(self): class SegmentationCharMask(object): def __init__( - self, chars_boxes, words=None, use_char_ann=True, size=None, mode=None + self, chars_boxes, words=None, use_char_ann=True, size=None, mode=None, char_num_classes=37 ): # self.chars_boxes=[CharPolygons(char_boxes, word=word, use_char_ann=use_char_ann, size=size, mode=mode) for char_boxes, word in zip(chars_boxes, words)] if words is None: @@ -587,6 +563,7 @@ def __init__( use_char_ann=use_char_ann, size=size, mode=mode, + char_num_classes=char_num_classes, ) for char_boxes in chars_boxes ] @@ -598,12 +575,14 @@ def __init__( use_char_ann=use_char_ann, size=size, mode=mode, + char_num_classes=char_num_classes, ) for i, char_boxes in enumerate(chars_boxes) ] self.size = size self.mode = mode self.use_char_ann = use_char_ann + self.char_num_classes = char_num_classes def transpose(self, method): if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): @@ -615,7 +594,7 @@ def transpose(self, method): for char_boxes in self.chars_boxes: flipped.append(char_boxes.transpose(method)) return SegmentationCharMask( - flipped, use_char_ann=self.use_char_ann, size=self.size, mode=self.mode + flipped, use_char_ann=self.use_char_ann, size=self.size, mode=self.mode, char_num_classes=self.char_num_classes ) def crop(self, box, keep_ind): @@ -635,7 +614,7 @@ def resize(self, size, *args, **kwargs): for char_boxes in self.chars_boxes: scaled.append(char_boxes.resize(size, *args, **kwargs)) return SegmentationCharMask( - scaled, use_char_ann=self.use_char_ann, size=size, mode=self.mode + scaled, use_char_ann=self.use_char_ann, size=size, mode=self.mode, char_num_classes=self.char_num_classes ) def set_size(self, size): @@ -652,6 +631,7 @@ def rotate(self, angle, r_c, start_h, start_w): use_char_ann=self.use_char_ann, size=(r_c[0] * 2, r_c[1] * 2), mode=self.mode, + char_num_classes=self.char_num_classes, ) def __iter__(self): @@ -678,6 +658,7 @@ def __getitem__(self, item): use_char_ann=self.use_char_ann, size=self.size, mode=self.mode, + char_num_classes=self.char_num_classes, ) def __repr__(self):