Skip to content

Commit

Permalink
fix magic number
Browse files Browse the repository at this point in the history
  • Loading branch information
MhLiao committed Sep 17, 2020
1 parent b7114ef commit 9bd595e
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 46 deletions.
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/data/datasets/icdar.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __getitem__(self, item):
if not self.use_charann:
use_char_ann = False
char_masks = SegmentationCharMask(
charsbbs, words=words, use_char_ann=use_char_ann, size=img.size
charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes)
)
target.add_field("char_masks", char_masks)
else:
Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/data/datasets/scut.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __getitem__(self, item):
masks = SegmentationMask(segmentations, img.size)
target.add_field("masks", masks)
char_masks = SegmentationCharMask(
charsbbs, words=words, use_char_ann=use_char_ann, size=img.size
charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes)
)
target.add_field("char_masks", char_masks)
if self.transforms is not None:
Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/data/datasets/synthtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __getitem__(self, item):
if not self.use_charann:
use_char_ann = False
char_masks = SegmentationCharMask(
charsbbs, words=words, use_char_ann=use_char_ann, size=img.size
charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes)
)
target.add_field("char_masks", char_masks)
if self.transforms is not None:
Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/data/datasets/tdtr.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __getitem__(self, item):
masks = SegmentationMask(segmentations, img.size)
target.add_field("masks", masks)
char_masks = SegmentationCharMask(
charsbbs, words=words, use_char_ann=use_char_ann, size=img.size
charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes)
)
target.add_field("char_masks", char_masks)
else:
Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/data/datasets/total_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __getitem__(self, item):
masks = SegmentationMask(segmentations, img.size)
target.add_field("masks", masks)
char_masks = SegmentationCharMask(
charsbbs, words=words, use_char_ann=use_char_ann, size=img.size
charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes)
)
target.add_field("char_masks", char_masks)
else:
Expand Down
63 changes: 22 additions & 41 deletions maskrcnn_benchmark/structures/segmentation_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def __init__(
char_classes=None,
size=None,
mode=None,
char_num_classes=37,
):
if isinstance(char_boxes, CharPolygons):
if char_classes is None:
Expand All @@ -352,6 +353,7 @@ def __init__(
self.size = size
self.mode = mode
self.use_char_ann = use_char_ann
self.char_num_classes = char_num_classes

def transpose(self, method):
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
Expand Down Expand Up @@ -381,6 +383,7 @@ def transpose(self, method):
char_classes=self.char_classes,
size=self.size,
mode=self.mode,
char_num_classes=self.char_num_classes,
)

def crop(self, box):
Expand All @@ -403,6 +406,7 @@ def crop(self, box):
char_classes=self.char_classes,
size=(w, h),
mode=self.mode,
char_num_classes=self.char_num_classes,
)

def rotate(self, angle, r_c, start_h, start_w):
Expand All @@ -423,6 +427,7 @@ def rotate(self, angle, r_c, start_h, start_w):
char_classes=self.char_classes,
size=(r_c[0] * 2, r_c[1] * 2),
mode=self.mode,
char_num_classes=self.char_num_classes,
)

def resize(self, size, *args, **kwargs):
Expand All @@ -437,6 +442,7 @@ def resize(self, size, *args, **kwargs):
char_classes=self.char_classes,
size=size,
mode=self.mode,
char_num_classes=self.char_num_classes,
)

ratio_w, ratio_h = ratios
Expand All @@ -454,6 +460,7 @@ def resize(self, size, *args, **kwargs):
char_classes=self.char_classes,
size=size,
mode=self.mode,
char_num_classes=self.char_num_classes,
)

def set_size(self, size):
Expand All @@ -464,59 +471,41 @@ def convert(self, mode):
if mode == "char_mask":
if not self.use_char_ann:
char_map = -np.ones((height, width))
char_map_weight = np.zeros((37,))
char_map_weight = np.zeros((self.char_num_classes,))
else:
char_map = np.zeros((height, width))
char_map_weight = np.ones((37,))
char_map_weight = np.ones((self.char_num_classes,))
for i, p in enumerate(self.char_boxes):
poly = p.numpy().reshape(4, 2)
# x_center = np.mean(poly[:,0], axis = 0).astype(np.int32)
# y_center = np.mean(poly[:,1], axis = 0).astype(np.int32)
poly = shrink_poly(poly, 0.25)
cv2.fillPoly(
char_map, [poly.astype(np.int32)], int(self.char_classes[i])
)
# if is_poly_inbox(poly, height, width) and x_center>=0 and x_center<width and y_center>=0 and y_center<height:
# spoly=shrink_rect(poly,0.25)
# spoly = spoly.astype(np.int32)
# sbox_xmin_shrink = max(0, min(spoly[:,0]))
# sbox_xmax_shrink = min(width - 1, max(spoly[:,0]))
# sbox_ymin_shrink = max(0, min(spoly[:,1]))
# sbox_ymax_shrink = min(height - 1, max(spoly[:,1]))
# ## very small char box
# if sbox_xmax_shrink == sbox_xmin_shrink:
# sbox_xmax_shrink = sbox_xmin_shrink + 1
# if sbox_ymax_shrink == sbox_ymin_shrink:
# sbox_ymax_shrink = sbox_ymin_shrink + 1
# char_map[sbox_ymin_shrink:sbox_ymax_shrink, sbox_xmin_shrink:sbox_xmax_shrink] = int(self.char_classes[i])
pos_index = np.where(char_map > 0)
pos_num = pos_index[0].size
if pos_num > 0:
pos_weight = 1.0 * (height * width - pos_num) / pos_num
char_map_weight[1:] = pos_weight
return torch.from_numpy(char_map), torch.from_numpy(char_map_weight)
elif mode == "seq_char_mask":
decoder_target = (38 - 1) * np.ones((32,))
decoder_target = self.char_num_classes * np.ones((32,))
word_target = -np.ones((32,))
if not self.use_char_ann:
char_map = -np.ones((height, width))
char_map_weight = np.zeros((37,))
char_map_weight = np.zeros((self.char_num_classes,))
for i, char in enumerate(self.word):
if i > 31:
break
decoder_target[i] = char2num(char)
word_target[i] = char2num(char)
end_point = min(max(1, len(self.word)), 31)
word_target[end_point] = 37
word_target[end_point] = self.char_num_classes
else:
char_map = np.zeros((height, width))
char_map_weight = np.ones((37,))
char_map_weight = np.ones((self.char_num_classes,))
word_length = 0
for i, p in enumerate(self.char_boxes):
poly = p.numpy().reshape(4, 2)
# x_center = np.mean(poly[:,0], axis = 0).astype(np.int32)
# y_center = np.mean(poly[:,1], axis = 0).astype(np.int32)
# if is_poly_inbox(poly, height, width):
if i < 32:
decoder_target[i] = int(self.char_classes[i])
word_target[i] = int(self.char_classes[i])
Expand All @@ -525,21 +514,8 @@ def convert(self, mode):
cv2.fillPoly(
char_map, [poly.astype(np.int32)], int(self.char_classes[i])
)
# if x_center>=0 and x_center<width and y_center>=0 and y_center<height:
# spoly=shrink_rect(poly,0.25)
# spoly = spoly.astype(np.int32)
# sbox_xmin_shrink = max(0, min(spoly[:,0]))
# sbox_xmax_shrink = min(width - 1, max(spoly[:,0]))
# sbox_ymin_shrink = max(0, min(spoly[:,1]))
# sbox_ymax_shrink = min(height - 1, max(spoly[:,1]))
# ## very small char box
# if sbox_xmax_shrink == sbox_xmin_shrink:
# sbox_xmax_shrink = sbox_xmin_shrink + 1
# if sbox_ymax_shrink == sbox_ymin_shrink:
# sbox_ymax_shrink = sbox_ymin_shrink + 1
# char_map[sbox_ymin_shrink:sbox_ymax_shrink, sbox_xmin_shrink:sbox_xmax_shrink] = int(self.char_classes[i])
end_point = min(max(1, word_length), 31)
word_target[end_point] = 37
word_target[end_point] = self.char_num_classes
pos_index = np.where(char_map > 0)
pos_num = pos_index[0].size
if pos_num > 0:
Expand Down Expand Up @@ -576,7 +552,7 @@ def __repr__(self):

class SegmentationCharMask(object):
def __init__(
self, chars_boxes, words=None, use_char_ann=True, size=None, mode=None
self, chars_boxes, words=None, use_char_ann=True, size=None, mode=None, char_num_classes=37
):
# self.chars_boxes=[CharPolygons(char_boxes, word=word, use_char_ann=use_char_ann, size=size, mode=mode) for char_boxes, word in zip(chars_boxes, words)]
if words is None:
Expand All @@ -587,6 +563,7 @@ def __init__(
use_char_ann=use_char_ann,
size=size,
mode=mode,
char_num_classes=char_num_classes,
)
for char_boxes in chars_boxes
]
Expand All @@ -598,12 +575,14 @@ def __init__(
use_char_ann=use_char_ann,
size=size,
mode=mode,
char_num_classes=char_num_classes,
)
for i, char_boxes in enumerate(chars_boxes)
]
self.size = size
self.mode = mode
self.use_char_ann = use_char_ann
self.char_num_classes = char_num_classes

def transpose(self, method):
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
Expand All @@ -615,7 +594,7 @@ def transpose(self, method):
for char_boxes in self.chars_boxes:
flipped.append(char_boxes.transpose(method))
return SegmentationCharMask(
flipped, use_char_ann=self.use_char_ann, size=self.size, mode=self.mode
flipped, use_char_ann=self.use_char_ann, size=self.size, mode=self.mode, char_num_classes=self.char_num_classes
)

def crop(self, box, keep_ind):
Expand All @@ -635,7 +614,7 @@ def resize(self, size, *args, **kwargs):
for char_boxes in self.chars_boxes:
scaled.append(char_boxes.resize(size, *args, **kwargs))
return SegmentationCharMask(
scaled, use_char_ann=self.use_char_ann, size=size, mode=self.mode
scaled, use_char_ann=self.use_char_ann, size=size, mode=self.mode, char_num_classes=self.char_num_classes
)

def set_size(self, size):
Expand All @@ -652,6 +631,7 @@ def rotate(self, angle, r_c, start_h, start_w):
use_char_ann=self.use_char_ann,
size=(r_c[0] * 2, r_c[1] * 2),
mode=self.mode,
char_num_classes=self.char_num_classes,
)

def __iter__(self):
Expand All @@ -678,6 +658,7 @@ def __getitem__(self, item):
use_char_ann=self.use_char_ann,
size=self.size,
mode=self.mode,
char_num_classes=self.char_num_classes,
)

def __repr__(self):
Expand Down

0 comments on commit 9bd595e

Please sign in to comment.