diff --git a/README.md b/README.md index 23115abb27..fd5b9acc9b 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Supported tasks: + Image Classification ([ResNet](examples/resnet), [SENet](examples/senet), [VGG](examples/vgg)) + Object Detection ([tutorial](http://chainercv.readthedocs.io/en/latest/tutorial/detection.html), [Faster R-CNN](examples/faster_rcnn), [FPN](examples/fpn), [SSD](examples/ssd), [YOLO](examples/yolo)) + Semantic Segmentation ([SegNet](examples/segnet), [PSPNet](examples/pspnet)) -+ Instance Segmentation ([FCIS](examples/fcis),) ++ Instance Segmentation ([FCIS](examples/fcis), [Mask R-CNN](examples/fpn)) # Guiding Principles ChainerCV is developed under the following three guiding principles. diff --git a/chainercv/datasets/__init__.py b/chainercv/datasets/__init__.py index 6ecea75ca5..bb6ed650dc 100644 --- a/chainercv/datasets/__init__.py +++ b/chainercv/datasets/__init__.py @@ -12,9 +12,11 @@ from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_semantic_segmentation_label_names # NOQA from chainercv.datasets.coco.coco_bbox_dataset import COCOBboxDataset # NOQA from chainercv.datasets.coco.coco_instance_segmentation_dataset import COCOInstanceSegmentationDataset # NOQA +from chainercv.datasets.coco.coco_keypoint_dataset import COCOKeypointDataset # NOQA from chainercv.datasets.coco.coco_semantic_segmentation_dataset import COCOSemanticSegmentationDataset # NOQA from chainercv.datasets.coco.coco_utils import coco_bbox_label_names # NOQA from chainercv.datasets.coco.coco_utils import coco_instance_segmentation_label_names # NOQA +from chainercv.datasets.coco.coco_utils import coco_keypoint_names # NOQA from chainercv.datasets.coco.coco_utils import coco_semantic_segmentation_label_colors # NOQA from chainercv.datasets.coco.coco_utils import coco_semantic_segmentation_label_names # NOQA from chainercv.datasets.cub.cub_keypoint_dataset import CUBKeypointDataset # NOQA diff --git a/chainercv/datasets/coco/coco_keypoint_dataset.py b/chainercv/datasets/coco/coco_keypoint_dataset.py new file mode 100644 index 0000000000..234d7e0942 --- /dev/null +++ b/chainercv/datasets/coco/coco_keypoint_dataset.py @@ -0,0 +1,166 @@ +from collections import defaultdict +import json +import numpy as np +import os + +from chainercv.chainer_experimental.datasets.sliceable import GetterDataset +from chainercv.datasets.coco.coco_utils import get_coco +from chainercv import utils + + +class COCOKeypointDataset(GetterDataset): + + """Keypoint dataset for `MS COCO`_. + + This only returns annotation for objects categorized to the "person" + category. + + .. _`MS COCO`: http://cocodataset.org/#home + + Args: + data_dir (string): Path to the root of the training data. If this is + :obj:`auto`, this class will automatically download data for you + under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/coco`. + split ({'train', 'val'}): Select a split of the dataset. + year ({'2014', '2017'}): Use a dataset released in :obj:`year`. + use_crowded (bool): If true, use bounding boxes that are labeled as + crowded in the original annotation. The default value is + :obj:`False`. + return_area (bool): If true, this dataset returns areas of masks + around objects. The default value is :obj:`False`. + return_crowded (bool): If true, this dataset returns a boolean array + that indicates whether bounding boxes are labeled as crowded + or not. The default value is :obj:`False`. + + This dataset returns the following data. + + .. csv-table:: + :header: name, shape, dtype, format + + :obj:`img`, ":math:`(3, H, W)`", :obj:`float32`, \ + "RGB, :math:`[0, 255]`" + :obj:`point` [#coco_point_1]_, ":math:`(R, K, 2)`", :obj:`float32`, \ + ":math:`(y, x)`" + :obj:`visible` [#coco_point_1]_, ":math:`(R, K)`", :obj:`bool`, \ + "true when a keypoint is visible." + :obj:`label` [#coco_point_1]_, ":math:`(R,)`", :obj:`int32`, \ + ":math:`[0, \#fg\_class - 1]`" + :obj:`bbox` [#coco_point_1]_, ":math:`(R, 4)`", :obj:`float32`, \ + ":math:`(y_{min}, x_{min}, y_{max}, x_{max})`" + :obj:`area` [#coco_point_1]_ [#coco_point_2]_, ":math:`(R,)`", \ + :obj:`float32`, -- + :obj:`crowded` [#coco_point_3]_, ":math:`(R,)`", :obj:`bool`, -- + + .. [#coco_point_1] If :obj:`use_crowded = True`, :obj:`point`, \ + :obj:`visible`, :obj:`bbox`, \ + :obj:`label` and :obj:`area` contain crowded instances. + .. [#coco_point_2] :obj:`area` is available \ + if :obj:`return_area = True`. + .. [#coco_point_3] :obj:`crowded` is available \ + if :obj:`return_crowded = True`. + + """ + + def __init__(self, data_dir='auto', split='train', year='2017', + use_crowded=False, + return_area=False, return_crowded=False): + if split not in ['train', 'val']: + raise ValueError('Unsupported split is given.') + super(COCOKeypointDataset, self).__init__() + self.use_crowded = use_crowded + if data_dir == 'auto': + data_dir = get_coco(split, split, year, 'instances') + + self.img_root = os.path.join( + data_dir, 'images', '{}{}'.format(split, year)) + self.data_dir = data_dir + + point_anno_path = os.path.join( + self.data_dir, 'annotations', 'person_keypoints_{}{}.json'.format( + split, year)) + annos = json.load(open(point_anno_path, 'r')) + + self.id_to_prop = {} + for prop in annos['images']: + self.id_to_prop[prop['id']] = prop + self.ids = sorted(list(self.id_to_prop.keys())) + + self.cat_ids = [cat['id'] for cat in annos['categories']] + + self.id_to_anno = defaultdict(list) + for anno in annos['annotations']: + self.id_to_anno[anno['image_id']].append(anno) + + self.add_getter('img', self._get_image) + self.add_getter( + ['point', 'visible', 'bbox', 'label', 'area', 'crowded'], + self._get_annotations) + keys = ('img', 'point', 'visible', 'label', 'bbox') + if return_area: + keys += ('area',) + if return_crowded: + keys += ('crowded',) + self.keys = keys + + def __len__(self): + return len(self.ids) + + def _get_image(self, i): + img_path = os.path.join( + self.img_root, self.id_to_prop[self.ids[i]]['file_name']) + img = utils.read_image(img_path, dtype=np.float32, color=True) + return img + + def _get_annotations(self, i): + # List[{'segmentation', 'area', 'iscrowd', + # 'image_id', 'bbox', 'category_id', 'id'}] + annotation = self.id_to_anno[self.ids[i]] + bbox = np.array([ann['bbox'] for ann in annotation], + dtype=np.float32) + if len(bbox) == 0: + bbox = np.zeros((0, 4), dtype=np.float32) + # (x, y, width, height) -> (x_min, y_min, x_max, y_max) + bbox[:, 2] = bbox[:, 0] + bbox[:, 2] + bbox[:, 3] = bbox[:, 1] + bbox[:, 3] + # (x_min, y_min, x_max, y_max) -> (y_min, x_min, y_max, x_max) + bbox = bbox[:, [1, 0, 3, 2]] + + label = np.array([self.cat_ids.index(ann['category_id']) + for ann in annotation], dtype=np.int32) + + area = np.array([ann['area'] + for ann in annotation], dtype=np.float32) + + crowded = np.array([ann['iscrowd'] + for ann in annotation], dtype=np.bool) + + point = np.array( + [anno['keypoints'] for anno in annotation], dtype=np.float32) + if len(point) > 0: + x = point[:, 0::3] + y = point[:, 1::3] + # 0: not labeled; 1: labeled, not inside mask; + # 2: labeled and inside mask + v = point[:, 2::3] + visible = v > 0 + point = np.stack((y, x), axis=2) + else: + point = np.empty((0, 0, 2), dtype=np.float32) + visible = np.empty((0, 0), dtype=np.bool) + + # Remove invisible boxes + bbox_area = np.prod(bbox[:, 2:] - bbox[:, :2], axis=1) + keep_mask = np.logical_and(bbox[:, 0] <= bbox[:, 2], + bbox[:, 1] <= bbox[:, 3]) + keep_mask = np.logical_and(keep_mask, bbox_area > 0) + + if not self.use_crowded: + keep_mask = np.logical_and(keep_mask, np.logical_not(crowded)) + + point = point[keep_mask] + visible = visible[keep_mask] + bbox = bbox[keep_mask] + label = label[keep_mask] + area = area[keep_mask] + crowded = crowded[keep_mask] + return point, visible, bbox, label, area, crowded diff --git a/chainercv/datasets/coco/coco_utils.py b/chainercv/datasets/coco/coco_utils.py index 7fe9d0f2f0..501d8329ba 100644 --- a/chainercv/datasets/coco/coco_utils.py +++ b/chainercv/datasets/coco/coco_utils.py @@ -19,10 +19,10 @@ } instances_anno_urls = { '2014': { - 'train': 'http://msvocds.blob.core.windows.net/annotations-1-0-3/' - 'instances_train-val2014.zip', - 'val': 'http://msvocds.blob.core.windows.net/annotations-1-0-3/' - 'instances_train-val2014.zip', + 'train': 'http://images.cocodataset.org/annotations/' + 'annotations_trainval2014.zip', + 'val': 'http://images.cocodataset.org/annotations/' + 'annotations_trainval2014.zip', 'valminusminival': 'https://dl.dropboxusercontent.com/s/' 's3tw5zcg7395368/instances_valminusminival2014.json.zip', 'minival': 'https://dl.dropboxusercontent.com/s/o43o90bna78omob/' @@ -442,3 +442,26 @@ def get_coco(split, img_split, year, mode): coco_instance_segmentation_label_names = coco_bbox_label_names + + +coco_keypoint_names = { + 0: [ + 'nose', + 'left_eye', + 'right_eye', + 'left_ear', + 'right_ear', + 'left_shoulder', + 'right_shoulder', + 'left_elbow', + 'right_elbow', + 'left_wrist', + 'right_wrist', + 'left_hip', + 'right_hip', + 'left_knee', + 'right_knee', + 'left_ankle', + 'right_ankle' + ] +} diff --git a/chainercv/evaluations/__init__.py b/chainercv/evaluations/__init__.py index 1f12332cdb..53017c6bb1 100644 --- a/chainercv/evaluations/__init__.py +++ b/chainercv/evaluations/__init__.py @@ -5,6 +5,7 @@ from chainercv.evaluations.eval_instance_segmentation_coco import eval_instance_segmentation_coco # NOQA from chainercv.evaluations.eval_instance_segmentation_voc import calc_instance_segmentation_voc_prec_rec # NOQA from chainercv.evaluations.eval_instance_segmentation_voc import eval_instance_segmentation_voc # NOQA +from chainercv.evaluations.eval_keypoint_detection_coco import eval_keypoint_detection_coco # NOQA from chainercv.evaluations.eval_semantic_segmentation import calc_semantic_segmentation_confusion # NOQA from chainercv.evaluations.eval_semantic_segmentation import calc_semantic_segmentation_iou # NOQA from chainercv.evaluations.eval_semantic_segmentation import eval_semantic_segmentation # NOQA diff --git a/chainercv/evaluations/eval_keypoint_detection_coco.py b/chainercv/evaluations/eval_keypoint_detection_coco.py new file mode 100644 index 0000000000..97dfc75b6f --- /dev/null +++ b/chainercv/evaluations/eval_keypoint_detection_coco.py @@ -0,0 +1,308 @@ +import itertools +import numpy as np +import os +import six + +from chainercv.evaluations.eval_detection_coco import _redirect_stdout +from chainercv.evaluations.eval_detection_coco import _summarize + +try: + import pycocotools.coco + import pycocotools.cocoeval + _available = True +except ImportError: + _available = False + + +def eval_keypoint_detection_coco( + pred_points, pred_labels, pred_scores, + gt_points, gt_visibles, gt_labels=None, gt_bboxes=None, + gt_areas=None, gt_crowdeds=None): + """Evaluate keypoint detection based on evaluation code of MS COCO. + + This function evaluates predicted keypints obtained by using average + precision for each class. + The code is based on the evaluation code used in MS COCO. + + Args: + pred_points (iterable of numpy.ndarray): See the table below. + pred_labels (iterable of numpy.ndarray): See the table below. + pred_scores (iterable of numpy.ndarray): See the table below. + This is used to rank instances. Note that this is not + the confidene for each keypoint. + gt_points (iterable of numpy.ndarray): See the table below. + gt_visibles (iterable of numpy.ndarray): See the table below. + gt_labels (iterable of numpy.ndarray): See the table below. + gt_bboxes (iterable of numpy.ndarray): See the table below. + This is optional. If this is :obj:`None`, the ground truth + bounding boxes are esitmated from the ground truth + keypoints. + gt_areas (iterable of numpy.ndarray): See the table below. If + :obj:`None`, some scores are not returned. + gt_crowdeds (iterable of numpy.ndarray): See the table below. + + .. csv-table:: + :header: name, shape, dtype, format + + :obj:`pred_points`, ":math:`[(R, K, 2)]`", :obj:`float32`, \ + ":math:`(y, x)`" + :obj:`pred_labels`, ":math:`[(R,)]`", :obj:`int32`, \ + ":math:`[0, \#fg\_class - 1]`" + :obj:`pred_scores`, ":math:`[(R,)]`", :obj:`float32`, \ + -- + :obj:`gt_points`, ":math:`[(R, K, 2)]`", :obj:`float32`, \ + ":math:`(y, x)`" + :obj:`gt_visibles`, ":math:`[(R, K)]`", :obj:`bool`, -- + :obj:`gt_labels`, ":math:`[(R,)]`", :obj:`int32`, \ + ":math:`[0, \#fg\_class - 1]`" + :obj:`gt_bboxes`, ":math:`[(R, 4)]`", :obj:`float32`, \ + ":math:`(y_{min}, x_{min}, y_{max}, x_{max})`" + :obj:`gt_areas`, ":math:`[(R,)]`", \ + :obj:`float32`, -- + :obj:`gt_crowdeds`, ":math:`[(R,)]`", :obj:`bool`, -- + + + Returns: + dict: + + The keys, value-types and the description of the values are listed + below. The APs and ARs calculated with different iou + thresholds, sizes of objects, and numbers of detections + per image. For more details on the 12 patterns of evaluation metrics, + please refer to COCO's official `evaluation page`_. + + .. csv-table:: + :header: key, type, description + + ap/iou=0.50:0.95/area=all/max_dets=20, *numpy.ndarray*, \ + [#coco_kp_eval_1]_ + ap/iou=0.50/area=all/max_dets=20, *numpy.ndarray*, \ + [#coco_kp_eval_1]_ + ap/iou=0.75/area=all/max_dets=20, *numpy.ndarray*, \ + [#coco_kp_eval_1]_ + ap/iou=0.50:0.95/area=medium/max_dets=20, *numpy.ndarray*, \ + [#coco_kp_eval_1]_ [#coco_kp_eval_5]_ + ap/iou=0.50:0.95/area=large/max_dets=20, *numpy.ndarray*, \ + [#coco_kp_eval_1]_ [#coco_kp_eval_5]_ + ar/iou=0.50:0.95/area=all/max_dets=20, *numpy.ndarray*, \ + [#coco_kp_eval_2]_ + ar/iou=0.50/area=all/max_dets=20, *numpy.ndarray*, \ + [#coco_kp_eval_2]_ + ar/iou=0.75/area=all/max_dets=20, *numpy.ndarray*, \ + [#coco_kp_eval_2]_ + ar/iou=0.50:0.95/area=medium/max_dets=20, *numpy.ndarray*, \ + [#coco_kp_eval_2]_ [#coco_kp_eval_5]_ + ar/iou=0.50:0.95/area=large/max_dets=20, *numpy.ndarray*, \ + [#coco_kp_eval_2]_ [#coco_kp_eval_5]_ + map/iou=0.50:0.95/area=all/max_dets=20, *float*, \ + [#coco_kp_eval_3]_ + map/iou=0.50/area=all/max_dets=20, *float*, \ + [#coco_kp_eval_3]_ + map/iou=0.75/area=all/max_dets=20, *float*, \ + [#coco_kp_eval_3]_ + map/iou=0.50:0.95/area=medium/max_dets=20, *float*, \ + [#coco_kp_eval_3]_ [#coco_kp_eval_5]_ + map/iou=0.50:0.95/area=large/max_dets=20, *float*, \ + [#coco_kp_eval_3]_ [#coco_kp_eval_5]_ + mar/iou=0.50:0.95/area=all/max_dets=20, *float*, \ + [#coco_kp_eval_4]_ + mar/iou=0.50/area=all/max_dets=20, *float*, \ + [#coco_kp_eval_4]_ + mar/iou=0.75/area=all/max_dets=20, *float*, \ + [#coco_kp_eval_4]_ + mar/iou=0.50:0.95/area=medium/max_dets=20, *float*, \ + [#coco_kp_eval_4]_ [#coco_kp_eval_5]_ + mar/iou=0.50:0.95/area=large/max_dets=20, *float*, \ + [#coco_kp_eval_4]_ [#coco_kp_eval_5]_ + coco_eval, *pycocotools.cocoeval.COCOeval*, \ + result from :obj:`pycocotools` + existent_labels, *numpy.ndarray*, \ + used labels \ + + .. [#coco_kp_eval_1] An array of average precisions. \ + The :math:`l`-th value corresponds to the average precision \ + for class :math:`l`. If class :math:`l` does not exist in \ + either :obj:`pred_labels` or :obj:`gt_labels`, the corresponding \ + value is set to :obj:`numpy.nan`. + .. [#coco_kp_eval_2] An array of average recalls. \ + The :math:`l`-th value corresponds to the average precision \ + for class :math:`l`. If class :math:`l` does not exist in \ + either :obj:`pred_labels` or :obj:`gt_labels`, the corresponding \ + value is set to :obj:`numpy.nan`. + .. [#coco_kp_eval_3] The average of average precisions over classes. + .. [#coco_kp_eval_4] The average of average recalls over classes. + .. [#coco_kp_eval_5] Skip if :obj:`gt_areas` is :obj:`None`. + + """ + if not _available: + raise ValueError( + 'Please install pycocotools \n' + 'pip install -e \'git+https://github.com/cocodataset/coco.git' + '#egg=pycocotools&subdirectory=PythonAPI\'') + + gt_coco = pycocotools.coco.COCO() + pred_coco = pycocotools.coco.COCO() + + pred_points = iter(pred_points) + pred_labels = iter(pred_labels) + pred_scores = iter(pred_scores) + gt_points = iter(gt_points) + gt_visibles = iter(gt_visibles) + gt_labels = iter(gt_labels) + gt_bboxes = (iter(gt_bboxes) if gt_bboxes is not None + else itertools.repeat(None)) + if gt_areas is None: + compute_area_dependent_metrics = False + gt_areas = itertools.repeat(None) + else: + compute_area_dependent_metrics = True + gt_areas = iter(gt_areas) + gt_crowdeds = (iter(gt_crowdeds) if gt_crowdeds is not None + else itertools.repeat(None)) + + ids = [] + pred_annos = [] + gt_annos = [] + existent_labels = {} + for i, (pred_point, pred_label, pred_score, gt_point, gt_visible, + gt_label, gt_bbox, + gt_area, gt_crowded) in enumerate(six.moves.zip( + pred_points, pred_labels, pred_scores, + gt_points, gt_visibles, gt_labels, gt_bboxes, + gt_areas, gt_crowdeds)): + if gt_bbox is None: + gt_bbox = itertools.repeat(None) + if gt_area is None: + gt_area = itertools.repeat(None) + if gt_crowded is None: + gt_crowded = itertools.repeat(None) + # Starting ids from 1 is important when using COCO. + img_id = i + 1 + + for pred_pnt, pred_lb, pred_sc in zip(pred_point, pred_label, + pred_score): + # http://cocodataset.org/#format-results + # Visibility flag is currently not used for evaluation + v = np.ones(len(pred_pnt)) + pred_annos.append( + _create_anno(pred_pnt, v, + pred_lb, pred_sc, None, + img_id=img_id, anno_id=len(pred_annos) + 1, + ar=None, crw=0)) + existent_labels[pred_lb] = True + + for gt_pnt, gt_v, gt_lb, gt_bb, gt_ar, gt_crw in zip( + gt_point, gt_visible, gt_label, gt_bbox, gt_area, gt_crowded): + gt_annos.append( + _create_anno(gt_pnt, gt_v, gt_lb, None, gt_bb, + img_id=img_id, anno_id=len(gt_annos) + 1, + ar=gt_ar, crw=gt_crw)) + ids.append({'id': img_id}) + existent_labels = sorted(existent_labels.keys()) + + pred_coco.dataset['categories'] = [{'id': i} for i in existent_labels] + gt_coco.dataset['categories'] = [{'id': i} for i in existent_labels] + pred_coco.dataset['annotations'] = pred_annos + gt_coco.dataset['annotations'] = gt_annos + pred_coco.dataset['images'] = ids + gt_coco.dataset['images'] = ids + + with _redirect_stdout(open(os.devnull, 'w')): + pred_coco.createIndex() + gt_coco.createIndex() + coco_eval = pycocotools.cocoeval.COCOeval( + gt_coco, pred_coco, 'keypoints') + coco_eval.evaluate() + coco_eval.accumulate() + + results = {'coco_eval': coco_eval} + p = coco_eval.params + common_kwargs = { + 'prec': coco_eval.eval['precision'], + 'rec': coco_eval.eval['recall'], + 'iou_threshs': p.iouThrs, + 'area_ranges': p.areaRngLbl, + 'max_detection_list': p.maxDets, + } + all_kwargs = { + 'ap/iou=0.50:0.95/area=all/max_dets=20': { + 'ap': True, 'iou_thresh': None, 'area_range': 'all', + 'max_detection': 20}, + 'ap/iou=0.50/area=all/max_dets=20': { + 'ap': True, 'iou_thresh': 0.5, 'area_range': 'all', + 'max_detection': 20}, + 'ap/iou=0.75/area=all/max_dets=20': { + 'ap': True, 'iou_thresh': 0.75, 'area_range': 'all', + 'max_detection': 20}, + 'ar/iou=0.50:0.95/area=all/max_dets=20': { + 'ap': False, 'iou_thresh': None, 'area_range': 'all', + 'max_detection': 20}, + 'ar/iou=0.50/area=all/max_dets=20': { + 'ap': False, 'iou_thresh': 0.5, 'area_range': 'all', + 'max_detection': 20}, + 'ar/iou=0.75/area=all/max_dets=20': { + 'ap': False, 'iou_thresh': 0.75, 'area_range': 'all', + 'max_detection': 20}, + } + if compute_area_dependent_metrics: + all_kwargs.update({ + 'ap/iou=0.50:0.95/area=medium/max_dets=20': { + 'ap': True, 'iou_thresh': None, 'area_range': 'medium', + 'max_detection': 20}, + 'ap/iou=0.50:0.95/area=large/max_dets=20': { + 'ap': True, 'iou_thresh': None, 'area_range': 'large', + 'max_detection': 20}, + 'ar/iou=0.50:0.95/area=medium/max_dets=20': { + 'ap': False, 'iou_thresh': None, 'area_range': 'medium', + 'max_detection': 20}, + 'ar/iou=0.50:0.95/area=large/max_dets=20': { + 'ap': False, 'iou_thresh': None, 'area_range': 'large', + 'max_detection': 20}, + }) + + for key, kwargs in all_kwargs.items(): + kwargs.update(common_kwargs) + metrics, mean_metric = _summarize(**kwargs) + + # pycocotools ignores classes that are not included in + # either gt or prediction, but lies between 0 and + # the maximum label id. + # We set values for these classes to np.nan. + results[key] = np.nan * np.ones(np.max(existent_labels) + 1) + results[key][existent_labels] = metrics + results['m' + key] = mean_metric + + results['existent_labels'] = existent_labels + return results + + +def _create_anno(pnt, v, lb, sc, bb, img_id, anno_id, ar=None, crw=None): + # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L342 + y_min = np.min(pnt[:, 0]) + x_min = np.min(pnt[:, 1]) + y_max = np.max(pnt[:, 0]) + x_max = np.max(pnt[:, 1]) + if ar is None: + ar = (y_max - y_min) * (x_max - x_min) + + if crw is None: + crw = False + # Rounding is done to make the result consistent with COCO. + + if bb is None: + bb_xywh = [x_min, y_min, x_max - x_min, y_max - y_min] + else: + bb_xywh = [bb[1], bb[0], bb[3] - bb[1], bb[2] - bb[0]] + pnt = np.concatenate((pnt[:, [1, 0]], v[:, None]), axis=1) + anno = { + 'image_id': img_id, 'category_id': lb, + 'keypoints': pnt.reshape((-1)).tolist(), + 'area': ar, + 'bbox': bb_xywh, + 'id': anno_id, + 'iscrowd': crw, + 'num_keypoints': (pnt[:, 0] > 0).sum() + } + if sc is not None: + anno.update({'score': sc}) + return anno diff --git a/chainercv/links/__init__.py b/chainercv/links/__init__.py index 642cc906e1..aa91f30b77 100644 --- a/chainercv/links/__init__.py +++ b/chainercv/links/__init__.py @@ -11,6 +11,10 @@ from chainercv.links.model.faster_rcnn.faster_rcnn_vgg import FasterRCNNVGG16 # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet101 # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet50 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet101 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet50 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet101 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet50 # NOQA from chainercv.links.model.resnet import ResNet101 # NOQA from chainercv.links.model.resnet import ResNet152 # NOQA from chainercv.links.model.resnet import ResNet50 # NOQA diff --git a/chainercv/links/model/fpn/__init__.py b/chainercv/links/model/fpn/__init__.py index 0ceacd4fe5..e4ba9c853c 100644 --- a/chainercv/links/model/fpn/__init__.py +++ b/chainercv/links/model/fpn/__init__.py @@ -1,9 +1,22 @@ +from chainercv.links.model.fpn.bbox_head import bbox_loss_post # NOQA +from chainercv.links.model.fpn.bbox_head import bbox_loss_pre # NOQA +from chainercv.links.model.fpn.bbox_head import BboxHead # NOQA from chainercv.links.model.fpn.faster_rcnn import FasterRCNN # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet101 # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet50 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet101 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet50 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet101 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet50 # NOQA from chainercv.links.model.fpn.fpn import FPN # NOQA -from chainercv.links.model.fpn.head import Head # NOQA -from chainercv.links.model.fpn.head import head_loss_post # NOQA -from chainercv.links.model.fpn.head import head_loss_pre # NOQA +from chainercv.links.model.fpn.keypoint_head import keypoint_loss_post # NOQA +from chainercv.links.model.fpn.keypoint_head import keypoint_loss_pre # NOQA +from chainercv.links.model.fpn.keypoint_head import KeypointHead # NOQA +from chainercv.links.model.fpn.mask_head import mask_loss_post # NOQA +from chainercv.links.model.fpn.mask_head import mask_loss_pre # NOQA +from chainercv.links.model.fpn.mask_head import MaskHead # NOQA +from chainercv.links.model.fpn.mask_utils import mask_to_segm # NOQA +from chainercv.links.model.fpn.mask_utils import segm_to_mask # NOQA from chainercv.links.model.fpn.rpn import RPN # NOQA from chainercv.links.model.fpn.rpn import rpn_loss # NOQA diff --git a/chainercv/links/model/fpn/head.py b/chainercv/links/model/fpn/bbox_head.py similarity index 92% rename from chainercv/links/model/fpn/head.py rename to chainercv/links/model/fpn/bbox_head.py index f0c0fc7b63..1daa8897b0 100644 --- a/chainercv/links/model/fpn/head.py +++ b/chainercv/links/model/fpn/bbox_head.py @@ -7,14 +7,14 @@ import chainer.links as L from chainercv.links.model.fpn.misc import argsort -from chainercv.links.model.fpn.misc import choice +from chainercv.links.model.fpn.misc import balanced_sampling from chainercv.links.model.fpn.misc import exp_clip from chainercv.links.model.fpn.misc import smooth_l1 from chainercv import utils -class Head(chainer.Chain): - """Head network of Feature Pyramid Networks. +class BboxHead(chainer.Chain): + """Bounding box head network of Feature Pyramid Networks. Args: n_class (int): The number of classes including background. @@ -28,7 +28,7 @@ class Head(chainer.Chain): std = (0.1, 0.2) def __init__(self, n_class, scales): - super(Head, self).__init__() + super(BboxHead, self).__init__() fc_init = { 'initialW': Caffe2FCUniform(), @@ -210,10 +210,10 @@ def decode(self, rois, roi_indices, locs, confs, return bboxes, labels, scores -def head_loss_pre(rois, roi_indices, std, bboxes, labels): +def bbox_loss_pre(rois, roi_indices, std, bboxes, labels): """Loss function for Head (pre). - This function processes RoIs for :func:`head_loss_post`. + This function processes RoIs for :func:`bbox_head_loss_post`. Args: rois (iterable of arrays): An iterable of arrays of @@ -285,25 +285,16 @@ def head_loss_pre(rois, roi_indices, std, bboxes, labels): else: gt_label = xp.zeros(int(mask.sum()), dtype=np.int32) - fg_index = xp.where(gt_label > 0)[0] - n_fg = int(batchsize_per_image * fg_ratio) - if len(fg_index) > n_fg: - gt_label[choice(fg_index, size=len(fg_index) - n_fg)] = -1 - - bg_index = xp.where(gt_label == 0)[0] - n_bg = batchsize_per_image - int((gt_label > 0).sum()) - if len(bg_index) > n_bg: - gt_label[choice(bg_index, size=len(bg_index) - n_bg)] = -1 - gt_locs[mask] = gt_loc - gt_labels[mask] = gt_label + gt_labels[mask] = balanced_sampling( + gt_label, batchsize_per_image, fg_ratio) - mask = gt_labels >= 0 - rois = rois[mask] - roi_indices = roi_indices[mask] - roi_levels = roi_levels[mask] - gt_locs = gt_locs[mask] - gt_labels = gt_labels[mask] + is_sampled = gt_labels >= 0 + rois = rois[is_sampled] + roi_indices = roi_indices[is_sampled] + roi_levels = roi_levels[is_sampled] + gt_locs = gt_locs[is_sampled] + gt_labels = gt_labels[is_sampled] masks = [roi_levels == l for l in range(n_level)] rois = [rois[m] for m in masks] @@ -314,7 +305,7 @@ def head_loss_pre(rois, roi_indices, std, bboxes, labels): return rois, roi_indices, gt_locs, gt_labels -def head_loss_post(locs, confs, roi_indices, gt_locs, gt_labels, batchsize): +def bbox_loss_post(locs, confs, roi_indices, gt_locs, gt_labels, batchsize): """Loss function for Head (post). Args: @@ -323,11 +314,11 @@ def head_loss_post(locs, confs, roi_indices, gt_locs, gt_labels, batchsize): confs (array): An iterable of arrays whose shape is :math:`(R, n\_class)`. roi_indices (list of arrays): A list of arrays returned by - :func:`head_locs_pre`. + :func:`bbox_head_locs_pre`. gt_locs (list of arrays): A list of arrays returned by - :func:`head_locs_pre`. + :func:`bbox_head_locs_pre`. gt_labels (list of arrays): A list of arrays returned by - :func:`head_locs_pre`. + :func:`bbox_head_locs_pre`. batchsize (int): The size of batch. Returns: diff --git a/chainercv/links/model/fpn/faster_rcnn.py b/chainercv/links/model/fpn/faster_rcnn.py index c64a563db2..c37fe30c08 100644 --- a/chainercv/links/model/fpn/faster_rcnn.py +++ b/chainercv/links/model/fpn/faster_rcnn.py @@ -4,17 +4,15 @@ import chainer from chainer.backends import cuda +import chainer.functions as F -from chainercv import transforms +from chainercv.links.model.fpn.misc import scale_img class FasterRCNN(chainer.Chain): - """Base class of Feature Pyramid Networks. + """Base class of Faster R-CNN with FPN. - This is a base class of Feature Pyramid Networks [#]_. - - .. [#] Tsung-Yi Lin et al. - Feature Pyramid Networks for Object Detection. CVPR 2017 + This is a base class of Faster R-CNN with FPN. Args: extractor (Link): A link that extracts feature maps. @@ -23,9 +21,14 @@ class FasterRCNN(chainer.Chain): rpn (Link): A link that has the same interface as :class:`~chainercv.links.model.fpn.RPN`. Please refer to the documentation found there. - head (Link): A link that has the same interface as - :class:`~chainercv.links.model.fpn.Head`. + bbox_head (Link): A link that has the same interface as + :class:`~chainercv.links.model.fpn.BboxHead`. + Please refer to the documentation found there. + mask_head (Link): A link that has the same interface as + :class:`~chainercv.links.model.fpn.MaskHead`. Please refer to the documentation found there. + return_values (list of strings): Determines the values + returned by :meth:`predict`. min_size (int): A preprocessing paramter for :meth:`prepare`. Please refer to a docstring found for :meth:`prepare`. max_size (int): A preprocessing paramter for :meth:`prepare`. Note @@ -45,18 +48,40 @@ class FasterRCNN(chainer.Chain): """ - _stride = 32 + stride = 32 + _accepted_return_values = ('rois', 'bboxes', 'labels', 'scores', + 'masks', 'points', 'point_scores') - def __init__(self, extractor, rpn, head, + def __init__(self, extractor, rpn, bbox_head, + mask_head, keypoint_head, return_values, min_size=800, max_size=1333): + for value_name in return_values: + if value_name not in self._accepted_return_values: + raise ValueError( + '{} is not included in accepted value names {}'.format( + value_name, self._accepted_return_values)) + self._return_values = return_values + + self._store_rpn_outputs = 'rois' in self._return_values + self._run_bbox = any([key in self._return_values + for key in ['bboxes', 'labels', 'scores', + 'masks', 'points', 'point_scores']]) + self._run_mask = 'masks' in self._return_values + self._run_keypoint = 'points' in self._return_values super(FasterRCNN, self).__init__() + with self.init_scope(): self.extractor = extractor self.rpn = rpn - self.head = head + if self._run_bbox: + self.bbox_head = bbox_head + if self._run_mask: + self.mask_head = mask_head + if self._run_keypoint: + self.keypoint_head = keypoint_head - self._min_size = min_size - self._max_size = max_size + self.min_size = min_size + self.max_size = max_size self.use_preset('visualize') @@ -94,52 +119,135 @@ def __call__(self, x): anchors = self.rpn.anchors(h.shape[2:] for h in hs) rois, roi_indices = self.rpn.decode( rpn_locs, rpn_confs, anchors, x.shape) - rois, roi_indices = self.head.distribute(rois, roi_indices) - head_locs, head_confs = self.head(hs, rois, roi_indices) - return rois, roi_indices, head_locs, head_confs + return hs, rois, roi_indices def predict(self, imgs): - """Detect objects from images. + """Conduct inference on the given images. - This method predicts objects for each image. + The value returned by this method is decided based on + the argument :obj:`return_values` of :meth:`__init__`. + + Examples: + + >>> from chainercv.links import FasterRCNNFPNResNet50 + >>> model = FasterRCNNFPNResNet50( + ... pretrained_model='coco', + ... return_values=['rois', 'bboxes', 'labels', 'scores']) + >>> rois, bboxes, labels, scores = model.predict(imgs) Args: - imgs (iterable of numpy.ndarray): Arrays holding images. - All images are in CHW and RGB format - and the range of their value is :math:`[0, 255]`. + imgs (iterable of numpy.ndarray): Inputs. Returns: - tuple of lists: - This method returns a tuple of three lists, - :obj:`(bboxes, labels, scores)`. - - * **bboxes**: A list of float arrays of shape :math:`(R, 4)`, \ - where :math:`R` is the number of bounding boxes in a image. \ - Each bounding box is organized by \ - :math:`(y_{min}, x_{min}, y_{max}, x_{max})` \ - in the second axis. - * **labels** : A list of integer arrays of shape :math:`(R,)`. \ - Each value indicates the class of the bounding box. \ - Values are in range :math:`[0, L - 1]`, where :math:`L` is the \ - number of the foreground classes. - * **scores** : A list of float arrays of shape :math:`(R,)`. \ - Each value indicates how confident the prediction is. + tuple of lists: + The table below shows the input and possible outputs. + + .. csv-table:: + :header: name, shape, dtype, format + + :obj:`imgs`, ":math:`[(3, H, W)]`", :obj:`float32`, \ + "RGB, :math:`[0, 255]`" + :obj:`rois`, ":math:`[(R', 4)]`", :obj:`float32`, \ + ":math:`(y_{min}, x_{min}, y_{max}, x_{max})`" + :obj:`bboxes`, ":math:`[(R, 4)]`", :obj:`float32`, \ + ":math:`(y_{min}, x_{min}, y_{max}, x_{max})`" + :obj:`scores`, ":math:`[(R,)]`", :obj:`float32`, \ + -- + :obj:`labels`, ":math:`[(R,)]`", :obj:`int32`, \ + ":math:`[0, \#fg\_class - 1]`" + :obj:`masks`, ":math:`[(R, H, W)]`", :obj:`bool`, -- """ + output = {} sizes = [img.shape[1:] for img in imgs] x, scales = self.prepare(imgs) with chainer.using_config('train', False), chainer.no_backprop_mode(): - rois, roi_indices, head_locs, head_confs = self(x) - bboxes, labels, scores = self.head.decode( - rois, roi_indices, head_locs, head_confs, - scales, sizes, self.nms_thresh, self.score_thresh) - - bboxes = [cuda.to_cpu(bbox) for bbox in bboxes] - labels = [cuda.to_cpu(label) for label in labels] - scores = [cuda.to_cpu(score) for score in scores] - return bboxes, labels, scores + hs, rpn_rois, rpn_roi_indices = self(x) + if self._store_rpn_outputs: + rpn_rois_cpu = [ + chainer.backends.cuda.to_cpu(rpn_roi) / scale + for rpn_roi, scale in + zip(_flat_to_list(rpn_rois, rpn_roi_indices, len(imgs)), + scales)] + output.update({'rois': rpn_rois_cpu}) + + if self._run_bbox: + bbox_rois, bbox_roi_indices = self.bbox_head.distribute( + rpn_rois, rpn_roi_indices) + with chainer.using_config( + 'train', False), chainer.no_backprop_mode(): + head_locs, head_confs = self.bbox_head( + hs, bbox_rois, bbox_roi_indices) + bboxes, labels, scores = self.bbox_head.decode( + bbox_rois, bbox_roi_indices, head_locs, head_confs, + scales, sizes, self.nms_thresh, self.score_thresh) + bboxes_cpu = [ + chainer.backends.cuda.to_cpu(bbox) for bbox in bboxes] + labels_cpu = [ + chainer.backends.cuda.to_cpu(label) for label in labels] + scores_cpu = [cuda.to_cpu(score) for score in scores] + output.update({'bboxes': bboxes_cpu, 'labels': labels_cpu, + 'scores': scores_cpu}) + rescaled_bboxes = [bbox * scale + for scale, bbox in zip(scales, bboxes)] + if self._run_mask: + # Change bboxes to RoI and RoI indices format + mask_rois_before_reordering, mask_roi_indices_before_reordering =\ + _list_to_flat(rescaled_bboxes) + mask_rois, mask_roi_indices, order = self.mask_head.distribute( + mask_rois_before_reordering, + mask_roi_indices_before_reordering) + with chainer.using_config( + 'train', False), chainer.no_backprop_mode(): + segms = F.sigmoid( + self.mask_head(hs, mask_rois, mask_roi_indices)).data + # Put the order of proposals back to the one used by bbox head. + segms = segms[order] + segms = _flat_to_list( + segms, mask_roi_indices_before_reordering, len(imgs)) + segms = [segm if segm is not None else + self.xp.zeros( + (0, self.mask_head.segm_size, + self.mask_head.segm_size), dtype=np.float32) + for segm in segms] + segms = [chainer.backends.cuda.to_cpu(segm) for segm in segms] + # Currently MaskHead only supports numpy inputs + masks_cpu = self.mask_head.decode( + segms, bboxes_cpu, labels_cpu, sizes) + output.update({'masks': masks_cpu}) + + if self._run_keypoint: + (point_rois_before_reordering, + point_roi_indices_before_reordering) = _list_to_flat( + rescaled_bboxes) + point_rois, point_roi_indices, order =\ + self.keypoint_head.distribute( + point_rois_before_reordering, + point_roi_indices_before_reordering) + with chainer.using_config( + 'train', False), chainer.no_backprop_mode(): + point_maps = self.keypoint_head( + hs, point_rois, point_roi_indices).data + point_maps = point_maps[order] + point_maps = _flat_to_list( + point_maps, point_roi_indices_before_reordering, len(imgs)) + point_maps = [point_map if point_map is not None else + self.xp.zeros( + (0, self.keypoint_head.n_point, + self.keypoint_head.point_map_size, + self.keypoint_head.point_map_size), + dtype=np.float32) + for point_map in point_maps] + point_maps = [ + chainer.backends.cuda.to_cpu(point_map) + for point_map in point_maps] + points_cpu, point_scores_cpu = self.keypoint_head.decode( + point_maps, bboxes_cpu) + output.update( + {'points': points_cpu, 'point_scores': point_scores_cpu}) + return tuple([output[key] for key in self._return_values]) def prepare(self, imgs): """Preprocess images. @@ -154,26 +262,44 @@ def prepare(self, imgs): scales that were caluclated in prepocessing. """ - scales = [] resized_imgs = [] for img in imgs: - _, H, W = img.shape - scale = self._min_size / min(H, W) - if scale * max(H, W) > self._max_size: - scale = self._max_size / max(H, W) - scales.append(scale) - H, W = int(H * scale), int(W * scale) - img = transforms.resize(img, (H, W)) + img, scale = scale_img( + img, self.min_size, self.max_size) img -= self.extractor.mean + scales.append(scale) resized_imgs.append(img) - - size = np.array([im.shape[1:] for im in resized_imgs]).max(axis=0) - size = (np.ceil(size / self._stride) * self._stride).astype(int) - x = np.zeros((len(imgs), 3, size[0], size[1]), dtype=np.float32) - for i, img in enumerate(resized_imgs): - _, H, W = img.shape - x[i, :, :H, :W] = img - + pad_size = np.array( + [im.shape[1:] for im in resized_imgs]).max(axis=0) + pad_size = ( + np.ceil(pad_size / self.stride) * self.stride).astype(int) + x = np.zeros( + (len(imgs), 3, pad_size[0], pad_size[1]), dtype=np.float32) + for i, im in enumerate(resized_imgs): + _, H, W = im.shape + x[i, :, :H, :W] = im x = self.xp.array(x) + return x, scales + + +def _list_to_flat(array_list): + xp = chainer.backends.cuda.get_array_module(array_list[0]) + + indices = xp.concatenate( + [i * xp.ones((len(array),), dtype=np.int32) for + i, array in enumerate(array_list)], axis=0) + flat = xp.concatenate(array_list, axis=0) + return flat, indices + + +def _flat_to_list(flat, indices, B): + array_list = [] + for i in range(B): + array = flat[indices == i] + if len(array) > 0: + array_list.append(array) + else: + array_list.append(None) + return array_list diff --git a/chainercv/links/model/fpn/faster_rcnn_fpn_resnet.py b/chainercv/links/model/fpn/faster_rcnn_fpn_resnet.py index 4b86e0cf7e..72d1b7bccb 100644 --- a/chainercv/links/model/fpn/faster_rcnn_fpn_resnet.py +++ b/chainercv/links/model/fpn/faster_rcnn_fpn_resnet.py @@ -4,9 +4,11 @@ import chainer.functions as F import chainer.links as L +from chainercv.links.model.fpn.bbox_head import BboxHead from chainercv.links.model.fpn.faster_rcnn import FasterRCNN from chainercv.links.model.fpn.fpn import FPN -from chainercv.links.model.fpn.head import Head +from chainercv.links.model.fpn.keypoint_head import KeypointHead +from chainercv.links.model.fpn.mask_head import MaskHead from chainercv.links.model.fpn.rpn import RPN from chainercv.links.model.resnet import ResNet101 from chainercv.links.model.resnet import ResNet50 @@ -14,15 +16,44 @@ class FasterRCNNFPNResNet(FasterRCNN): - """Base class for FasterRCNNFPNResNet50 and FasterRCNNFPNResNet101. + """Base class for Faster R-CNN with a ResNet backbone and FPN. A subclass of this class should have :obj:`_base` and :obj:`_models`. + + Args: + n_fg_class (int): The number of classes excluding the background. + pretrained_model (string): The weight file to be loaded. + This can take :obj:`'coco'`, `filepath` or :obj:`None`. + The default value is :obj:`None`. + + * :obj:`'coco'`: Load weights trained on train split of \ + MS COCO 2017. \ + The weight file is downloaded and cached automatically. \ + :obj:`n_fg_class` must be :obj:`80` or :obj:`None`. + * :obj:`'imagenet'`: Load weights of ResNet-50 trained on \ + ImageNet. \ + The weight file is downloaded and cached automatically. \ + This option initializes weights partially and the rests are \ + initialized randomly. In this case, :obj:`n_fg_class` \ + can be set to any number. + * `filepath`: A path of npz file. In this case, :obj:`n_fg_class` \ + must be specified properly. + * :obj:`None`: Do not load weights. + return_values (list of strings): Determines the values + returned by :meth:`predict`. + min_size (int): A preprocessing paramter for :meth:`prepare`. Please \ + refer to :meth:`prepare`. + max_size (int): A preprocessing paramter for :meth:`prepare`. + """ def __init__(self, n_fg_class=None, pretrained_model=None, + n_point=None, + return_values=['bboxes', 'labels', 'scores'], min_size=800, max_size=1333): param, path = utils.prepare_pretrained_model( - {'n_fg_class': n_fg_class}, pretrained_model, self._models) + {'n_fg_class': n_fg_class, 'n_point': n_point}, + pretrained_model, self._models, {'n_point': None}) base = self._base(n_class=1, arch='he') base.pick = ('res2', 'res3', 'res4', 'res5') @@ -32,10 +63,17 @@ def __init__(self, n_fg_class=None, pretrained_model=None, extractor = FPN( base, len(base.pick), (1 / 4, 1 / 8, 1 / 16, 1 / 32, 1 / 64)) + if param['n_point'] is not None: + keypoint_head = KeypointHead(param['n_point'], extractor.scales) + else: + keypoint_head = None super(FasterRCNNFPNResNet, self).__init__( extractor=extractor, rpn=RPN(extractor.scales), - head=Head(param['n_fg_class'] + 1, extractor.scales), + bbox_head=BboxHead(param['n_fg_class'] + 1, extractor.scales), + mask_head=MaskHead(param['n_fg_class'] + 1, extractor.scales), + keypoint_head=keypoint_head, + return_values=return_values, min_size=min_size, max_size=max_size ) @@ -44,41 +82,45 @@ def __init__(self, n_fg_class=None, pretrained_model=None, self.extractor.base, self._base(pretrained_model='imagenet', arch='he')) elif path: - chainer.serializers.load_npz(path, self) + chainer.serializers.load_npz(path, self, strict=False) -class FasterRCNNFPNResNet50(FasterRCNNFPNResNet): - """Feature Pyramid Networks with ResNet-50. +class MaskRCNNFPNResNet(FasterRCNNFPNResNet): + """Mask R-CNN with a ResNet backbone and FPN. - This is a model of Feature Pyramid Networks [#]_. - This model uses :class:`~chainercv.links.ResNet50` as - its base feature extractor. + Please refer to :class:`~chainercv.links.model.fpn.FasterRCNNFPNResNet`. - .. [#] Tsung-Yi Lin et al. - Feature Pyramid Networks for Object Detection. CVPR 2017 + """ - Args: - n_fg_class (int): The number of classes excluding the background. - pretrained_model (string): The weight file to be loaded. - This can take :obj:`'coco'`, `filepath` or :obj:`None`. - The default value is :obj:`None`. + def __init__(self, n_fg_class=None, pretrained_model=None, + return_values=['masks', 'labels', 'scores'], + min_size=800, max_size=1333): + super(MaskRCNNFPNResNet, self).__init__( + n_fg_class, pretrained_model, None, return_values, + min_size, max_size) - * :obj:`'coco'`: Load weights trained on train split of \ - MS COCO 2017. \ - The weight file is downloaded and cached automatically. \ - :obj:`n_fg_class` must be :obj:`80` or :obj:`None`. - * :obj:`'imagenet'`: Load weights of ResNet-50 trained on \ - ImageNet. \ - The weight file is downloaded and cached automatically. \ - This option initializes weights partially and the rests are \ - initialized randomly. In this case, :obj:`n_fg_class` \ - can be set to any number. - * `filepath`: A path of npz file. In this case, :obj:`n_fg_class` \ - must be specified properly. - * :obj:`None`: Do not load weights. - min_size (int): A preprocessing paramter for :meth:`prepare`. Please \ - refer to :meth:`prepare`. - max_size (int): A preprocessing paramter for :meth:`prepare`. + +class KeypointRCNNFPNResNet(FasterRCNNFPNResNet): + """Keypoint R-CNN with a ResNet backbone and FPN. + + Please refer to :class:`~chainercv.links.model.fpn.FasterRCNNFPNResNet`. + + """ + + def __init__(self, n_fg_class=None, pretrained_model=None, + n_point=None, + return_values=['points', 'labels', 'scores', + 'point_scores', 'bboxes'], + min_size=800, max_size=1333): + super(KeypointRCNNFPNResNet, self).__init__( + n_fg_class, pretrained_model, n_point, + return_values, min_size, max_size) + + +class FasterRCNNFPNResNet50(FasterRCNNFPNResNet): + """Faster R-CNN with ResNet-50 and FPN. + + Please refer to :class:`~chainercv.links.model.fpn.FasterRCNNFPNResNet`. """ @@ -87,44 +129,52 @@ class FasterRCNNFPNResNet50(FasterRCNNFPNResNet): 'coco': { 'param': {'n_fg_class': 80}, 'url': 'https://chainercv-models.preferred.jp/' - 'faster_rcnn_fpn_resnet50_coco_trained_2018_12_13.npz', + 'faster_rcnn_fpn_resnet50_coco_trained_2019_03_15.npz', 'cv2': True }, } class FasterRCNNFPNResNet101(FasterRCNNFPNResNet): - """Feature Pyramid Networks with ResNet-101. + """Faster R-CNN with ResNet-101 and FPN. - This is a model of Feature Pyramid Networks [#]_. - This model uses :class:`~chainercv.links.ResNet101` as - its base feature extractor. + Please refer to :class:`~chainercv.links.model.fpn.FasterRCNNFPNResNet`. - .. [#] Tsung-Yi Lin et al. - Feature Pyramid Networks for Object Detection. CVPR 2017 + """ - Args: - n_fg_class (int): The number of classes excluding the background. - pretrained_model (string): The weight file to be loaded. - This can take :obj:`'coco'`, `filepath` or :obj:`None`. - The default value is :obj:`None`. + _base = ResNet101 + _models = { + 'coco': { + 'param': {'n_fg_class': 80}, + 'url': 'https://chainercv-models.preferred.jp/' + 'faster_rcnn_fpn_resnet101_coco_trained_2019_03_15.npz', + 'cv2': True + }, + } - * :obj:`'coco'`: Load weights trained on train split of \ - MS COCO 2017. \ - The weight file is downloaded and cached automatically. \ - :obj:`n_fg_class` must be :obj:`80` or :obj:`None`. - * :obj:`'imagenet'`: Load weights of ResNet-101 trained on \ - ImageNet. \ - The weight file is downloaded and cached automatically. \ - This option initializes weights partially and the rests are \ - initialized randomly. In this case, :obj:`n_fg_class` \ - can be set to any number. - * `filepath`: A path of npz file. In this case, :obj:`n_fg_class` \ - must be specified properly. - * :obj:`None`: Do not load weights. - min_size (int): A preprocessing paramter for :meth:`prepare`. Please \ - refer to :meth:`prepare`. - max_size (int): A preprocessing paramter for :meth:`prepare`. + +class MaskRCNNFPNResNet50(MaskRCNNFPNResNet): + """Mask R-CNN with ResNet-50 and FPN. + + Please refer to :class:`~chainercv.links.model.fpn.FasterRCNNFPNResNet`. + + """ + + _base = ResNet50 + _models = { + 'coco': { + 'param': {'n_fg_class': 80}, + 'url': 'https://chainercv-models.preferred.jp/' + 'faster_rcnn_fpn_resnet50_mask_coco_trained_2019_03_15.npz', + 'cv2': True + }, + } + + +class MaskRCNNFPNResNet101(MaskRCNNFPNResNet): + """Mask R-CNN with ResNet-101 and FPN. + + Please refer to :class:`~chainercv.links.model.fpn.FasterRCNNFPNResNet`. """ @@ -132,8 +182,42 @@ class FasterRCNNFPNResNet101(FasterRCNNFPNResNet): _models = { 'coco': { 'param': {'n_fg_class': 80}, + 'url': '', + 'cv2': True + }, + } + + +class KeypointRCNNFPNResNet50(KeypointRCNNFPNResNet): + """Keypoint R-CNN with ResNet-50 and FPN. + + Please refer to :class:`~chainercv.links.model.fpn.FasterRCNNFPNResNet`. + + """ + + _base = ResNet50 + _models = { + 'coco': { + 'param': {'n_fg_class': 1, 'n_point': 17}, 'url': 'https://chainercv-models.preferred.jp/' - 'faster_rcnn_fpn_resnet101_coco_trained_2018_12_13.npz', + 'faster_rcnn_fpn_resnet50_keypoint_coco_converted_2019_03_15.npz', + 'cv2': True + }, + } + + +class KeypointRCNNFPNResNet101(KeypointRCNNFPNResNet): + """Keypoint R-CNN with ResNet-101 and FPN. + + Please refer to :class:`~chainercv.links.model.fpn.FasterRCNNFPNResNet`. + + """ + + _base = ResNet50 + _models = { + 'coco': { + 'param': {'n_fg_class': 80}, + 'url': '', 'cv2': True }, } diff --git a/chainercv/links/model/fpn/keypoint_head.py b/chainercv/links/model/fpn/keypoint_head.py new file mode 100644 index 0000000000..c0dd00679d --- /dev/null +++ b/chainercv/links/model/fpn/keypoint_head.py @@ -0,0 +1,231 @@ +from __future__ import division + +import numpy as np +import PIL + +import cv2 + +import chainer +import chainer.links as L +import chainer.functions as F +from chainer.backends import cuda +from chainer.initializers import HeNormal + +from chainercv.links import Conv2DActiv +from chainercv.transforms.image.resize import resize +from chainercv.utils.bbox.bbox_iou import bbox_iou + +from chainercv.links.model.fpn.keypoint_utils import point_to_roi_points +from chainercv.links.model.fpn.keypoint_utils import within_bbox + + +# make a bilinear interpolation kernel +# credit @longjon +def _upsample_filt(size): + factor = (size + 1) // 2 + if size % 2 == 1: + center = factor - 1 + else: + center = factor - 0.5 + og = np.ogrid[:size, :size] + return (1 - abs(og[0] - center) / factor) * \ + (1 - abs(og[1] - center) / factor) + + +class KeypointHead(chainer.Chain): + + _canonical_scale = 224 + _roi_size = 14 + _roi_sample_ratio = 2 + point_map_size = 56 + + def __init__(self, n_point, scales): + super(KeypointHead, self).__init__() + + initialW = HeNormal(1, fan_option='fan_out') + with self.init_scope(): + self.conv1 = Conv2DActiv(512, 3, pad=1, initialW=initialW) + self.conv2 = Conv2DActiv(512, 3, pad=1, initialW=initialW) + self.conv3 = Conv2DActiv(512, 3, pad=1, initialW=initialW) + self.conv4 = Conv2DActiv(512, 3, pad=1, initialW=initialW) + self.conv5 = Conv2DActiv(512, 3, pad=1, initialW=initialW) + self.conv6 = Conv2DActiv(512, 3, pad=1, initialW=initialW) + self.conv7 = Conv2DActiv(512, 3, pad=1, initialW=initialW) + self.conv8 = Conv2DActiv(512, 3, pad=1, initialW=initialW) + self.point = L.Deconvolution2D( + n_point, 4, pad=1, stride=2, initialW=initialW) + # Do not update the weight of this link + self.upsample = L.Deconvolution2D( + n_point, n_point, 4, pad=1, stride=2, nobias=True) + self.upsample.W.data[:] = 0 + self.upsample.W.data[np.arange(n_point), np.arange(n_point)] = _upsample_filt(4) + + self._scales = scales + self.n_point = n_point + + def __call__(self, hs, rois, roi_indices): + pooled_hs = [] + for l, h in enumerate(hs): + if len(rois[l]) == 0: + continue + + pooled_hs.append(F.roi_average_align_2d( + h, rois[l], roi_indices[l], + self._roi_size, + self._scales[l], self._roi_sample_ratio)) + + if len(pooled_hs) == 0: + return chainer.Variable( + self.xp.empty( + (0, self.n_point, self.point_map_size, self.point_map_size), + dtype=np.float32)) + + h = F.concat(pooled_hs, axis=0) + h = self.conv1(h) + h = self.conv2(h) + h = self.conv3(h) + h = self.conv4(h) + h = self.conv5(h) + h = self.conv6(h) + h = self.conv7(h) + h = self.conv8(h) + h = self.point(h) + return self.upsample(h) + + def distribute(self, rois, roi_indices): + # Compleetely same as MaskHead.distribute + size = self.xp.sqrt(self.xp.prod(rois[:, 2:] + 1 - rois[:, :2], axis=1)) + level = self.xp.floor(self.xp.log2( + size / self._canonical_scale + 1e-6)).astype(np.int32) + # skip last level + level = self.xp.clip( + level + len(self._scales) // 2, 0, len(self._scales) - 2) + + masks = [level == l for l in range(len(self._scales))] + rois = [rois[mask] for mask in masks] + roi_indices = [roi_indices[mask] for mask in masks] + order = self.xp.argsort( + self.xp.concatenate([self.xp.where(mask)[0] for mask in masks])) + return rois, roi_indices, order + + def decode(self, point_maps, bboxes): + points = [] + point_scores = [] + for bbox, point_map in zip(bboxes, point_maps): + point = np.zeros((len(bbox), self.n_point, 2), dtype=np.float32) + point_score = np.zeros((len(bbox), self.n_point), dtype=np.float32) + + hs = bbox[:, 2] - bbox[:, 0] + ws = bbox[:, 3] - bbox[:, 1] + h_ceils = np.ceil(np.maximum(hs, 1)) + w_ceils = np.ceil(np.maximum(ws, 1)) + h_corrections = hs / h_ceils + w_corrections = ws / w_ceils + for i, (bb, point_m) in enumerate(zip(bbox, point_map)): + point_m = cv2.resize( + point_m.transpose((1, 2, 0)), + (w_ceils[i], h_ceils[i]), + interpolation=cv2.INTER_CUBIC).transpose( + (2, 0, 1)) + _, H, W = point_m.shape + for k in range(self.n_point): + pos = point_m[k].argmax() + x_int = pos % W + y_int = (pos - x_int) // W + + y = (y_int + 0.5) * h_corrections[i] + x = (x_int + 0.5) * w_corrections[i] + point[i, k, 0] = y + bb[0] + point[i, k, 1] = x + bb[1] + point_score[i, k] = point_m[k, y_int, x_int] + points.append(point) + point_scores.append(point_score) + return points, point_scores + + +def keypoint_loss_pre(rois, roi_indices, gt_points, gt_visibles, + gt_bboxes, gt_head_labels, point_map_size): + batchsize_per_image = 512 + fg_ratio = 0.25 + + _, n_point, _ = gt_points[0].shape + + xp = cuda.get_array_module(*rois) + + n_level = len(rois) + + roi_levels = xp.hstack( + xp.array((l,) * len(rois[l])) for l in range(n_level)).astype(np.int32) + rois = xp.vstack(rois).astype(np.float32) + roi_indices = xp.hstack(roi_indices).astype(np.int32) + gt_head_labels = xp.hstack(gt_head_labels) + + # Ignore all negative samples + index = (gt_head_labels > 0).nonzero()[0] + roi_levels = roi_levels[index] + rois = rois[index] + roi_indices = roi_indices[index] + gt_head_labels = gt_head_labels[index] + + gt_head_points = xp.empty( + (len(rois), n_point, 2), dtype=np.float32) + gt_head_visibles = xp.empty( + (len(rois), n_point), dtype=np.bool) + for i in np.unique(cuda.to_cpu(roi_indices)): + gt_point = gt_points[i] + gt_visible = gt_visibles[i] + gt_bbox = gt_bboxes[i] + + index = (roi_indices == i).nonzero()[0] + gt_head_label = gt_head_labels[index] + roi = rois[index] + + iou = bbox_iou(roi, gt_bbox) + gt_index = iou.argmax(axis=1) + gt_head_point, gt_head_visible = point_to_roi_points( + gt_point[gt_index], gt_visible[gt_index], + roi, point_map_size) + gt_head_points[index] = xp.array(gt_head_point) + gt_head_visibles[index] = xp.array(gt_head_visible) + + # Ignore RoIs whose closest bounding box does not contain + # any valid keypoints. + valid_point = within_bbox(gt_point[gt_index], roi) + valid_point = xp.logical_and(valid_point, gt_visible[gt_index]) + visible_roi = valid_point.sum(axis=1) > 0 + gt_head_label[xp.logical_not(visible_roi)] = -1 + gt_head_labels[index] = gt_head_label + + is_sampled = (gt_head_labels > 0).nonzero()[0] + rois = rois[is_sampled] + roi_indices = roi_indices[is_sampled] + roi_levels = roi_levels[is_sampled] + gt_head_points = gt_head_points[is_sampled] + gt_head_visibles = gt_head_visibles[is_sampled] + + flag_masks = [roi_levels == l for l in range(n_level)] + rois = [rois[m] for m in flag_masks] + roi_indices = [roi_indices[m] for m in flag_masks] + gt_head_points = [gt_head_points[m] for m in flag_masks] + gt_head_visibles = [gt_head_visibles[m] for m in flag_masks] + return rois, roi_indices, gt_head_points, gt_head_visibles + + +def keypoint_loss_post( + point_maps, point_roi_indices, gt_head_points, + gt_head_visibles, batchsize): + xp = cuda.get_array_module(point_maps.array) + + point_roi_indices = xp.hstack(point_roi_indices).astype(np.int32) + gt_head_points = xp.vstack(gt_head_points).astype(np.int32) + gt_head_visibles = xp.vstack(gt_head_visibles).astype(np.bool) + + B, K, H, W = point_maps.shape + point_maps = point_maps.reshape((B * K, H * W)) + spatial_labels = gt_head_points[:, :, 0] * W + gt_head_points[:, :, 1] + spatial_labels = spatial_labels.reshape((B * K,)) + spatial_labels[xp.logical_not(gt_head_visibles.reshape((B * K,)))] = -1 + # Remember that the loss is normalized by the total number of + # visible keypoints. + keypoint_loss = F.softmax_cross_entropy(point_maps, spatial_labels) + return keypoint_loss diff --git a/chainercv/links/model/fpn/keypoint_utils.py b/chainercv/links/model/fpn/keypoint_utils.py new file mode 100644 index 0000000000..adc5070528 --- /dev/null +++ b/chainercv/links/model/fpn/keypoint_utils.py @@ -0,0 +1,52 @@ +from __future__ import division + +import numpy as np + +import chainer + + +def point_to_roi_points( + point, visible, bbox, point_map_size): + xp = chainer.backends.cuda.get_array_module(point) + + R, K, _ = point.shape + + roi_point = xp.zeros((len(bbox), K, 2)) + roi_visible = xp.zeros((len(bbox), K), dtype=np.bool) + + offset_y = bbox[:, 0] + offset_x = bbox[:, 1] + scale_y = point_map_size / (bbox[:, 2] - bbox[:, 0]) + scale_x = point_map_size / (bbox[:, 3] - bbox[:, 1]) + + for k in range(K): + y_boundary_index = xp.where(point[:, k, 0] == bbox[:, 2])[0] + x_boundary_index = xp.where(point[:, k, 1] == bbox[:, 3])[0] + + ys = (point[:, k, 0] - offset_y) * scale_y + ys = xp.floor(ys) + if len(y_boundary_index) > 0: + ys[y_boundary_index] = point_map_size - 1 + xs = (point[:, k, 1] - offset_x) * scale_x + xs = xp.floor(xs) + if len(x_boundary_index) > 0: + xs[x_boundary_index] = point_map_size - 1 + + valid = xp.logical_and( + xp.logical_and( + xp.logical_and(ys >= 0, xs >= 0), + xp.logical_and(ys < point_map_size, xs < point_map_size)), + visible[:, k]) + + roi_point[:, k, 0] = ys + roi_point[:, k, 1] = xs + roi_visible[:, k] = valid + return roi_point, roi_visible + + +def within_bbox(point, bbox): + y_within = (point[:, :, 0] >= bbox[:, 0][:, None]) & ( + point[:, :, 0] <= bbox[:, 2][:, None]) + x_within = (point[:, :, 1] >= bbox[:, 1][:, None]) & ( + point[:, :, 1] <= bbox[:, 3][:, None]) + return y_within & x_within diff --git a/chainercv/links/model/fpn/mask_head.py b/chainercv/links/model/fpn/mask_head.py new file mode 100644 index 0000000000..602713838b --- /dev/null +++ b/chainercv/links/model/fpn/mask_head.py @@ -0,0 +1,257 @@ +from __future__ import division + +import numpy as np + +import chainer +from chainer.backends import cuda +import chainer.functions as F +from chainer.initializers import HeNormal +import chainer.links as L + +from chainercv.links import Conv2DActiv +from chainercv.utils.bbox.bbox_iou import bbox_iou + +from chainercv.links.model.fpn.mask_utils import mask_to_segm +from chainercv.links.model.fpn.mask_utils import segm_to_mask + + +class MaskHead(chainer.Chain): + + """Mask Head network of Mask R-CNN. + + Args: + n_class (int): The number of classes including background. + scales (tuple of floats): The scales of feature maps. + + """ + + _canonical_level = 2 + _canonical_scale = 224 + _roi_size = 14 + _roi_sample_ratio = 2 + segm_size = _roi_size * 2 + + def __init__(self, n_class, scales): + super(MaskHead, self).__init__() + + initialW = HeNormal(1, fan_option='fan_out') + with self.init_scope(): + self.conv1 = Conv2DActiv(256, 3, pad=1, initialW=initialW) + self.conv2 = Conv2DActiv(256, 3, pad=1, initialW=initialW) + self.conv3 = Conv2DActiv(256, 3, pad=1, initialW=initialW) + self.conv4 = Conv2DActiv(256, 3, pad=1, initialW=initialW) + self.conv5 = L.Deconvolution2D( + 256, 2, pad=0, stride=2, initialW=initialW) + self.seg = L.Convolution2D(n_class, 1, pad=0, initialW=initialW) + + self._n_class = n_class + self._scales = scales + + def __call__(self, hs, rois, roi_indices): + pooled_hs = [] + for l, h in enumerate(hs): + if len(rois[l]) == 0: + continue + + pooled_hs.append(F.roi_average_align_2d( + h, rois[l], roi_indices[l], + self._roi_size, + self._scales[l], self._roi_sample_ratio)) + + if len(pooled_hs) == 0: + out_size = self.segm_size + segs = chainer.Variable( + self.xp.empty((0, self._n_class, out_size, out_size), + dtype=np.float32)) + return segs + + h = F.concat(pooled_hs, axis=0) + h = self.conv1(h) + h = self.conv2(h) + h = self.conv3(h) + h = self.conv4(h) + h = F.relu(self.conv5(h)) + return self.seg(h) + + def distribute(self, rois, roi_indices): + """Assigns feature levels to Rois based on their size. + + Args: + rois (array): An array of shape :math:`(R, 4)`, \ + where :math:`R` is the total number of RoIs in the given batch. + roi_indices (array): An array of shape :math:`(R,)`. + + Returns: + two lists and one array: + :obj:`out_rois`, :obj:`out_roi_indices` and :obj:`order`. + + * **out_rois**: A list of arrays of shape :math:`(R_l, 4)`, \ + where :math:`R_l` is the number of RoIs in the :math:`l`-th \ + feature map. + * **out_roi_indices** : A list of arrays of shape :math:`(R_l,)`. + * **order**: A correspondence between the output and the input. \ + The relationship below is satisfied. + + .. code:: python + + xp.concatenate(out_rois, axis=0)[order[i]] == rois[i] + + """ + + size = self.xp.sqrt(self.xp.prod(rois[:, 2:] - rois[:, :2], axis=1)) + level = self.xp.floor(self.xp.log2( + size / self._canonical_scale + 1e-6)).astype(np.int32) + # skip last level + level = self.xp.clip( + level + self._canonical_level, 0, len(self._scales) - 2) + + masks = [level == l for l in range(len(self._scales))] + out_rois = [rois[mask] for mask in masks] + out_roi_indices = [roi_indices[mask] for mask in masks] + order = self.xp.argsort( + self.xp.concatenate([self.xp.where(mask)[0] for mask in masks])) + return out_rois, out_roi_indices, order + + def decode(self, segms, bboxes, labels, sizes): + """Decodes back to masks. + + Args: + segms (iterable of arrays): An iterable of arrays of + shape :math:`(R_n, n\_class, M, M)`. + bboxes (iterable of arrays): An iterable of arrays of + shape :math:`(R_n, 4)`. + labels (iterable of arrays): An iterable of arrays of + shape :math:`(R_n,)`. + sizes (list of tuples of two ints): A list of + :math:`(H_n, W_n)`, where :math:`H_n` and :math:`W_n` + are height and width of the :math:`n`-th image. + + Returns: + list of arrays: + This list contains instance segmentation for each image + in the batch. + More precisely, this is a list of boolean arrays of shape + :math:`(R'_n, H_n, W_n)`, where :math:`R'_n` is the number of + bounding boxes in the :math:`n`-th image. + """ + + xp = chainer.backends.cuda.get_array_module(*segms) + if xp != np: + raise ValueError( + 'MaskHead.decode only supports numpy inputs for now.') + masks = [] + for bbox, segm, label, size in zip( + bboxes, segms, labels, sizes): + if len(segm) > 0: + masks.append( + segm_to_mask(segm[np.arange(len(label)), label + 1], + bbox, size)) + else: + masks.append(np.zeros((0,) + size, dtype=np.bool)) + return masks + + +def mask_loss_pre(rois, roi_indices, gt_masks, gt_bboxes, + gt_head_labels, segm_size): + """Loss function for Mask Head (pre). + + This function processes RoIs for :func:`mask_loss_post` by + selecting RoIs for mask loss calculation and + preparing ground truth network output. + + Args: + rois (iterable of arrays): An iterable of arrays of + shape :math:`(R_l, 4)`, where :math:`R_l` is the number + of RoIs in the :math:`l`-th feature map. + roi_indices (iterable of arrays): An iterable of arrays of + shape :math:`(R_l,)`. + gt_masks (iterable of arrays): An iterable of arrays whose shape is + :math:`(R_n, H, W)`, where :math:`R_n` is the number of + ground truth objects. + gt_head_labels (iterable of arrays): An iterable of arrays of + shape :math:`(R_l,)`. This is a collection of ground-truth + labels assigned to :obj:`rois` during bounding box localization + stage. The range of value is :math:`(0, n\_class - 1)`. + segm_size (int): Size of the ground truth network output. + + Returns: + tuple of four lists: + :obj:`mask_rois`, :obj:`mask_roi_indices`, + :obj:`gt_segms`, and :obj:`gt_mask_labels`. + + * **rois**: A list of arrays of shape :math:`(R'_l, 4)`, \ + where :math:`R'_l` is the number of RoIs in the :math:`l`-th \ + feature map. + * **roi_indices**: A list of arrays of shape :math:`(R'_l,)`. + * **gt_segms**: A list of arrays of shape :math:`(R'_l, M, M). \ + :math:`M` is the argument :obj:`segm_size`. + * **gt_mask_labels**: A list of arrays of shape :math:`(R'_l,)` \ + indicating the classes of ground truth. + """ + + xp = cuda.get_array_module(*rois) + + n_level = len(rois) + + roi_levels = xp.hstack( + xp.array((l,) * len(rois[l])) for l in range(n_level)).astype(np.int32) + rois = xp.vstack(rois).astype(np.float32) + roi_indices = xp.hstack(roi_indices).astype(np.int32) + gt_head_labels = xp.hstack(gt_head_labels) + + index = (gt_head_labels > 0).nonzero()[0] + mask_roi_levels = roi_levels[index] + mask_rois = rois[index] + mask_roi_indices = roi_indices[index] + gt_mask_labels = gt_head_labels[index] + + gt_segms = xp.empty( + (len(mask_rois), segm_size, segm_size), dtype=np.float32) + for i in np.unique(cuda.to_cpu(mask_roi_indices)): + gt_mask = gt_masks[i] + gt_bbox = gt_bboxes[i] + + index = (mask_roi_indices == i).nonzero()[0] + mask_roi = mask_rois[index] + iou = bbox_iou(mask_roi, gt_bbox) + gt_index = iou.argmax(axis=1) + gt_segms[index] = xp.array( + mask_to_segm(gt_mask, mask_roi, segm_size, gt_index)) + + flag_masks = [mask_roi_levels == l for l in range(n_level)] + mask_rois = [mask_rois[m] for m in flag_masks] + mask_roi_indices = [mask_roi_indices[m] for m in flag_masks] + gt_segms = [gt_segms[m] for m in flag_masks] + gt_mask_labels = [gt_mask_labels[m] for m in flag_masks] + return mask_rois, mask_roi_indices, gt_segms, gt_mask_labels + + +def mask_loss_post(segms, mask_roi_indices, gt_segms, gt_mask_labels, + batchsize): + """Loss function for Mask Head (post). + + Args: + segms (array): An array whose shape is :math:`(R, n\_class, M, M)`, + where :math:`R` is the total number of RoIs in the given batch. + mask_roi_indices (array): A list of arrays returned by + :func:`mask_loss_pre`. + gt_segms (list of arrays): A list of arrays returned by + :func:`mask_loss_pre`. + gt_mask_labels (list of arrays): A list of arrays returned by + :func:`mask_loss_pre`. + batchsize (int): The size of batch. + + Returns: + chainer.Variable: + Mask loss. + """ + xp = cuda.get_array_module(segms.array) + + mask_roi_indices = xp.hstack(mask_roi_indices).astype(np.int32) + gt_segms = xp.vstack(gt_segms) + gt_mask_labels = xp.hstack(gt_mask_labels).astype(np.int32) + + mask_loss = F.sigmoid_cross_entropy( + segms[np.arange(len(gt_mask_labels)), gt_mask_labels], + gt_segms.astype(np.int32)) + return mask_loss diff --git a/chainercv/links/model/fpn/mask_utils.py b/chainercv/links/model/fpn/mask_utils.py new file mode 100644 index 0000000000..5c28e20232 --- /dev/null +++ b/chainercv/links/model/fpn/mask_utils.py @@ -0,0 +1,157 @@ +from __future__ import division + +import numpy as np + +import chainer + +from chainercv import transforms + + +def mask_to_segm(mask, bbox, segm_size, index=None, pad=1): + """Crop and resize mask. + + Args: + mask (~numpy.ndarray): See below. + bbox (~numpy.ndarray): See below. + segm_size (int): The size of segm :math:`S`. + index (~numpy.ndarray): See below. :math:`R = N` when + :obj:`index` is :obj:`None`. + pad (int): The amount of padding used for bbox. + + Returns: + ~numpy.ndarray: See below. + + .. csv-table:: + :header: name, shape, dtype, format + + :obj:`mask`, ":math:`(N, H, W)`", :obj:`bool`, -- + :obj:`bbox`, ":math:`(R, 4)`", :obj:`float32`, \ + ":math:`(y_{min}, x_{min}, y_{max}, x_{max})`" + :obj:`index` (optional), ":math:`(R,)`", :obj:`int32`, -- + :obj:`segms` (output), ":math:`(R, S, S)`", :obj:`float32`, \ + ":math:`[0, 1]`" + + """ + _, H, W = mask.shape + bbox = chainer.backends.cuda.to_cpu(bbox) + padded_segm_size = segm_size + pad * 2 + expand_scale = padded_segm_size / segm_size + bbox = _integerize_bbox(_expand_boxes(bbox, expand_scale)) + + segm = [] + if index is None: + index = np.arange(len(bbox)) + else: + index = chainer.backends.cuda.to_cpu(index) + + for i, bb in zip(index, bbox): + y_min = max(bb[0], 0) + x_min = max(bb[1], 0) + y_max = max(min(bb[2], H), 0) + x_max = max(min(bb[3], W), 0) + if y_max - y_min == 0 or x_max - x_min == 0: + segm.append(np.zeros((segm_size, segm_size), dtype=np.float32)) + continue + + bb_height = bb[2] - bb[0] + bb_width = bb[3] - bb[1] + cropped_m = np.zeros((bb_height, bb_width), dtype=np.bool) + + y_offset = y_min - bb[0] + x_offset = x_min - bb[1] + cropped_m[y_offset:y_offset + y_max - y_min, + x_offset:x_offset + x_max - x_min] =\ + chainer.backends.cuda.to_cpu(mask[i, y_min:y_max, x_min:x_max]) + + sgm = transforms.resize( + cropped_m[None].astype(np.float32), + (padded_segm_size, padded_segm_size))[0].astype(np.int32) + segm.append(sgm[pad:-pad, pad:-pad]) + + return np.array(segm, dtype=np.float32) + + +def segm_to_mask(segm, bbox, size, pad=1): + """Recover mask from cropped and resized mask. + + Args: + segm (~numpy.ndarray): See below. + bbox (~numpy.ndarray): See below. + size (tuple): This is a tuple of length 2. Its elements are + ordered as (height, width). + pad (int): The amount of padding used for bbox. + + Returns: + ~numpy.ndarray: See below. + + .. csv-table:: + :header: name, shape, dtype, format + + :obj:`segm`, ":math:`(R, S, S)`", :obj:`float32`, -- + :obj:`bbox`, ":math:`(R, 4)`", :obj:`float32`, \ + ":math:`(y_{min}, x_{min}, y_{max}, x_{max})`" + :obj:`mask` (output), ":math:`(R, H, W)`", :obj:`bool`, -- + + """ + H, W = size + _, segm_size, _ = segm.shape + + mask = np.zeros((len(bbox), H, W), dtype=np.bool) + + # To work around an issue with cv2.resize (it seems to automatically + # pad with repeated border values), we manually zero-pad the masks by 1 + # pixel prior to resizing back to the original image resolution. + # This prevents "top hat" artifacts. We therefore need to expand + # the reference boxes by an appropriate factor. + expand_scale = (segm_size + pad * 2) / segm_size + padded_mask = np.zeros( + (segm_size + pad * 2, segm_size + pad * 2), dtype=np.float32) + + bbox = _integerize_bbox(_expand_boxes(bbox, expand_scale)) + for i, (bb, sgm) in enumerate(zip(bbox, segm)): + padded_mask[1:-1, 1:-1] = sgm + + bb_height = bb[2] - bb[0] + bb_width = bb[3] - bb[1] + if bb_height == 0 or bb_width == 0: + continue + + crop_mask = transforms.resize( + padded_mask[None], (bb_height, bb_width))[0] + crop_mask = crop_mask > 0.5 + + y_min = max(bb[0], 0) + x_min = max(bb[1], 0) + y_max = max(min(bb[2], H), 0) + x_max = max(min(bb[3], W), 0) + y_offset = y_min - bb[0] + x_offset = x_min - bb[1] + mask[i, y_min:y_max, x_min:x_max] = crop_mask[ + y_offset:y_offset + y_max - y_min, + x_offset:x_offset + x_max - x_min] + return mask + + +def _integerize_bbox(bbox): + return np.round(bbox).astype(np.int32) + + +def _expand_boxes(bbox, scale): + """Expand an array of boxes by a given scale.""" + xp = chainer.backends.cuda.get_array_module(bbox) + + h_half = (bbox[:, 2] - bbox[:, 0]) * .5 + w_half = (bbox[:, 3] - bbox[:, 1]) * .5 + y_c = (bbox[:, 2] + bbox[:, 0]) * .5 + x_c = (bbox[:, 3] + bbox[:, 1]) * .5 + + h_half *= scale + w_half *= scale + + expanded_bbox = xp.zeros(bbox.shape) + expanded_bbox[:, 0] = y_c - h_half + expanded_bbox[:, 1] = x_c - w_half + expanded_bbox[:, 2] = y_c + h_half + expanded_bbox[:, 3] = x_c + w_half + + return expanded_bbox diff --git a/chainercv/links/model/fpn/misc.py b/chainercv/links/model/fpn/misc.py index c699e3d2f6..7863255ab9 100644 --- a/chainercv/links/model/fpn/misc.py +++ b/chainercv/links/model/fpn/misc.py @@ -5,6 +5,8 @@ from chainer.backends import cuda import chainer.functions as F +from chainercv import transforms + exp_clip = np.log(1000 / 16) @@ -13,6 +15,23 @@ def smooth_l1(x, t, beta): return F.huber_loss(x, t, beta, reduce='no') / beta +def balanced_sampling(label, n_sample, fg_ratio): + label = label.copy() + + xp = cuda.get_array_module(label) + + fg_index = xp.where(label > 0)[0] + n_fg = int(n_sample * fg_ratio) + if len(fg_index) > n_fg: + label[choice(fg_index, size=len(fg_index) - n_fg)] = -1 + + bg_index = xp.where(label == 0)[0] + n_bg = n_sample - int((label > 0).sum()) + if len(bg_index) > n_bg: + label[choice(bg_index, size=len(bg_index) - n_bg)] = -1 + return label + + # to avoid out of memory def argsort(x): xp = cuda.get_array_module(x) @@ -31,3 +50,14 @@ def choice(x, size): return y else: return cuda.to_gpu(y) + + +def scale_img(img, min_size, max_size): + """Process image.""" + _, H, W = img.shape + scale = min_size / min(H, W) + if scale * max(H, W) > max_size: + scale = max_size / max(H, W) + H, W = int(H * scale), int(W * scale) + img = transforms.resize(img, (H, W)) + return img, scale diff --git a/chainercv/transforms/point/flip_point.py b/chainercv/transforms/point/flip_point.py index 104929e5bf..36e279ab7d 100644 --- a/chainercv/transforms/point/flip_point.py +++ b/chainercv/transforms/point/flip_point.py @@ -1,12 +1,11 @@ +import numpy as np + + def flip_point(point, size, y_flip=False, x_flip=False): """Modify points according to image flips. Args: - point (~numpy.ndarray): Points in the image. - The shape of this array is :math:`(P, 2)`. :math:`P` is the number - of points in the image. - The last dimension is composed of :math:`y` and :math:`x` - coordinates of the points. + point (~numpy.ndarray or list of arrays): See the table below. size (tuple): A tuple of length 2. The height and the width of the image, which is associated with the points. y_flip (bool): Modify points according to a vertical flip of @@ -14,15 +13,31 @@ def flip_point(point, size, y_flip=False, x_flip=False): x_flip (bool): Modify keypoipoints according to a horizontal flip of an image. + .. csv-table:: + :header: name, shape, dtype, format + + :obj:`point`, ":math:`[(K, 2)]` or :math:`(R, K, 2)`", \ + :obj:`float32`, ":math:`(y, x)`" + Returns: - ~numpy.ndarray: + ~numpy.ndarray or list of arrays: Points modified according to image flips. """ H, W = size - point = point.copy() - if y_flip: - point[:, 0] = H - point[:, 0] - if x_flip: - point[:, 1] = W - point[:, 1] - return point + if isinstance(point, np.ndarray): + out_point = point.copy() + if y_flip: + out_point[:, :, 0] = H - out_point[:, :, 0] + if x_flip: + out_point[:, :, 1] = W - out_point[:, :, 1] + else: + out_point = [] + for pnt in point: + pnt = pnt.copy() + if y_flip: + pnt[:, 0] = H - pnt[:, 0] + if x_flip: + pnt[:, 1] = W - pnt[:, 1] + out_point.append(pnt) + return out_point diff --git a/chainercv/transforms/point/resize_point.py b/chainercv/transforms/point/resize_point.py index 0991fd4170..061efc0410 100644 --- a/chainercv/transforms/point/resize_point.py +++ b/chainercv/transforms/point/resize_point.py @@ -1,25 +1,38 @@ +import numpy as np + + def resize_point(point, in_size, out_size): """Adapt point coordinates to the rescaled image space. Args: - point (~numpy.ndarray): Points in the image. - The shape of this array is :math:`(P, 2)`. :math:`P` is the number - of points in the image. - The last dimension is composed of :math:`y` and :math:`x` - coordinates of the points. + point (~numpy.ndarray or list of arrays): See the table below. in_size (tuple): A tuple of length 2. The height and the width of the image before resized. out_size (tuple): A tuple of length 2. The height and the width of the image after resized. + .. csv-table:: + :header: name, shape, dtype, format + + :obj:`point`, ":math:`[(K, 2)]` or :math:`(R, K, 2)`", \ + :obj:`float32`, ":math:`(y, x)`" + Returns: - ~numpy.ndarray: + ~numpy.ndarray or list of arrays: Points rescaled according to the given image shapes. """ - point = point.copy() y_scale = float(out_size[0]) / in_size[0] x_scale = float(out_size[1]) / in_size[1] - point[:, 0] = y_scale * point[:, 0] - point[:, 1] = x_scale * point[:, 1] - return point + if isinstance(point, np.ndarray): + out_point = point.copy() + out_point[:, :, 0] = y_scale * point[:, :, 0] + out_point[:, :, 1] = x_scale * point[:, :, 1] + else: + out_point = [] + for pnt in point: + out_pnt = pnt.copy() + out_pnt[:, 0] = y_scale * pnt[:, 0] + out_pnt[:, 1] = x_scale * pnt[:, 1] + out_point.append(out_pnt) + return out_point diff --git a/chainercv/transforms/point/translate_point.py b/chainercv/transforms/point/translate_point.py index bd05f91244..c4a9e911bf 100644 --- a/chainercv/transforms/point/translate_point.py +++ b/chainercv/transforms/point/translate_point.py @@ -1,3 +1,6 @@ +import numpy as np + + def translate_point(point, y_offset=0, x_offset=0): """Translate points. @@ -6,23 +9,32 @@ def translate_point(point, y_offset=0, x_offset=0): to the coordinate :math:`(y, x) = (y_{offset}, x_{offset})`. Args: - point (~numpy.ndarray): Points in the image. - The shape of this array is :math:`(P, 2)`. :math:`P` is the number - of points in the image. - The last dimension is composed of :math:`y` and :math:`x` - coordinates of the points. + point (~numpy.ndarray or list of arrays): See the table below. y_offset (int or float): The offset along y axis. x_offset (int or float): The offset along x axis. + .. csv-table:: + :header: name, shape, dtype, format + + :obj:`point`, ":math:`[(K, 2)]` or :math:`(R, K, 2)`", \ + :obj:`float32`, ":math:`(y, x)`" + Returns: ~numpy.ndarray: Points modified translation of an image. """ - out_point = point.copy() - - out_point[:, 0] += y_offset - out_point[:, 1] += x_offset + if isinstance(point, np.ndarray): + out_point = point.copy() + out_point[:, :, 0] += y_offset + out_point[:, :, 1] += x_offset + else: + out_point = [] + for pnt in point: + out_pnt = pnt.copy() + out_pnt[:, 0] += y_offset + out_pnt[:, 1] += x_offset + out_point.append(out_pnt) return out_point diff --git a/chainercv/utils/testing/assertions/assert_is_instance_segmentation_link.py b/chainercv/utils/testing/assertions/assert_is_instance_segmentation_link.py index 1faf7aaf7e..09f55c900c 100644 --- a/chainercv/utils/testing/assertions/assert_is_instance_segmentation_link.py +++ b/chainercv/utils/testing/assertions/assert_is_instance_segmentation_link.py @@ -21,7 +21,6 @@ def assert_is_instance_segmentation_link(link, n_fg_class): np.random.randint(0, 256, size=(3, 480, 320)).astype(np.float32)] result = link.predict(imgs) - print(result) assert len(result) == 3, \ 'Link must return three elements: masks, labels and scores.' masks, labels, scores = result diff --git a/chainercv/visualizations/__init__.py b/chainercv/visualizations/__init__.py index 2adf9f7ba8..bf77cf892c 100644 --- a/chainercv/visualizations/__init__.py +++ b/chainercv/visualizations/__init__.py @@ -1,5 +1,6 @@ from chainercv.visualizations.vis_bbox import vis_bbox # NOQA from chainercv.visualizations.vis_image import vis_image # NOQA from chainercv.visualizations.vis_instance_segmentation import vis_instance_segmentation # NOQA +from chainercv.visualizations.vis_keypoint_coco import vis_keypoint_coco # NOQA from chainercv.visualizations.vis_point import vis_point # NOQA from chainercv.visualizations.vis_semantic_segmentation import vis_semantic_segmentation # NOQA diff --git a/chainercv/visualizations/vis_keypoint_coco.py b/chainercv/visualizations/vis_keypoint_coco.py new file mode 100644 index 0000000000..61f47e8a27 --- /dev/null +++ b/chainercv/visualizations/vis_keypoint_coco.py @@ -0,0 +1,171 @@ +from __future__ import division + +import numpy as np + +from chainercv.datasets import coco_keypoint_names +from chainercv.visualizations.vis_image import vis_image + + +human_id = 0 + +coco_point_skeleton = [ + [coco_keypoint_names[human_id].index('left_eye'), + coco_keypoint_names[human_id].index('right_eye')], + [coco_keypoint_names[human_id].index('left_eye'), + coco_keypoint_names[human_id].index('nose')], + [coco_keypoint_names[human_id].index('right_eye'), + coco_keypoint_names[human_id].index('nose')], + [coco_keypoint_names[human_id].index('right_eye'), + coco_keypoint_names[human_id].index('right_ear')], + [coco_keypoint_names[human_id].index('left_eye'), + coco_keypoint_names[human_id].index('left_ear')], + [coco_keypoint_names[human_id].index('right_shoulder'), + coco_keypoint_names[human_id].index('right_elbow')], + [coco_keypoint_names[human_id].index('right_elbow'), + coco_keypoint_names[human_id].index('right_wrist')], + [coco_keypoint_names[human_id].index('left_shoulder'), + coco_keypoint_names[human_id].index('left_elbow')], + [coco_keypoint_names[human_id].index('left_elbow'), + coco_keypoint_names[human_id].index('left_wrist')], + [coco_keypoint_names[human_id].index('right_hip'), + coco_keypoint_names[human_id].index('right_knee')], + [coco_keypoint_names[human_id].index('right_knee'), + coco_keypoint_names[human_id].index('right_ankle')], + [coco_keypoint_names[human_id].index('left_hip'), + coco_keypoint_names[human_id].index('left_knee')], + [coco_keypoint_names[human_id].index('left_knee'), + coco_keypoint_names[human_id].index('left_ankle')], + [coco_keypoint_names[human_id].index('right_shoulder'), + coco_keypoint_names[human_id].index('left_shoulder')], + [coco_keypoint_names[human_id].index('right_hip'), + coco_keypoint_names[human_id].index('left_hip')] +] + + +def vis_keypoint_coco( + img, point, visible=None, + point_score=None, thresh=2, + markersize=3, linewidth=1, ax=None): + """Visualize keypoints organized as in COCO. + + Example: + + >>> from chainercv.datasets import COCOKeypointDataset + >>> from chainercv.visualizations import vis_keypoint_coco + >>> import matplotlib.pyplot as plt + >>> data = COCOKeypointDataset(split='val') + >>> img, point, visible = data[10][:3] + >>> vis_keypoint_coco(img, point, visible) + >>> plt.show() + + Args: + img (~numpy.ndarray): See the table below. + If this is :obj:`None`, no image is displayed. + point (~numpy.ndarray): See the table below. + visible (~numpy.ndarray): See the table below. If this is + :obj:`None`, all points are assumed to be visible. + point_score (~numpy.ndarray): See the table below. If this + is :obj:`None`, the confidence of all points is infinitely + large. + thresh (float): Points with confidence below :obj:`thresh` are + not visualized. + markersize (float): The size of vertices. + linewidth (float): The thickness of edges. + ax (matplotlib.axes.Axis): The visualization is displayed on this + axis. If this is :obj:`None` (default), a new axis is created. + + .. csv-table:: + :header: name, shape, dtype, format + + :obj:`img`, ":math:`(3, H, W)`", :obj:`float32`, \ + "RGB, :math:`[0, 255]`" + :obj:`point`, ":math:`(R, K, 2)`", :obj:`float32`, \ + ":math:`(y, x)`" + :obj:`visible`, ":math:`(R, K)`", :obj:`bool`, \ + "true when a keypoint is visible." + :obj:`point_score`, ":math:`(R, K)`", :obj:`float32`, -- + + Returns: + ~matploblib.axes.Axes: + Returns the Axes object with the plot for further tweaking. + + """ + from matplotlib import pyplot as plt + + # Returns newly instantiated matplotlib.axes.Axes object if ax is None + ax = vis_image(img, ax=ax) + + cmap = plt.get_cmap('rainbow') + colors = [cmap(i) for i in np.linspace(0, 1, len(coco_point_skeleton) + 2)] + + if point_score is None: + point_score = np.inf * np.ones(point.shape[:2], dtype=np.float32) + if point_score.shape != point.shape[:2]: + raise ValueError('Mismatch in the number of instances or joints.') + if point.shape[1:] != (len(coco_keypoint_names[human_id]), 2): + raise ValueError('point has invisible shape') + + if visible is not None: + if visible.dtype != np.bool: + raise ValueError('The dtype of `visible` should be np.bool') + if visible.shape != point.shape[:2]: + raise ValueError('Mismatch in the number of instances or joints.') + for i, vld in enumerate(visible): + point_score[i, np.logical_not(vld)] = -np.inf + + for pnt, pnt_sc in zip(point, point_score): + for l in range(len(coco_point_skeleton)): + i0 = coco_point_skeleton[l][0] + i1 = coco_point_skeleton[l][1] + s0 = pnt_sc[i0] + y0 = pnt[i0, 0] + x0 = pnt[i0, 1] + s1 = pnt_sc[i1] + y1 = pnt[i1, 0] + x1 = pnt[i1, 1] + if s0 > thresh and s1 > thresh: + line = ax.plot([x0, x1], [y0, y1]) + plt.setp(line, color=colors[l], + linewidth=linewidth, alpha=0.7) + if s0 > thresh: + ax.plot( + x0, y0, '.', color=colors[l], + markersize=markersize, alpha=0.7) + if s1 > thresh: + ax.plot( + x1, y1, '.', color=colors[l], + markersize=markersize, alpha=0.7) + + # for better visualization, add mid shoulder / mid hip + mid_shoulder = ( + pnt[coco_keypoint_names[human_id].index('right_shoulder'), :2] + + pnt[coco_keypoint_names[human_id].index('left_shoulder'), :2]) / 2 + mid_shoulder_sc = np.minimum( + pnt_sc[coco_keypoint_names[human_id].index('right_shoulder')], + pnt_sc[coco_keypoint_names[human_id].index('left_shoulder')]) + + mid_hip = ( + pnt[coco_keypoint_names[human_id].index('right_hip'), :2] + + pnt[coco_keypoint_names[human_id].index('left_hip'), :2]) / 2 + mid_hip_sc = np.minimum( + pnt_sc[coco_keypoint_names[human_id].index('right_hip')], + pnt_sc[coco_keypoint_names[human_id].index('left_hip')]) + if (mid_shoulder_sc > thresh and + pnt_sc[coco_keypoint_names[human_id].index('nose')] > thresh): + y = [mid_shoulder[0], + pnt[coco_keypoint_names[human_id].index('nose'), 0]] + x = [mid_shoulder[1], + pnt[coco_keypoint_names[human_id].index('nose'), 1]] + line = ax.plot(x, y) + plt.setp( + line, color=colors[len(coco_point_skeleton)], + linewidth=linewidth, alpha=0.7) + if (mid_shoulder_sc > thresh and mid_hip_sc > thresh): + y = [mid_shoulder[0], mid_hip[0]] + x = [mid_shoulder[1], mid_hip[1]] + line = ax.plot(x, y) + plt.setp( + line, color=colors[len(coco_point_skeleton) + 1], + linewidth=linewidth, alpha=0.7) + + return ax diff --git a/docs/source/reference/datasets.rst b/docs/source/reference/datasets.rst index ebf878354e..276c3249d9 100644 --- a/docs/source/reference/datasets.rst +++ b/docs/source/reference/datasets.rst @@ -73,6 +73,10 @@ COCOInstanceSegmentationDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: COCOInstanceSegmentationDataset +COCOKeypointDataset +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: COCOKeypointDataset + COCOSemanticSegmentationDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: COCOSemanticSegmentationDataset diff --git a/docs/source/reference/evaluations.rst b/docs/source/reference/evaluations.rst index 2befc38e47..553f1b52f6 100644 --- a/docs/source/reference/evaluations.rst +++ b/docs/source/reference/evaluations.rst @@ -45,6 +45,10 @@ calc_instance_segmentation_voc_prec_rec ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: calc_instance_segmentation_voc_prec_rec +Keypoint Detection COCO +----------------------- +.. autofunction:: eval_keypoint_detection_coco + Semantic Segmentation IoU ------------------------- diff --git a/docs/source/reference/links.rst b/docs/source/reference/links.rst index 7ccb0ef406..7b4c9709b1 100644 --- a/docs/source/reference/links.rst +++ b/docs/source/reference/links.rst @@ -33,7 +33,6 @@ For more details, please read :func:`FasterRCNN.predict`. .. toctree:: links/faster_rcnn - links/fpn links/ssd links/yolo @@ -52,6 +51,14 @@ For more details, please read :func:`SegNetBasic.predict`. links/deeplab +Links for Multiple Tasks +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. toctree:: + + links/fpn + + Classifiers ~~~~~~~~~~~ diff --git a/docs/source/reference/links/fpn.rst b/docs/source/reference/links/fpn.rst index d97aa3599f..4c01e2a44a 100644 --- a/docs/source/reference/links/fpn.rst +++ b/docs/source/reference/links/fpn.rst @@ -18,6 +18,20 @@ FasterRCNNFPNResnet101 :members: +Instance Segmentation Links +--------------------------- + +MaskRCNNFPNResNet50 +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: MaskRCNNFPNResNet50 + :members: + +MaskRCNNFPNResNet101 +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: MaskRCNNFPNResNet101 + :members: + + Utility ------- @@ -26,14 +40,20 @@ FasterRCNN .. autoclass:: FasterRCNN :members: +FasterRCNNFPNResNet +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: FasterRCNNFPNResNet + :members: + + FPN ~~~ .. autoclass:: FPN :members: -Head -~~~~ -.. autoclass:: Head +BboxHead +~~~~~~~~ +.. autoclass:: BboxHead :members: :special-members: __call__ @@ -43,17 +63,40 @@ RPN :members: :special-members: __call__ +MaskHead +~~~~~~~~ +.. autoclass:: MaskHead + :members: + :special-members: __call__ + +segm_to_mask +~~~~~~~~~~~~ +.. autofunction:: segm_to_mask + + Train-only Utility ------------------ -head_loss_pre +bbox_loss_pre ~~~~~~~~~~~~~ -.. autofunction:: head_loss_pre +.. autofunction:: bbox_loss_pre -head_loss_post +bbox_loss_post ~~~~~~~~~~~~~~ -.. autofunction:: head_loss_post +.. autofunction:: bbox_loss_post rpn_loss ~~~~~~~~ .. autofunction:: rpn_loss + +mask_loss_pre +~~~~~~~~~~~~~ +.. autofunction:: mask_loss_pre + +mask_loss_post +~~~~~~~~~~~~~~ +.. autofunction:: mask_loss_post + +mask_to_segm +~~~~~~~~~~~~ +.. autofunction:: mask_to_segm diff --git a/docs/source/reference/visualizations.rst b/docs/source/reference/visualizations.rst index 685b498e43..c316209839 100644 --- a/docs/source/reference/visualizations.rst +++ b/docs/source/reference/visualizations.rst @@ -12,6 +12,10 @@ vis_image ~~~~~~~~~ .. autofunction:: vis_image +vis_keypoint_coco +~~~~~~~~~~~~~~~~~ +.. autofunction:: vis_keypoint_coco + vis_instance_segmentation ~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: vis_instance_segmentation diff --git a/examples/fpn/demo.py b/examples/fpn/demo.py index 053d0351e2..b11a844eb6 100644 --- a/examples/fpn/demo.py +++ b/examples/fpn/demo.py @@ -4,17 +4,27 @@ import chainer from chainercv.datasets import coco_bbox_label_names +from chainercv.datasets import coco_instance_segmentation_label_names +from chainercv.datasets import coco_keypoint_names from chainercv.links import FasterRCNNFPNResNet101 from chainercv.links import FasterRCNNFPNResNet50 +from chainercv.links import KeypointRCNNFPNResNet101 +from chainercv.links import KeypointRCNNFPNResNet50 +from chainercv.links import MaskRCNNFPNResNet101 +from chainercv.links import MaskRCNNFPNResNet50 from chainercv import utils from chainercv.visualizations import vis_bbox +from chainercv.visualizations import vis_instance_segmentation +from chainercv.visualizations import vis_keypoint_coco def main(): parser = argparse.ArgumentParser() parser.add_argument( '--model', - choices=('faster_rcnn_fpn_resnet50', 'faster_rcnn_fpn_resnet101'), + choices=('faster_rcnn_fpn_resnet50', 'faster_rcnn_fpn_resnet101', + 'mask_rcnn_fpn_resnet50', 'mask_rcnn_fpn_resnet101', + 'keypoint_rcnn_fpn_resnet50', 'keypoint_rcnn_fpn_resnet101'), default='faster_rcnn_fpn_resnet50') parser.add_argument('--gpu', type=int, default=-1) parser.add_argument('--pretrained-model', default='coco') @@ -22,26 +32,71 @@ def main(): args = parser.parse_args() if args.model == 'faster_rcnn_fpn_resnet50': + mode = 'bbox' model = FasterRCNNFPNResNet50( n_fg_class=len(coco_bbox_label_names), pretrained_model=args.pretrained_model) elif args.model == 'faster_rcnn_fpn_resnet101': + mode = 'bbox' model = FasterRCNNFPNResNet101( n_fg_class=len(coco_bbox_label_names), pretrained_model=args.pretrained_model) + elif args.model == 'mask_rcnn_fpn_resnet50': + mode = 'instance_segmentation' + model = MaskRCNNFPNResNet50( + n_fg_class=len(coco_instance_segmentation_label_names), + pretrained_model=args.pretrained_model) + elif args.model == 'mask_rcnn_fpn_resnet101': + mode = 'instance_segmentation' + model = MaskRCNNFPNResNet101( + n_fg_class=len(coco_instance_segmentation_label_names), + pretrained_model=args.pretrained_model) + elif args.model == 'keypoint_rcnn_fpn_resnet50': + mode = 'keypoint' + model = KeypointRCNNFPNResNet50( + n_fg_class=1, + pretrained_model=args.pretrained_model, + n_point=len(coco_keypoint_names[0])) + elif args.model == 'keypoint_rcnn_fpn_resnet101': + mode = 'keypoint' + model = KeypointRCNNFPNResNet101( + n_fg_class=1, + pretrained_model=args.pretrained_model, + n_point=len(coco_keypoint_names[0])) if args.gpu >= 0: chainer.cuda.get_device_from_id(args.gpu).use() model.to_gpu() img = utils.read_image(args.image) - bboxes, labels, scores = model.predict([img]) - bbox = bboxes[0] - label = labels[0] - score = scores[0] - vis_bbox( - img, bbox, label, score, label_names=coco_bbox_label_names) + if mode == 'bbox': + bboxes, labels, scores = model.predict([img]) + bbox = bboxes[0] + label = labels[0] + score = scores[0] + + vis_bbox( + img, bbox, label, score, label_names=coco_bbox_label_names) + elif mode == 'instance_segmentation': + masks, labels, scores = model.predict([img]) + mask = masks[0] + label = labels[0] + score = scores[0] + vis_instance_segmentation( + img, mask, label, score, + label_names=coco_instance_segmentation_label_names) + elif mode == 'keypoint': + points, labels, scores, point_scores, bboxes = model.predict([img]) + point = points[0] + label = labels[0] + score = scores[0] + point_score = point_scores[0] + bbox = bboxes[0] + ax = vis_keypoint_coco( + img, point, None, point_score) + vis_bbox(None, bbox, label, score=score, + label_names=coco_bbox_label_names, ax=ax) plt.show() diff --git a/examples/fpn/train_multi.py b/examples/fpn/train_multi.py index 1adff045d8..f0dd4d0e30 100644 --- a/examples/fpn/train_multi.py +++ b/examples/fpn/train_multi.py @@ -1,10 +1,11 @@ -from __future__ import division - import argparse import multiprocessing import numpy as np +import random +import PIL import chainer +import chainer.functions as F import chainer.links as L from chainer.optimizer_hooks import WeightDecay from chainer import serializers @@ -15,14 +16,30 @@ from chainercv.chainer_experimental.datasets.sliceable import TransformDataset from chainercv.chainer_experimental.training.extensions import make_shift +from chainercv.links.model.fpn.misc import scale_img +from chainercv import transforms + +from chainercv.datasets import coco_instance_segmentation_label_names +from chainercv.datasets import COCOInstanceSegmentationDataset +from chainercv.links import MaskRCNNFPNResNet101 +from chainercv.links import MaskRCNNFPNResNet50 + from chainercv.datasets import coco_bbox_label_names from chainercv.datasets import COCOBboxDataset from chainercv.links import FasterRCNNFPNResNet101 from chainercv.links import FasterRCNNFPNResNet50 -from chainercv import transforms -from chainercv.links.model.fpn import head_loss_post -from chainercv.links.model.fpn import head_loss_pre +from chainercv.datasets import coco_keypoint_names +from chainercv.datasets import COCOKeypointDataset +from chainercv.links import KeypointRCNNFPNResNet101 +from chainercv.links import KeypointRCNNFPNResNet50 + +from chainercv.links.model.fpn import bbox_loss_post +from chainercv.links.model.fpn import bbox_loss_pre +from chainercv.links.model.fpn import keypoint_loss_post +from chainercv.links.model.fpn import keypoint_loss_pre +from chainercv.links.model.fpn import mask_loss_post +from chainercv.links.model.fpn import mask_loss_pre from chainercv.links.model.fpn import rpn_loss # https://docs.chainer.org/en/stable/tips.html#my-training-process-gets-stuck-when-using-multiprocessiterator @@ -35,16 +52,30 @@ class TrainChain(chainer.Chain): - def __init__(self, model): + def __init__(self, model, mode): super(TrainChain, self).__init__() with self.init_scope(): self.model = model - - def __call__(self, imgs, bboxes, labels): - x, scales = self.model.prepare(imgs) - bboxes = [self.xp.array(bbox) * scale - for bbox, scale in zip(bboxes, scales)] + self.mode = mode + + def __call__(self, imgs, bboxes, labels, masks=None, + points=None, visibles=None): + B = len(imgs) + pad_size = np.array( + [im.shape[1:] for im in imgs]).max(axis=0) + pad_size = ( + np.ceil( + pad_size / self.model.stride) * self.model.stride).astype(int) + x = np.zeros( + (len(imgs), 3, pad_size[0], pad_size[1]), dtype=np.float32) + for i, img in enumerate(imgs): + _, H, W = img.shape + x[i, :, :H, :W] = img + x = self.xp.array(x) + + bboxes = [self.xp.array(bbox) for bbox in bboxes] labels = [self.xp.array(label) for label in labels] + sizes = [img.shape[1:] for img in imgs] with chainer.using_config('train', False): hs = self.model.extractor(x) @@ -52,10 +83,7 @@ def __call__(self, imgs, bboxes, labels): rpn_locs, rpn_confs = self.model.rpn(hs) anchors = self.model.rpn.anchors(h.shape[2:] for h in hs) rpn_loc_loss, rpn_conf_loss = rpn_loss( - rpn_locs, rpn_confs, anchors, - [(int(img.shape[1] * scale), int(img.shape[2] * scale)) - for img, scale in zip(imgs, scales)], - bboxes) + rpn_locs, rpn_confs, anchors, sizes, bboxes) rois, roi_indices = self.model.rpn.decode( rpn_locs, rpn_confs, anchors, x.shape) @@ -64,33 +92,127 @@ def __call__(self, imgs, bboxes, labels): [roi_indices] + [self.xp.array((i,) * len(bbox)) for i, bbox in enumerate(bboxes)]) - rois, roi_indices = self.model.head.distribute(rois, roi_indices) - rois, roi_indices, head_gt_locs, head_gt_labels = head_loss_pre( - rois, roi_indices, self.model.head.std, bboxes, labels) - head_locs, head_confs = self.model.head(hs, rois, roi_indices) - head_loc_loss, head_conf_loss = head_loss_post( + rois, roi_indices = self.model.bbox_head.distribute(rois, roi_indices) + rois, roi_indices, head_gt_locs, head_gt_labels = bbox_loss_pre( + rois, roi_indices, self.model.bbox_head.std, bboxes, labels) + head_locs, head_confs = self.model.bbox_head(hs, rois, roi_indices) + head_loc_loss, head_conf_loss = bbox_loss_post( head_locs, head_confs, - roi_indices, head_gt_locs, head_gt_labels, len(x)) - - loss = rpn_loc_loss + rpn_conf_loss + head_loc_loss + head_conf_loss + roi_indices, head_gt_locs, head_gt_labels, B) + + mask_loss = 0 + if self.mode == 'instance_segmentation': + # For reducing unnecessary CPU/GPU copy, `masks` is kept in CPU. + pad_masks = [ + np.zeros( + (mask.shape[0], pad_size[0], pad_size[1]), dtype=np.bool) + for mask in masks] + for i, mask in enumerate(masks): + _, H, W = mask.shape + pad_masks[i][:, :H, :W] = mask + masks = pad_masks + + mask_rois, mask_roi_indices, gt_segms, gt_mask_labels =\ + mask_loss_pre( + rois, roi_indices, masks, bboxes, + head_gt_labels, self.model.mask_head.segm_size) + n_roi = sum([len(roi) for roi in mask_rois]) + if n_roi > 0: + segms = self.model.mask_head(hs, mask_rois, mask_roi_indices) + mask_loss = mask_loss_post( + segms, mask_roi_indices, gt_segms, gt_mask_labels, B) + else: + # Compute dummy variables to complete the computational graph + mask_rois[0] = self.xp.array([[0, 0, 1, 1]], dtype=np.float32) + mask_roi_indices[0] = self.xp.array([0], dtype=np.int32) + segms = self.model.mask_head(hs, mask_rois, mask_roi_indices) + mask_loss = 0 * F.sum(segms) + + point_loss = 0 + if self.mode == 'keypoint': + points = [self.xp.array(point) for point in points] + visibles = [self.xp.array(visible) for visible in visibles] + + point_rois, point_roi_indices, gt_head_points, gt_head_visibles =\ + keypoint_loss_pre( + rois, roi_indices, points, visibles, bboxes, + head_gt_labels, self.model.keypoint_head.point_map_size) + n_roi = sum([len(roi) for roi in point_rois]) + if n_roi > 0: + point_maps = self.model.keypoint_head( + hs, point_rois, point_roi_indices) + point_loss = keypoint_loss_post( + point_maps, point_roi_indices, + gt_head_points, gt_head_visibles, B) + else: + # Compute dummy variables to complete the computational graph + point_rois[0] = self.xp.array([[0, 0, 1, 1]], dtype=np.float32) + point_roi_indices[0] = self.xp.array([0], dtype=np.int32) + point_maps = self.model.keypoint_head( + hs, point_rois, point_roi_indices) + point_loss = 0 * F.sum(point_maps) + + loss = (rpn_loc_loss + rpn_conf_loss + + head_loc_loss + head_conf_loss + mask_loss + point_loss) chainer.reporter.report({ 'loss': loss, 'loss/rpn/loc': rpn_loc_loss, 'loss/rpn/conf': rpn_conf_loss, - 'loss/head/loc': head_loc_loss, 'loss/head/conf': head_conf_loss}, + 'loss/bbox_head/loc': head_loc_loss, + 'loss/bbox_head/conf': head_conf_loss, + 'loss/mask_head': mask_loss, + 'loss/keypoint_head': point_loss}, self) - return loss -def transform(in_data): - img, bbox, label = in_data - - img, params = transforms.random_flip( - img, x_random=True, return_param=True) - bbox = transforms.flip_bbox( - bbox, img.shape[1:], x_flip=params['x_flip']) - - return img, bbox, label +class Transform(object): + + def __init__(self, min_size, max_size, mean, mode): + if not isinstance(min_size, (tuple, list)): + min_size = (min_size,) + self.min_size = min_size + self.max_size = max_size + self.mean = mean + self.mode = mode + + def __call__(self, in_data): + if self.mode == 'bbox': + img, bbox, label = in_data + elif self.mode == 'instance_segmentation': + img, mask, label, bbox = in_data + elif self.mode == 'keypoint': + img, point, visible, label, bbox = in_data + + original_size = img.shape[1:] + # Flipping + img, params = transforms.random_flip( + img, x_random=True, return_param=True) + x_flip = params['x_flip'] + bbox = transforms.flip_bbox( + bbox, img.shape[1:], x_flip=x_flip) + + # Scaling and mean subtraction + min_size = random.choice(self.min_size) + img, scale = scale_img( + img, min_size, self.max_size) + img -= self.mean + bbox = bbox * scale + + if self.mode == 'bbox': + return img, bbox, label + elif self.mode == 'instance_segmentation': + mask = transforms.flip(mask, x_flip=x_flip) + mask = transforms.resize( + mask.astype(np.float32), + img.shape[1:], + interpolation=PIL.Image.NEAREST).astype(np.bool) + return img, bbox, label, mask + elif self.mode == 'keypoint': + point = transforms.flip_point( + point, original_size, x_flip=x_flip) + point = transforms.resize_point( + point, original_size, img.shape[1:]) + return img, bbox, label, None, point, visible def converter(batch, device=None): @@ -98,17 +220,29 @@ def converter(batch, device=None): return tuple(list(v) for v in zip(*batch)) +def valid_point_annotation(visible): + if len(visible) == 0: + return False + min_keypoint_per_image = 10 + n_visible = visible.sum() + return n_visible >= min_keypoint_per_image + + def main(): parser = argparse.ArgumentParser() + parser.add_argument('--data-dir', default='auto') parser.add_argument( '--model', - choices=('faster_rcnn_fpn_resnet50', 'faster_rcnn_fpn_resnet101'), + choices=('mask_rcnn_fpn_resnet50', 'mask_rcnn_fpn_resnet101', + 'faster_rcnn_fpn_resnet50', 'faster_rcnn_fpn_resnet101', + 'keypoint_rcnn_fpn_resnet50', 'keypoint_rcnn_fpn_resnet101'), default='faster_rcnn_fpn_resnet50') parser.add_argument('--batchsize', type=int, default=16) parser.add_argument('--iteration', type=int, default=90000) parser.add_argument('--step', type=int, nargs='*', default=[60000, 80000]) parser.add_argument('--out', default='result') parser.add_argument('--resume') + parser.add_argument('--communicator', default='hierarchical') args = parser.parse_args() # https://docs.chainer.org/en/stable/chainermn/tutorial/tips_faqs.html#using-multiprocessiterator @@ -118,24 +252,70 @@ def main(): p.start() p.join() - comm = chainermn.create_communicator() + comm = chainermn.create_communicator(args.communicator) device = comm.intra_rank if args.model == 'faster_rcnn_fpn_resnet50': + mode = 'bbox' model = FasterRCNNFPNResNet50( - n_fg_class=len(coco_bbox_label_names), pretrained_model='imagenet') + n_fg_class=len(coco_bbox_label_names), + pretrained_model='imagenet') elif args.model == 'faster_rcnn_fpn_resnet101': + mode = 'bbox' model = FasterRCNNFPNResNet101( - n_fg_class=len(coco_bbox_label_names), pretrained_model='imagenet') + n_fg_class=len(coco_bbox_label_names), + pretrained_model='imagenet') + elif args.model == 'mask_rcnn_fpn_resnet50': + mode = 'instance_segmentation' + model = MaskRCNNFPNResNet50( + n_fg_class=len(coco_instance_segmentation_label_names), + pretrained_model='imagenet') + elif args.model == 'mask_rcnn_fpn_resnet101': + mode = 'instance_segmentation' + model = MaskRCNNFPNResNet101( + n_fg_class=len(coco_instance_segmentation_label_names), + pretrained_model='imagenet') + elif args.model == 'keypoint_rcnn_fpn_resnet50': + mode = 'keypoint' + model = KeypointRCNNFPNResNet50( + n_fg_class=1, pretrained_model='imagenet', + n_point=len(coco_keypoint_names[0])) + elif args.model == 'keypoint_rcnn_fpn_resnet101': + mode = 'keypoint' + model = KeypointRCNNFPNResNet101( + n_fg_class=1, pretrained_model='imagenet', + n_point=len(coco_keypoint_names[0])) model.use_preset('evaluate') - train_chain = TrainChain(model) + train_chain = TrainChain(model, mode) chainer.cuda.get_device_from_id(device).use() train_chain.to_gpu() - train = TransformDataset( - COCOBboxDataset(year='2017', split='train'), - ('img', 'bbox', 'label'), transform) + if mode == 'bbox': + transform = Transform( + model.min_size, model.max_size, model.extractor.mean, mode) + train = TransformDataset( + COCOBboxDataset( + data_dir=args.data_dir, year='2017', split='train'), + ('img', 'bbox', 'label'), transform) + elif mode == 'instance_segmentation': + transform = Transform( + model.min_size, model.max_size, model.extractor.mean, mode) + train = TransformDataset( + COCOInstanceSegmentationDataset( + data_dir=args.data_dir, split='train', return_bbox=True), + ('img', 'bbox', 'label', 'mask'), transform) + elif mode == 'keypoint': + train = COCOKeypointDataset(data_dir=args.data_dir, split='train') + indices = [i for i, visible in enumerate(train.slice[:, 'visible']) + if valid_point_annotation(visible)] + train = train.slice[indices] + transform = Transform( + (640, 672, 704, 736, 768, 800), + model.max_size, model.extractor.mean, mode) + train = TransformDataset( + train, + ('img', 'bbox', 'label', 'mask', 'point', 'visible'), transform) if comm.rank == 0: indices = np.arange(len(train)) @@ -144,8 +324,10 @@ def main(): indices = chainermn.scatter_dataset(indices, comm, shuffle=True) train = train.slice[indices] - train_iter = chainer.iterators.MultithreadIterator( - train, args.batchsize // comm.size) + train_iter = chainer.iterators.MultiprocessIterator( + train, args.batchsize // comm.size, + n_processes=args.batchsize // comm.size, + shared_mem=100 * 1000 * 1000 * 4) optimizer = chainermn.create_multi_node_optimizer( chainer.optimizers.MomentumSGD(), comm) @@ -157,11 +339,14 @@ def main(): for link in model.links(): if isinstance(link, L.BatchNormalization): link.disable_update() + if mode == 'keypoint': + model.keypoint_head.upsample.disable_update() + n_iteration = args.iteration * 16 / args.batchsize updater = training.updaters.StandardUpdater( train_iter, optimizer, converter=converter, device=device) trainer = training.Trainer( - updater, (args.iteration * 16 / args.batchsize, 'iteration'), args.out) + updater, (n_iteration, 'iteration'), args.out) @make_shift('lr') def lr_schedule(trainer): @@ -190,7 +375,9 @@ def lr_schedule(trainer): trainer.extend(extensions.PrintReport( ['epoch', 'iteration', 'lr', 'main/loss', 'main/loss/rpn/loc', 'main/loss/rpn/conf', - 'main/loss/head/loc', 'main/loss/head/conf']), + 'main/loss/bbox_head/loc', 'main/loss/bbox_head/conf', + 'main/loss/mask_head', 'main/loss/keypoint_head' + ]), trigger=log_interval) trainer.extend(extensions.ProgressBar(update_interval=10)) @@ -198,7 +385,7 @@ def lr_schedule(trainer): trainer.extend( extensions.snapshot_object( model, 'model_iter_{.updater.iteration}'), - trigger=(90000 * 16 / args.batchsize, 'iteration')) + trigger=(n_iteration, 'iteration')) if args.resume: serializers.load_npz(args.resume, trainer, strict=False) diff --git a/examples/instance_segmentation/eval_instance_segmentation.py b/examples/instance_segmentation/eval_instance_segmentation.py index 3dba29b5b9..9fbb158bdd 100755 --- a/examples/instance_segmentation/eval_instance_segmentation.py +++ b/examples/instance_segmentation/eval_instance_segmentation.py @@ -10,12 +10,18 @@ from chainercv.evaluations import eval_instance_segmentation_coco from chainercv.evaluations import eval_instance_segmentation_voc from chainercv.experimental.links import FCISResNet101 +from chainercv.links import MaskRCNNFPNResNet101 +from chainercv.links import MaskRCNNFPNResNet50 from chainercv.utils import apply_to_iterator from chainercv.utils import ProgressHook models = { # model: (class, dataset -> pretrained_model, default batchsize) 'fcis_resnet101': (FCISResNet101, {'sbd': 'sbd', 'coco': 'coco'}, 1), + 'mask_rcnn_fpn_resnet50': (MaskRCNNFPNResNet50, + {}, 1), + 'mask_rcnn_fpn_resnet101': (MaskRCNNFPNResNet101, + {}, 1), } diff --git a/examples/keypoint_detection/eval_keypoint_detection.py b/examples/keypoint_detection/eval_keypoint_detection.py new file mode 100644 index 0000000000..a5a7ca68d1 --- /dev/null +++ b/examples/keypoint_detection/eval_keypoint_detection.py @@ -0,0 +1,90 @@ +import argparse + +import chainer +from chainer import iterators + +from chainercv.datasets import COCOKeypointDataset +from chainercv.evaluations import eval_keypoint_detection_coco +from chainercv.links import KeypointRCNNFPNResNet101 +from chainercv.links import KeypointRCNNFPNResNet50 +from chainercv.utils import apply_to_iterator +from chainercv.utils import ProgressHook + +models = { + # model: (class, dataset -> pretrained_model, default batchsize) + 'keypoint_rcnn_fpn_resnet50': (KeypointRCNNFPNResNet50, {}, 1), + 'keypoint_rcnn_fpn_resnet101': (KeypointRCNNFPNResNet101, {}, 1), +} + + +def setup(dataset, model_name, pretrained_model, batchsize): + cls, pretrained_models, default_batchsize = models[model_name] + dataset_name = dataset + if pretrained_model is None: + pretrained_model = pretrained_models.get(dataset_name, dataset_name) + if batchsize is None: + batchsize = default_batchsize + + if dataset_name == 'coco': + dataset = COCOKeypointDataset( + split='val', + use_crowded=True, return_crowded=True, + return_area=True) + n_fg_class = 1 + n_point = 17 + model = cls( + n_fg_class=n_fg_class, + pretrained_model=pretrained_model, + n_point=n_point, + ) + model.use_preset('evaluate') + + def eval_(out_values, rest_values): + (pred_points, pred_labels, pred_scores, pred_point_scores, + pred_bboxes) = out_values + (gt_points, gt_visibles, gt_labels, gt_bboxes, + gt_areas, gt_crowdeds) = rest_values + + result = eval_keypoint_detection_coco( + pred_points, pred_labels, pred_scores, + gt_points, gt_visibles, gt_labels, gt_bboxes, + gt_areas, gt_crowdeds) + + print() + for area in ('all', 'large', 'medium'): + print('mmAP ({}):'.format(area), + result['map/iou=0.50:0.95/area={}/max_dets=20'.format( + area)]) + + return dataset, eval_, model, batchsize + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', choices=('coco',), default='coco') + parser.add_argument('--model', choices=sorted(models.keys())) + parser.add_argument('--pretrained-model') + parser.add_argument('--batchsize', type=int) + parser.add_argument('--gpu', type=int, default=-1) + args = parser.parse_args() + + dataset, eval_, model, batchsize = setup( + args.dataset, args.model, args.pretrained_model, args.batchsize) + + if args.gpu >= 0: + chainer.cuda.get_device_from_id(args.gpu).use() + model.to_gpu() + + iterator = iterators.MultithreadIterator( + dataset, batchsize, repeat=False, shuffle=False) + + in_values, out_values, rest_values = apply_to_iterator( + model.predict, iterator, hook=ProgressHook(len(dataset))) + # delete unused iterators explicitly + del in_values + + eval_(out_values, rest_values) + + +if __name__ == '__main__': + main() diff --git a/examples/keypoint_detection/eval_keypoint_detection_multi.py b/examples/keypoint_detection/eval_keypoint_detection_multi.py new file mode 100644 index 0000000000..8a49017c21 --- /dev/null +++ b/examples/keypoint_detection/eval_keypoint_detection_multi.py @@ -0,0 +1,48 @@ +import argparse + +import chainer +from chainer import iterators +import chainermn + +from chainercv.utils import apply_to_iterator +from chainercv.utils import ProgressHook + +from eval_keypoint_detection import models +from eval_keypoint_detection import setup + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--dataset', choices=('coco',), default='coco') + parser.add_argument('--model', choices=sorted(models.keys())) + parser.add_argument('--pretrained-model') + parser.add_argument('--batchsize', type=int) + args = parser.parse_args() + + comm = chainermn.create_communicator() + device = comm.intra_rank + + dataset, eval_, model, batchsize = setup( + args.dataset, args.model, args.pretrained_model, args.batchsize) + + chainer.cuda.get_device_from_id(device).use() + model.to_gpu() + + if not comm.rank == 0: + apply_to_iterator(model.predict, None, comm=comm) + return + + iterator = iterators.MultithreadIterator( + dataset, batchsize * comm.size, repeat=False, shuffle=False) + + in_values, out_values, rest_values = apply_to_iterator( + model.predict, iterator, hook=ProgressHook(len(dataset)), comm=comm) + # delete unused iterators explicitly + del in_values + + eval_(out_values, rest_values) + + +if __name__ == '__main__': + main() diff --git a/examples/mask_rcnn/train_multi_keypoint.py b/examples/mask_rcnn/train_multi_keypoint.py new file mode 100644 index 0000000000..e751aae619 --- /dev/null +++ b/examples/mask_rcnn/train_multi_keypoint.py @@ -0,0 +1,293 @@ +import argparse +import multiprocessing +import numpy as np +import random + +import chainer +import chainer.functions as F +import chainer.links as L +from chainer.optimizer_hooks import WeightDecay +from chainer import serializers +from chainer import training +from chainer.training import extensions + +import chainermn + +from chainercv.chainer_experimental.datasets.sliceable import TransformDataset +from chainercv.chainer_experimental.training.extensions import make_shift +from chainercv.datasets import COCOKeypointDataset +from chainercv.links import MaskRCNNFPNResNet101 +from chainercv.links import MaskRCNNFPNResNet50 +from chainercv.links.model.mask_rcnn.misc import scale_img +from chainercv import transforms + +from chainercv.links.model.fpn import head_loss_post +from chainercv.links.model.fpn import head_loss_pre +from chainercv.links.model.fpn import rpn_loss +from chainercv.links.model.mask_rcnn import keypoint_loss_pre +from chainercv.links.model.mask_rcnn import keypoint_loss_post + +# https://docs.chainer.org/en/stable/tips.html#my-training-process-gets-stuck-when-using-multiprocessiterator +try: + import cv2 + cv2.setNumThreads(0) +except ImportError: + pass + + +class TrainChain(chainer.Chain): + + def __init__(self, model): + super(TrainChain, self).__init__() + with self.init_scope(): + self.model = model + + def __call__(self, imgs, points, visibles, labels, bboxes): + B = len(imgs) + pad_size = np.array( + [im.shape[1:] for im in imgs]).max(axis=0) + pad_size = ( + np.ceil( + pad_size / self.model.stride) * self.model.stride).astype(int) + x = np.zeros( + (len(imgs), 3, pad_size[0], pad_size[1]), dtype=np.float32) + for i, img in enumerate(imgs): + _, H, W = img.shape + x[i, :, :H, :W] = img + x = self.xp.array(x) + + points = [self.xp.array(point) for point in points] + visibles = [self.xp.array(visible) for visible in visibles] + + bboxes = [self.xp.array(bbox) for bbox in bboxes] + assert all([np.all(label == 0) for label in labels]) + labels = [self.xp.array(label) for label in labels] + sizes = [img.shape[1:] for img in imgs] + + with chainer.using_config('train', False): + hs = self.model.extractor(x) + + rpn_locs, rpn_confs = self.model.rpn(hs) + anchors = self.model.rpn.anchors(h.shape[2:] for h in hs) + rpn_loc_loss, rpn_conf_loss = rpn_loss( + rpn_locs, rpn_confs, anchors, sizes, bboxes) + + rois, roi_indices = self.model.rpn.decode( + rpn_locs, rpn_confs, anchors, x.shape) + rois = self.xp.vstack([rois] + bboxes) + roi_indices = self.xp.hstack( + [roi_indices] + + [self.xp.array((i,) * len(bbox)) + for i, bbox in enumerate(bboxes)]) + rois, roi_indices = self.model.head.distribute(rois, roi_indices) + rois, roi_indices, head_gt_locs, head_gt_labels = head_loss_pre( + rois, roi_indices, self.model.head.std, bboxes, labels) + head_locs, head_confs = self.model.head(hs, rois, roi_indices) + head_loc_loss, head_conf_loss = head_loss_post( + head_locs, head_confs, + roi_indices, head_gt_locs, head_gt_labels, B) + losses = [ + rpn_loc_loss + rpn_conf_loss + head_loc_loss + head_conf_loss] + + point_rois, point_roi_indices, gt_head_points, gt_head_visibles = keypoint_loss_pre( + rois, roi_indices, points, visibles, bboxes, head_gt_labels, + self.model.keypoint_head.point_map_size) + n_roi = sum([len(roi) for roi in point_rois]) + if n_roi > 0: + point_maps = self.model.keypoint_head(hs, point_rois, point_roi_indices) + point_loss = keypoint_loss_post( + point_maps, point_roi_indices, + gt_head_points, gt_head_visibles, B) + else: + # Compute dummy variables to complete the computational graph + point_rois[0] = self.xp.array([[0, 0, 1, 1]], dtype=np.float32) + point_roi_indices[0] = self.xp.array([0], dtype=np.int32) + point_maps = self.model.keypoint_head(hs, point_rois, point_roi_indices) + point_loss = 0 * F.sum(point_maps) + losses.append(point_loss) + loss = sum(losses) + chainer.reporter.report({ + 'loss': loss, + 'loss/rpn/loc': rpn_loc_loss, 'loss/rpn/conf': rpn_conf_loss, + 'loss/head/loc': head_loc_loss, 'loss/head/conf': head_conf_loss, + 'loss/keypoint': point_loss}, + self) + return loss + + +class Transform(object): + + def __init__(self, min_size, max_size, mean): + if not isinstance(min_size, (tuple, list)): + min_size = (min_size,) + self.min_size = min_size + self.max_size = max_size + self.mean = mean + + def __call__(self, in_data): + img, point, visible, label, bbox = in_data + # Flipping + size = img.shape[1:] + img, params = transforms.random_flip( + img, x_random=True, return_param=True) + point = transforms.flip_point( + point, size, x_flip=params['x_flip']) + bbox = transforms.flip_bbox( + bbox, size, x_flip=params['x_flip']) + + # Scaling and mean subtraction + min_size = random.choice(self.min_size) + img, scale = scale_img(img, min_size, self.max_size) + img -= self.mean + point = transforms.resize_point(point, size, img.shape[1:]) + bbox = bbox * scale + return img, point, visible, label, bbox + + +def converter(batch, device=None): + # do not send data to gpu (device is ignored) + return tuple(list(v) for v in zip(*batch)) + + +def valid_annotation(visible): + if len(visible) == 0: + return False + min_keypoint_per_image = 10 + n_visible = visible.sum() + return n_visible >= min_keypoint_per_image + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model', + choices=('mask_rcnn_fpn_resnet50', 'mask_rcnn_fpn_resnet101'), + default='mask_rcnn_fpn_resnet50') + parser.add_argument('--batchsize', type=int, default=16) + parser.add_argument('--iteration', type=int, default=90000) + parser.add_argument('--step', type=int, nargs='*', default=[60000, 80000]) + parser.add_argument('--out', default='result') + parser.add_argument('--resume') + parser.add_argument('--communicator', default='hierarchical') + args = parser.parse_args() + + + # from chainer.configuration import global_config + # global_config.cv_resize_backend = 'PIL' + # global_config.cv_read_image_backend = 'PIL' + + # https://docs.chainer.org/en/stable/chainermn/tutorial/tips_faqs.html#using-multiprocessiterator + if hasattr(multiprocessing, 'set_start_method'): + multiprocessing.set_start_method('forkserver') + p = multiprocessing.Process() + p.start() + p.join() + + comm = chainermn.create_communicator(args.communicator) + device = comm.intra_rank + + if args.model == 'mask_rcnn_fpn_resnet50': + model = MaskRCNNFPNResNet50( + n_fg_class=1, + pretrained_model='imagenet', + mode='keypoint' + ) + elif args.model == 'mask_rcnn_fpn_resnet101': + model = MaskRCNNFPNResNet101( + n_fg_class=1, + pretrained_model='imagenet', + mode='keypoint' + ) + + model.use_preset('evaluate') + train_chain = TrainChain(model) + chainer.cuda.get_device_from_id(device).use() + train_chain.to_gpu() + + train = COCOKeypointDataset(split='train') + indices = [i for i, visible in enumerate(train.slice[:, 'visible']) + if valid_annotation(visible)] + train = train.slice[indices] + train = TransformDataset( + train, ('img', 'point', 'visible', 'label', 'bbox'), + Transform( + (640, 672, 704, 736, 768, 800), model.max_size, + model.extractor.mean)) + + if comm.rank == 0: + indices = np.arange(len(train)) + else: + indices = None + indices = chainermn.scatter_dataset(indices, comm, shuffle=True) + train = train.slice[indices] + + train_iter = chainer.iterators.MultiprocessIterator( + train, args.batchsize // comm.size, + n_processes=args.batchsize // comm.size, + shared_mem=10 * 1000 * 1000 * 3) + + optimizer = chainermn.create_multi_node_optimizer( + chainer.optimizers.MomentumSGD(), comm) + optimizer.setup(train_chain) + optimizer.add_hook(WeightDecay(0.0001)) + + model.extractor.base.conv1.disable_update() + model.extractor.base.res2.disable_update() + for link in model.links(): + if isinstance(link, L.BatchNormalization): + link.disable_update() + model.keypoint_head.upsample.disable_update() + + n_iteration = args.iteration * 16 / args.batchsize + updater = training.updaters.StandardUpdater( + train_iter, optimizer, converter=converter, device=device) + trainer = training.Trainer( + updater, (n_iteration, 'iteration'), args.out) + + @make_shift('lr') + def lr_schedule(trainer): + base_lr = 0.02 * args.batchsize / 16 + warm_up_duration = 500 + warm_up_rate = 1 / 3 + + iteration = trainer.updater.iteration + if iteration < warm_up_duration: + rate = warm_up_rate \ + + (1 - warm_up_rate) * iteration / warm_up_duration + else: + rate = 1 + for step in args.step: + if iteration >= step * 16 / args.batchsize: + rate *= 0.1 + + return base_lr * rate + + trainer.extend(lr_schedule) + + if comm.rank == 0: + log_interval = 10, 'iteration' + trainer.extend(extensions.LogReport(trigger=log_interval)) + trainer.extend(extensions.observe_lr(), trigger=log_interval) + trainer.extend(extensions.PrintReport( + ['epoch', 'iteration', 'lr', 'main/loss', + 'main/loss/rpn/loc', 'main/loss/rpn/conf', + 'main/loss/head/loc', 'main/loss/head/conf', + 'main/loss/keypoint' + ]), + trigger=log_interval) + trainer.extend(extensions.ProgressBar(update_interval=10)) + + trainer.extend(extensions.snapshot(), trigger=(10000, 'iteration')) + trainer.extend( + extensions.snapshot_object( + model, 'model_iter_{.updater.iteration}'), + trigger=(n_iteration, 'iteration')) + + if args.resume: + serializers.load_npz(args.resume, trainer, strict=False) + + trainer.run() + + +if __name__ == '__main__': + main() diff --git a/examples_tests/mask_rcnn_tests/test_demo.sh b/examples_tests/mask_rcnn_tests/test_demo.sh new file mode 100644 index 0000000000..344ae45c19 --- /dev/null +++ b/examples_tests/mask_rcnn_tests/test_demo.sh @@ -0,0 +1,8 @@ +cd examples/mask_rcnn +curl -L https://cloud.githubusercontent.com/assets/2062128/26187667/9cb236da-3bd5-11e7-8bcf-7dbd4302e2dc.jpg \ + -o sample.jpg + +$PYTHON demo.py --model mask_rcnn_fpn_resnet50 sample.jpg +$PYTHON demo.py --model mask_rcnn_fpn_resnet50 --gpu 0 sample.jpg +$PYTHON demo.py --model mask_rcnn_fpn_resnet101 sample.jpg +$PYTHON demo.py --model mask_rcnn_fpn_resnet101 --gpu 0 sample.jpg diff --git a/examples_tests/mask_rcnn_tests/test_train_multi.sh b/examples_tests/mask_rcnn_tests/test_train_multi.sh new file mode 100644 index 0000000000..5f5227d2f7 --- /dev/null +++ b/examples_tests/mask_rcnn_tests/test_train_multi.sh @@ -0,0 +1,4 @@ +cd examples/mask_rcnn + +$MPIEXEC $PYTHON train_multi.py --model mask_rcnn_fpn_resnet50 --batchsize 4 --iteration 9 --step 6 8 +$MPIEXEC $PYTHON train_multi.py --model mask_rcnn_fpn_resnet101 --batchsize 4 --iteration 9 --step 6 8 diff --git a/tests/datasets_tests/coco_tests/test_coco_keypoint_dataset.py b/tests/datasets_tests/coco_tests/test_coco_keypoint_dataset.py new file mode 100644 index 0000000000..984245f9ba --- /dev/null +++ b/tests/datasets_tests/coco_tests/test_coco_keypoint_dataset.py @@ -0,0 +1,86 @@ +import unittest + +import numpy as np + +from chainer import testing +from chainer.testing import attr + +from chainercv.datasets import coco_keypoint_names +from chainercv.datasets import COCOKeypointDataset +from chainercv.utils import assert_is_bbox +from chainercv.utils import assert_is_point_dataset + + +def _create_paramters(): + split_years = testing.product({ + 'split': ['train', 'val'], + 'year': ['2014', '2017']}) + split_years += [{'split': 'minival', 'year': '2014'}, + {'split': 'valminusminival', 'year': '2014'}] + use_and_return_args = testing.product({ + 'use_crowded': [False, True], + 'return_crowded': [False, True], + 'return_area': [False, True]}) + params = testing.product_dict( + split_years, + use_and_return_args) + return params + + +@testing.parameterize(*testing.product( + { + 'split': ['train', 'val'], + 'year': ['2014', '2017'], + 'use_crowded': [False, True], + 'return_crowded': [False, True], + 'return_area': [False, True], + } +)) +class TestCOCOKeypointDataset(unittest.TestCase): + + def setUp(self): + self.dataset = COCOKeypointDataset( + split=self.split, year=self.year, + use_crowded=self.use_crowded, return_area=self.return_area, + return_crowded=self.return_crowded) + + @attr.slow + def test_coco_bbox_dataset(self): + human_id = 0 + assert_is_point_dataset( + self.dataset, len(coco_keypoint_names[human_id]), + n_example=30) + + for _ in range(10): + i = np.random.randint(0, len(self.dataset)) + img, point, _, label, bbox = self.dataset[i][:5] + assert_is_bbox(bbox, img.shape[1:]) + self.assertEqual(len(bbox), len(point)) + + self.assertIsInstance(label, np.ndarray) + self.assertEqual(label.dtype, np.int32) + self.assertEqual(label.shape, (point.shape[0],)) + + if self.return_area: + for _ in range(10): + i = np.random.randint(0, len(self.dataset)) + _, point, _, _, _, area = self.dataset[i][:6] + self.assertIsInstance(area, np.ndarray) + self.assertEqual(area.dtype, np.float32) + self.assertEqual(area.shape, (point.shape[0],)) + + if self.return_crowded: + for _ in range(10): + i = np.random.randint(0, len(self.dataset)) + example = self.dataset[i] + crowded = example[-1] + point = example[1] + self.assertIsInstance(crowded, np.ndarray) + self.assertEqual(crowded.dtype, np.bool) + self.assertEqual(crowded.shape, (point.shape[0],)) + + if not self.use_crowded: + np.testing.assert_equal(crowded, 0) + + +testing.run_module(__name__, __file__) diff --git a/tests/evaluations_tests/test_eval_keypoint_detection_coco.py b/tests/evaluations_tests/test_eval_keypoint_detection_coco.py new file mode 100644 index 0000000000..8112f007f8 --- /dev/null +++ b/tests/evaluations_tests/test_eval_keypoint_detection_coco.py @@ -0,0 +1,171 @@ +import numpy as np +import os +from six.moves.urllib import request +import unittest + +from chainer import testing + +from chainercv.datasets import coco_keypoint_names +from chainercv.evaluations import eval_keypoint_detection_coco + +try: + import pycocotools # NOQA + _available = True +except ImportError: + _available = False + + +human_id = 0 + + +def _generate_point(n_inst, size): + H, W = size + n_joint = len(coco_keypoint_names[human_id]) + ys = np.random.uniform(0, H, size=(n_inst, n_joint)) + xs = np.random.uniform(0, W, size=(n_inst, n_joint)) + point = np.stack((ys, xs), axis=2).astype(np.float32) + + valid = np.random.randint(0, 2, size=(n_inst, n_joint)).astype(np.bool) + return point, valid + + +@unittest.skipUnless(_available, 'pycocotools is not installed') +class TestEvalKeypointDetectionCOCOSimple(unittest.TestCase): + + n_inst = 3 + + def setUp(self): + self.pred_points = [] + self.pred_labels = [] + self.pred_scores = [] + self.gt_points = [] + self.gt_visibles = [] + self.gt_bboxes = [] + self.gt_labels = [] + for i in range(2): + point, valid = _generate_point(self.n_inst, (32, 48)) + self.pred_points.append(point) + self.pred_labels.append(np.zeros((self.n_inst,), dtype=np.int32)) + self.pred_scores.append(np.random.uniform( + 0.5, 1, size=(self.n_inst,)).astype(np.float32)) + self.gt_points.append(point) + self.gt_visibles.append(valid) + bbox = np.zeros((self.n_inst, 4), dtype=np.float32) + for i, pnt in enumerate(point): + y_min = np.min(pnt[:, 0]) + x_min = np.min(pnt[:, 1]) + y_max = np.max(pnt[:, 0]) + x_max = np.max(pnt[:, 1]) + bbox[i] = [y_min, x_min, y_max, x_max] + self.gt_bboxes.append(bbox) + self.gt_labels.append(np.zeros((self.n_inst,), dtype=np.int32)) + + def _check(self, result): + self.assertEqual(result['map/iou=0.50:0.95/area=all/max_dets=20'], 1) + self.assertEqual(result['map/iou=0.50/area=all/max_dets=20'], 1) + self.assertEqual(result['map/iou=0.75/area=all/max_dets=20'], 1) + self.assertEqual(result['mar/iou=0.50:0.95/area=all/max_dets=20'], 1) + self.assertEqual(result['mar/iou=0.50/area=all/max_dets=20'], 1) + self.assertEqual(result['mar/iou=0.75/area=all/max_dets=20'], 1) + + def test_gt_bboxes_not_supplied(self): + result = eval_keypoint_detection_coco( + self.pred_points, self.pred_labels, self.pred_scores, + self.gt_points, self.gt_visibles, self.gt_labels, None) + self._check(result) + + def test_area_not_supplied(self): + result = eval_keypoint_detection_coco( + self.pred_points, self.pred_labels, self.pred_scores, + self.gt_points, self.gt_visibles, self.gt_labels, self.gt_bboxes) + self._check(result) + + self.assertFalse( + 'map/iou=0.50:0.95/area=medium/max_dets=20' in result) + self.assertFalse( + 'map/iou=0.50:0.95/area=large/max_dets=20' in result) + self.assertFalse( + 'mar/iou=0.50:0.95/area=medium/max_dets=20' in result) + self.assertFalse( + 'mar/iou=0.50:0.95/area=large/max_dets=20' in result) + + def test_area_supplied(self): + gt_areas = [[100] * self.n_inst for _ in range(2)] + result = eval_keypoint_detection_coco( + self.pred_points, self.pred_labels, self.pred_scores, + self.gt_points, self.gt_visibles, self.gt_labels, self.gt_bboxes, + gt_areas=gt_areas, + ) + self._check(result) + self.assertTrue( + 'map/iou=0.50:0.95/area=medium/max_dets=20' in result) + self.assertTrue( + 'map/iou=0.50:0.95/area=large/max_dets=20' in result) + self.assertTrue( + 'mar/iou=0.50:0.95/area=medium/max_dets=20' in result) + self.assertTrue( + 'mar/iou=0.50:0.95/area=large/max_dets=20' in result) + + def test_crowded_supplied(self): + gt_crowdeds = [[True] * self.n_inst for _ in range(2)] + result = eval_keypoint_detection_coco( + self.pred_points, self.pred_labels, self.pred_scores, + self.gt_points, self.gt_visibles, self.gt_labels, self.gt_bboxes, + gt_crowdeds=gt_crowdeds, + ) + # When the only ground truth is crowded, nothing is evaluated. + # In that case, all the results are nan. + self.assertTrue( + np.isnan(result['map/iou=0.50:0.95/area=all/max_dets=20'])) + + +@unittest.skipUnless(_available, 'pycocotools is not installed') +class TestEvalKeypointDetectionCOCO(unittest.TestCase): + + @classmethod + def setUpClass(cls): + base_url = 'https://chainercv-models.preferred.jp/tests' + + cls.dataset = np.load(request.urlretrieve(os.path.join( + base_url, + 'eval_keypoint_detection_coco_dataset_2019_02_21.npz'))[0]) + cls.result = np.load(request.urlretrieve(os.path.join( + base_url, + 'eval_keypoint_detection_coco_result_2019_02_20.npz'))[0]) + + def test_eval_keypoint_detection_coco(self): + pred_points = self.result['points'] + pred_labels = self.result['labels'] + pred_scores = self.result['scores'] + + gt_points = self.dataset['points'] + gt_visibles = self.dataset['visibles'] + gt_labels = self.dataset['labels'] + gt_bboxes = self.dataset['bboxes'] + gt_areas = self.dataset['areas'] + gt_crowdeds = self.dataset['crowdeds'] + + result = eval_keypoint_detection_coco( + pred_points, pred_labels, pred_scores, + gt_points, gt_visibles, gt_labels, gt_bboxes, + gt_areas, gt_crowdeds) + + expected = { + 'map/iou=0.50:0.95/area=all/max_dets=20': 0.37733572721481323, + 'map/iou=0.50/area=all/max_dets=20': 0.6448841691017151, + 'map/iou=0.75/area=all/max_dets=20': 0.35469090938568115, + 'map/iou=0.50:0.95/area=medium/max_dets=20': 0.3894105851650238, + 'map/iou=0.50:0.95/area=large/max_dets=20': 0.39169296622276306, + 'mar/iou=0.50:0.95/area=all/max_dets=20': 0.5218977928161621, + 'mar/iou=0.50/area=all/max_dets=20': 0.7445255517959595, + 'mar/iou=0.75/area=all/max_dets=20': 0.510948896408081, + 'mar/iou=0.50:0.95/area=medium/max_dets=20': 0.5150684714317322, + 'mar/iou=0.50:0.95/area=large/max_dets=20': 0.5296875238418579, + } + + for key, item in expected.items(): + np.testing.assert_almost_equal( + result[key], expected[key], decimal=5) + + +testing.run_module(__name__, __file__) diff --git a/tests/links_tests/model_tests/fpn_tests/test_faster_rcnn.py b/tests/links_tests/model_tests/fpn_tests/test_faster_rcnn.py index 1d245ac0bd..bebfa4a79b 100644 --- a/tests/links_tests/model_tests/fpn_tests/test_faster_rcnn.py +++ b/tests/links_tests/model_tests/fpn_tests/test_faster_rcnn.py @@ -7,10 +7,13 @@ from chainer import testing from chainer.testing import attr +from chainercv.links.model.fpn import BboxHead from chainercv.links.model.fpn import FasterRCNN -from chainercv.links.model.fpn import Head +from chainercv.links.model.fpn import MaskHead from chainercv.links.model.fpn import RPN +from chainercv.utils import assert_is_bbox from chainercv.utils import assert_is_detection_link +from chainercv.utils import assert_is_instance_segmentation_link def _random_array(xp, shape): @@ -31,28 +34,35 @@ def __call__(self, x): class DummyFasterRCNN(FasterRCNN): - def __init__(self, n_fg_class, min_size, max_size): + def __init__(self, n_fg_class, return_values, min_size, max_size): extractor = DummyExtractor() super(DummyFasterRCNN, self).__init__( extractor=extractor, rpn=RPN(extractor.scales), - head=Head(n_fg_class + 1, extractor.scales), + bbox_head=BboxHead(n_fg_class + 1, extractor.scales), + mask_head=MaskHead(n_fg_class + 1, extractor.scales), + return_values=return_values, min_size=min_size, max_size=max_size, ) @testing.parameterize(*testing.product_dict( + [ + {'return_values': 'detection'}, + {'return_values': 'instance_segmentation'}, + {'return_values': 'rpn'} + ], [ {'n_fg_class': 1}, {'n_fg_class': 5}, {'n_fg_class': 20}, ], [ - { - 'in_sizes': [(480, 640), (320, 320)], - 'min_size': 800, 'max_size': 1333, - 'expected_shape': (800, 1088), - }, + # { + # 'in_sizes': [(480, 640), (320, 320)], + # 'min_size': 800, 'max_size': 1333, + # 'expected_shape': (800, 1088), + # }, { 'in_sizes': [(200, 50), (400, 100)], 'min_size': 200, 'max_size': 320, @@ -63,7 +73,14 @@ def __init__(self, n_fg_class, min_size, max_size): class TestFasterRCNN(unittest.TestCase): def setUp(self): + if self.return_values == 'detection': + return_values = ['bboxes', 'labels', 'scores'] + elif self.return_values == 'instance_segmentation': + return_values = ['masks', 'labels', 'scores'] + elif self.return_values == 'rpn': + return_values = ['rois'] self.link = DummyFasterRCNN(n_fg_class=self.n_fg_class, + return_values=return_values, min_size=self.min_size, max_size=self.max_size) @@ -88,29 +105,20 @@ def test_use_preset(self): def _check_call(self): x = _random_array(self.link.xp, (2, 3, 32, 32)) with chainer.using_config('train', False): - rois, roi_indices, head_locs, head_confs = self.link(x) + hs, rois, roi_indices = self.link(x) - self.assertEqual(len(rois), len(self.link.extractor.scales)) - self.assertEqual(len(roi_indices), len(self.link.extractor.scales)) + self.assertEqual(len(hs), len(self.link.extractor.scales)) for l in range(len(self.link.extractor.scales)): - self.assertIsInstance(rois[l], self.link.xp.ndarray) - self.assertEqual(rois[l].shape[1:], (4,)) - - self.assertIsInstance(roi_indices[l], self.link.xp.ndarray) - self.assertEqual(roi_indices[l].shape[1:], ()) - - self.assertEqual(rois[l].shape[0], roi_indices[l].shape[0]) + self.assertIsInstance(hs[l], chainer.Variable) + self.assertIsInstance(hs[l].data, self.link.xp.ndarray) - n_roi = sum( - len(rois[l]) for l in range(len(self.link.extractor.scales))) + self.assertIsInstance(rois, self.link.xp.ndarray) + self.assertEqual(rois.shape[1:], (4,)) - self.assertIsInstance(head_locs, chainer.Variable) - self.assertIsInstance(head_locs.array, self.link.xp.ndarray) - self.assertEqual(head_locs.shape, (n_roi, self.n_fg_class + 1, 4)) + self.assertIsInstance(roi_indices, self.link.xp.ndarray) + self.assertEqual(roi_indices.shape[1:], ()) - self.assertIsInstance(head_confs, chainer.Variable) - self.assertIsInstance(head_confs.array, self.link.xp.ndarray) - self.assertEqual(head_confs.shape, (n_roi, self.n_fg_class + 1)) + self.assertEqual(rois.shape[0], roi_indices.shape[0]) def test_call_cpu(self): self._check_call() @@ -126,13 +134,32 @@ def test_call_train_mode(self): with chainer.using_config('train', True): self.link(x) + def _check_predict(self): + if self.return_values == 'detection': + assert_is_detection_link(self.link, self.n_fg_class) + elif self.return_values == 'instance_segmentation': + assert_is_instance_segmentation_link(self.link, self.n_fg_class) + elif self.return_values == 'rpn': + imgs = [ + np.random.randint( + 0, 256, size=(3, 480, 320)).astype(np.float32), + np.random.randint( + 0, 256, size=(3, 480, 320)).astype(np.float32)] + result = self.link.predict(imgs) + assert len(result) == 1 + assert len(result[0]) == 1 + for i in range(len(result[0])): + roi = result[0][i] + assert_is_bbox(roi) + + @attr.slow def test_predict_cpu(self): - assert_is_detection_link(self.link, self.n_fg_class) + self._check_predict() @attr.gpu def test_predict_gpu(self): self.link.to_gpu() - assert_is_detection_link(self.link, self.n_fg_class) + self._check_predict() def test_prepare(self): imgs = [_random_array(np, (3, s[0], s[1])) for s in self.in_sizes] diff --git a/tests/links_tests/model_tests/fpn_tests/test_faster_rcnn_fpn_resnet.py b/tests/links_tests/model_tests/fpn_tests/test_faster_rcnn_fpn_resnet.py index cf5537ed3e..3ac43292fc 100644 --- a/tests/links_tests/model_tests/fpn_tests/test_faster_rcnn_fpn_resnet.py +++ b/tests/links_tests/model_tests/fpn_tests/test_faster_rcnn_fpn_resnet.py @@ -6,17 +6,21 @@ from chainercv.links import FasterRCNNFPNResNet101 from chainercv.links import FasterRCNNFPNResNet50 +from chainercv.links import MaskRCNNFPNResNet101 +from chainercv.links import MaskRCNNFPNResNet50 from chainercv.utils.testing import attr @testing.parameterize(*testing.product({ - 'model': [FasterRCNNFPNResNet50, FasterRCNNFPNResNet101], + 'model': [FasterRCNNFPNResNet50, FasterRCNNFPNResNet101, + MaskRCNNFPNResNet50, MaskRCNNFPNResNet101], 'n_fg_class': [1, 5, 20], })) class TestFasterRCNNFPNResNet(unittest.TestCase): def setUp(self): - self.link = self.model(n_fg_class=self.n_fg_class) + self.link = self.model( + n_fg_class=self.n_fg_class, min_size=66) def _check_call(self): imgs = [ @@ -40,7 +44,8 @@ def test_call_gpu(self): @testing.parameterize(*testing.product({ - 'model': [FasterRCNNFPNResNet50, FasterRCNNFPNResNet101], + 'model': [FasterRCNNFPNResNet50, FasterRCNNFPNResNet101, + MaskRCNNFPNResNet50, MaskRCNNFPNResNet101], 'n_fg_class': [None, 10, 80], 'pretrained_model': ['coco', 'imagenet'], })) diff --git a/tests/links_tests/model_tests/fpn_tests/test_mask_head.py b/tests/links_tests/model_tests/fpn_tests/test_mask_head.py new file mode 100644 index 0000000000..ba48d58646 --- /dev/null +++ b/tests/links_tests/model_tests/fpn_tests/test_mask_head.py @@ -0,0 +1,236 @@ +from __future__ import division + +import numpy as np +import unittest + +import chainer +from chainer import testing +from chainer.testing import attr + +from chainercv.links.model.fpn import mask_loss_post +from chainercv.links.model.fpn import mask_loss_pre +from chainercv.links.model.fpn import MaskHead + +from chainercv.utils import mask_to_bbox + + +def _random_array(xp, shape): + return xp.array( + np.random.uniform(-1, 1, size=shape), dtype=np.float32) + + +@testing.parameterize( + {'n_class': 1 + 1}, + {'n_class': 5 + 1}, + {'n_class': 20 + 1}, +) +class TestMaskHead(unittest.TestCase): + + def setUp(self): + self.link = MaskHead( + n_class=self.n_class, scales=(1 / 2, 1 / 4, 1 / 8)) + + def _check_call(self): + hs = [ + chainer.Variable(_random_array(self.link.xp, (2, 64, 32, 32))), + chainer.Variable(_random_array(self.link.xp, (2, 64, 16, 16))), + chainer.Variable(_random_array(self.link.xp, (2, 64, 8, 8))), + ] + rois = [ + self.link.xp.array(((4, 1, 6, 3),), dtype=np.float32), + self.link.xp.array( + ((0, 1, 2, 3), (5, 4, 10, 6)), dtype=np.float32), + self.link.xp.array(((10, 4, 12, 10),), dtype=np.float32), + ] + roi_indices = [ + self.link.xp.array((0,), dtype=np.int32), + self.link.xp.array((1, 0), dtype=np.int32), + self.link.xp.array((1,), dtype=np.int32), + ] + + segs = self.link(hs, rois, roi_indices) + + self.assertIsInstance(segs, chainer.Variable) + self.assertIsInstance(segs.array, self.link.xp.ndarray) + self.assertEqual( + segs.shape, + (4, self.n_class, self.link.segm_size, self.link.segm_size)) + + def test_call_cpu(self): + self._check_call() + + @attr.gpu + def test_call_gpu(self): + self.link.to_gpu() + self._check_call() + + def _check_distribute(self): + rois = self.link.xp.array(( + (0, 0, 10, 10), + (0, 1000, 0, 1000), + (0, 0, 224, 224), + (100, 100, 224, 224), + ), dtype=np.float32) + roi_indices = self.link.xp.array((0, 1, 0, 0), dtype=np.int32) + n_roi = len(roi_indices) + + rois, roi_indices, order = self.link.distribute(rois, roi_indices) + + self.assertEqual(len(rois), 3) + self.assertEqual(len(roi_indices), 3) + for l in range(3): + self.assertIsInstance(rois[l], self.link.xp.ndarray) + self.assertIsInstance(roi_indices[l], self.link.xp.ndarray) + + self.assertEqual(rois[l].shape[0], roi_indices[l].shape[0]) + self.assertEqual(rois[l].shape[1:], (4,)) + self.assertEqual(roi_indices[l].shape[1:], ()) + + self.assertEqual(sum(rois[l].shape[0] for l in range(3)), 4) + + self.assertEqual(len(order), n_roi) + self.assertIsInstance(order, self.link.xp.ndarray) + + def test_distribute_cpu(self): + self._check_distribute() + + @attr.gpu + def test_distribute_gpu(self): + self.link.to_gpu() + self._check_distribute() + + def _check_decode(self): + segms = [ + _random_array( + self.link.xp, + (1, self.n_class, self.link.segm_size, self.link.segm_size)), + _random_array( + self.link.xp, + (2, self.n_class, self.link.segm_size, self.link.segm_size)), + _random_array( + self.link.xp, + (1, self.n_class, self.link.segm_size, self.link.segm_size)) + ] + bboxes = [ + self.link.xp.array(((4, 1, 6, 3),), dtype=np.float32), + self.link.xp.array( + ((0, 1, 2, 3), (5, 4, 10, 6)), dtype=np.float32), + self.link.xp.array(((10, 4, 12, 10),), dtype=np.float32), + ] + labels = [ + self.link.xp.random.randint( + 0, self.n_class - 1, size=(1,), dtype=np.int32), + self.link.xp.random.randint( + 0, self.n_class - 1, size=(2,), dtype=np.int32), + self.link.xp.random.randint( + 0, self.n_class - 1, size=(1,), dtype=np.int32), + ] + + sizes = [(56, 56), (48, 48), (72, 72)] + masks = self.link.decode( + segms, bboxes, labels, sizes) + + self.assertEqual(len(masks), 3) + for n in range(3): + self.assertIsInstance(masks[n], self.link.xp.ndarray) + + self.assertEqual(masks[n].shape[0], labels[n].shape[0]) + self.assertEqual(masks[n].shape[1:], sizes[n]) + + def test_decode_cpu(self): + self._check_decode() + + +class TestMaskHeadLoss(unittest.TestCase): + + def _check_mask_loss_pre(self, xp): + n_inst = 12 + segm_size = 28 + rois = [ + xp.array(((4, 1, 6, 3),), dtype=np.float32), + xp.array( + ((0, 1, 2, 3), (5, 4, 10, 6)), dtype=np.float32), + xp.array(((10, 4, 12, 10),), dtype=np.float32), + ] + roi_indices = [ + xp.array((0,), dtype=np.int32), + xp.array((1, 0), dtype=np.int32), + xp.array((1,), dtype=np.int32), + ] + masks = [ + _random_array(xp, (n_inst, 60, 70)), + _random_array(xp, (n_inst, 60, 70)), + ] + bboxes = [mask_to_bbox(mask) for mask in masks] + labels = [ + xp.array((1,), dtype=np.int32), + xp.array((10, 4), dtype=np.int32), + xp.array((3,), dtype=np.int32), + ] + rois, roi_indices, gt_segms, gt_mask_labels = mask_loss_pre( + rois, roi_indices, masks, bboxes, labels, segm_size) + + self.assertEqual(len(rois), 3) + self.assertEqual(len(roi_indices), 3) + self.assertEqual(len(gt_segms), 3) + self.assertEqual(len(gt_mask_labels), 3) + for l in range(3): + self.assertIsInstance(rois[l], xp.ndarray) + self.assertIsInstance(roi_indices[l], xp.ndarray) + self.assertIsInstance(gt_segms[l], xp.ndarray) + self.assertIsInstance(gt_mask_labels[l], xp.ndarray) + + self.assertEqual(rois[l].shape[0], roi_indices[l].shape[0]) + self.assertEqual(rois[l].shape[0], gt_segms[l].shape[0]) + self.assertEqual(rois[l].shape[0], gt_mask_labels[l].shape[0]) + self.assertEqual(rois[l].shape[1:], (4,)) + self.assertEqual(roi_indices[l].shape[1:], ()) + self.assertEqual(gt_segms[l].shape[1:], (segm_size, segm_size)) + self.assertEqual(gt_mask_labels[l].shape[1:], ()) + self.assertEqual(gt_segms[l].dtype, np.float32) + self.assertEqual(gt_mask_labels[l].dtype, np.int32) + + def test_mask_loss_pre_cpu(self): + self._check_mask_loss_pre(np) + + @attr.gpu + def test_mask_loss_pre_gpu(self): + import cupy + self._check_mask_loss_pre(cupy) + + def _check_mask_loss_post(self, xp): + B = 2 + segms = chainer.Variable(_random_array(xp, (20, 81, 28, 28))) + mask_roi_indices = [ + xp.random.randint(0, B, size=5).astype(np.int32), + xp.random.randint(0, B, size=7).astype(np.int32), + xp.random.randint(0, B, size=8).astype(np.int32), + ] + gt_segms = [ + _random_array(xp, (5, 28, 28)), + _random_array(xp, (7, 28, 28)), + _random_array(xp, (8, 28, 28)), + ] + gt_mask_labels = [ + xp.random.randint(0, 80, size=5).astype(np.int32), + xp.random.randint(0, 80, size=7).astype(np.int32), + xp.random.randint(0, 80, size=8).astype(np.int32), + ] + + mask_loss = mask_loss_post( + segms, mask_roi_indices, gt_segms, gt_mask_labels, B) + + self.assertIsInstance(mask_loss, chainer.Variable) + self.assertIsInstance(mask_loss.array, xp.ndarray) + self.assertEqual(mask_loss.shape, ()) + + def test_mask_loss_post_cpu(self): + self._check_mask_loss_post(np) + + @attr.gpu + def test_mask_loss_post_gpu(self): + import cupy + self._check_mask_loss_post(cupy) + + +testing.run_module(__name__, __file__) diff --git a/tests/links_tests/model_tests/fpn_tests/test_mask_utils.py b/tests/links_tests/model_tests/fpn_tests/test_mask_utils.py new file mode 100644 index 0000000000..c6bcd360d0 --- /dev/null +++ b/tests/links_tests/model_tests/fpn_tests/test_mask_utils.py @@ -0,0 +1,53 @@ +from __future__ import division + +import numpy as np +import unittest + +from chainer import testing + +from chainercv.links.model.fpn.mask_utils import mask_to_segm +from chainercv.links.model.fpn.mask_utils import segm_to_mask + + +class TestSegmToMask(unittest.TestCase): + + def setUp(self): + # When n_inst >= 3, the test fails. + # This is due to the fact that the transformed + # image of `transforms.resize` is misaligned to the corners. + n_inst = 2 + self.segm_size = 3 + self.size = (36, 48) + + self.segm = np.ones( + (n_inst, self.segm_size, self.segm_size), dtype=np.float32) + self.bbox = np.zeros((n_inst, 4), dtype=np.float32) + for i in range(n_inst): + self.bbox[i, 0] = 10 + i + self.bbox[i, 1] = 10 + i + self.bbox[i, 2] = self.bbox[i, 0] + self.segm_size * (1 + i) + self.bbox[i, 3] = self.bbox[i, 1] + self.segm_size * (1 + i) + + self.mask = np.zeros((n_inst,) + self.size, dtype=np.bool) + for i, bb in enumerate(self.bbox): + bb = bb.astype(np.int32) + self.mask[i, bb[0]:bb[2], bb[1]:bb[3]] = 1 + + def test_segm_to_mask(self): + mask = segm_to_mask(self.segm, self.bbox, self.size) + np.testing.assert_equal(mask, self.mask) + + def test_mask_to_segm(self): + segm = mask_to_segm(self.mask, self.bbox, self.segm_size) + np.testing.assert_equal(segm, self.segm) + + def test_mask_to_segm_index(self): + index = np.arange(len(self.bbox))[::-1] + segm = mask_to_segm( + self.mask, self.bbox[::-1], + self.segm_size, index=index) + segm = segm[::-1] + np.testing.assert_equal(segm, self.segm) + + +testing.run_module(__name__, __file__) diff --git a/tests/links_tests/model_tests/mask_rcnn_tests/test_keypoint_head.py b/tests/links_tests/model_tests/mask_rcnn_tests/test_keypoint_head.py new file mode 100644 index 0000000000..17616f156c --- /dev/null +++ b/tests/links_tests/model_tests/mask_rcnn_tests/test_keypoint_head.py @@ -0,0 +1,144 @@ +from __future__ import division + +import numpy as np +import unittest + +import chainer +from chainer import testing +from chainer.testing import attr + +from chainercv.links.model.mask_rcnn import KeypointHead +from chainercv.links.model.mask_rcnn import keypoint_loss_post +from chainercv.links.model.mask_rcnn import keypoint_loss_pre + + +def _random_array(xp, shape): + return xp.array( + np.random.uniform(-1, 1, size=shape), dtype=np.float32) + + +def _point_to_bbox(point, visible=None): + xp = chainer.backends.cuda.get_array_module(point) + + bbox = xp.zeros((len(point), 4), dtype=np.float32) + + for i, pnt in enumerate(point): + if visible is None: + vsbl = xp.ones((len(pnt),), dtype=np.bool) + else: + vsbl = visible[i] + pnt = pnt[vsbl] + bbox[i, 0] = xp.min(pnt[:, 0]) + bbox[i, 1] = xp.min(pnt[:, 1]) + bbox[i, 2] = xp.max(pnt[:, 0]) + bbox[i, 3] = xp.max(pnt[:, 1]) + return bbox + + +class TestKeypointHeadLoss(unittest.TestCase): + + def _check_keypoint_loss_pre(self, xp): + point_map_size = 28 + n_point = 17 + rois = [ + xp.array(((4, 1, 6, 3),), dtype=np.float32), + xp.array( + ((0, 1, 2, 3), (5, 4, 10, 6)), dtype=np.float32), + xp.array(((10, 4, 12, 10),), dtype=np.float32), + ] + roi_indices = [ + xp.array((0,), dtype=np.int32), + xp.array((1, 0), dtype=np.int32), + xp.array((1,), dtype=np.int32), + ] + points = [ + xp.zeros((1, n_point, 2), dtype=np.float32), + xp.zeros((3, n_point, 2), dtype=np.float32), + ] + visibles = [ + xp.ones((1, n_point), dtype=np.bool), + xp.ones((3, n_point), dtype=np.bool), + ] + bboxes = [_point_to_bbox(point, visible) + for point, visible in zip(points, visibles)] + labels = [ + xp.array((1,), dtype=np.int32), + xp.array((1, 1), dtype=np.int32), + xp.array((1,), dtype=np.int32), + ] + rois, roi_indices, gt_roi_points, gt_roi_visibles = keypoint_loss_pre( + rois, roi_indices, points, visibles, bboxes, + labels, point_map_size) + + self.assertEqual(len(rois), 3) + self.assertEqual(len(roi_indices), 3) + self.assertEqual(len(gt_roi_points), 3) + self.assertEqual(len(gt_roi_visibles), 3) + for l in range(3): + self.assertIsInstance(rois[l], xp.ndarray) + self.assertIsInstance(roi_indices[l], xp.ndarray) + self.assertIsInstance(gt_roi_points[l], xp.ndarray) + self.assertIsInstance(gt_roi_visibles[l], xp.ndarray) + + self.assertEqual(rois[l].shape[0], roi_indices[l].shape[0]) + self.assertEqual(rois[l].shape[0], gt_roi_points[l].shape[0]) + self.assertEqual(rois[l].shape[0], gt_roi_visibles[l].shape[0]) + self.assertEqual(rois[l].shape[1:], (4,)) + self.assertEqual(roi_indices[l].shape[1:], ()) + self.assertEqual( + gt_roi_points[l].shape[1:], (n_point, 2)) + self.assertEqual( + gt_roi_visibles[l].shape[1:], (n_point,)) + + self.assertEqual( + gt_roi_points[l].dtype, np.float32) + self.assertEqual( + gt_roi_visibles[l].dtype, np.bool) + + def test_keypoint_loss_pre_cpu(self): + self._check_keypoint_loss_pre(np) + + @attr.gpu + def test_keypoint_loss_pre_gpu(self): + import cupy + self._check_keypoint_loss_pre(cupy) + + def _check_keypoint_loss_post(self, xp): + B = 2 + n_point = 17 + + point_maps = chainer.Variable(_random_array(xp, (20, n_point, 28, 28))) + point_roi_indices = [ + xp.random.randint(0, B, size=5).astype(np.int32), + xp.random.randint(0, B, size=7).astype(np.int32), + xp.random.randint(0, B, size=8).astype(np.int32), + ] + gt_roi_points = [ + xp.random.randint(0, 28, size=(5, n_point, 2)).astype(np.int32), + xp.random.randint(0, 28, size=(7, n_point, 2)).astype(np.int32), + xp.random.randint(0, 28, size=(8, n_point, 2)).astype(np.int32), + ] + gt_roi_visibles = [ + xp.random.randint(0, 2, size=(5, n_point)).astype(np.bool), + xp.random.randint(0, 2, size=(7, n_point)).astype(np.bool), + xp.random.randint(0, 2, size=(8, n_point)).astype(np.bool), + ] + + keypoint_loss = keypoint_loss_post( + point_maps, point_roi_indices, gt_roi_points, + gt_roi_visibles, B) + + self.assertIsInstance(keypoint_loss, chainer.Variable) + self.assertIsInstance(keypoint_loss.array, xp.ndarray) + self.assertEqual(keypoint_loss.shape, ()) + + def test_keypoint_loss_post_cpu(self): + self._check_keypoint_loss_post(np) + + @attr.gpu + def test_keypoint_loss_post_gpu(self): + import cupy + self._check_keypoint_loss_post(cupy) + + +testing.run_module(__name__, __file__) diff --git a/tests/transforms_tests/point_tests/test_flip_point.py b/tests/transforms_tests/point_tests/test_flip_point.py index ac6dc4d690..f02ae8b33d 100644 --- a/tests/transforms_tests/point_tests/test_flip_point.py +++ b/tests/transforms_tests/point_tests/test_flip_point.py @@ -8,19 +8,35 @@ class TestFlipPoint(unittest.TestCase): - def test_flip_point(self): + def test_flip_point_ndarray(self): point = np.random.uniform( - low=0., high=32., size=(12, 2)) + low=0., high=32., size=(3, 12, 2)) out = flip_point(point, size=(34, 32), y_flip=True) point_expected = point.copy() - point_expected[:, 0] = 34 - point[:, 0] + point_expected[:, :, 0] = 34 - point[:, :, 0] np.testing.assert_equal(out, point_expected) out = flip_point(point, size=(34, 32), x_flip=True) point_expected = point.copy() - point_expected[:, 1] = 32 - point[:, 1] + point_expected[:, :, 1] = 32 - point[:, :, 1] np.testing.assert_equal(out, point_expected) + def test_flip_point_list(self): + point = [np.random.uniform( + low=0., high=32., size=(12, 2))] + + out = flip_point(point, size=(34, 32), y_flip=True) + for i, pnt in enumerate(point): + pnt_expected = pnt.copy() + pnt_expected[:, 0] = 34 - pnt[:, 0] + np.testing.assert_equal(out[i], pnt_expected) + + out = flip_point(point, size=(34, 32), x_flip=True) + for i, pnt in enumerate(point): + pnt_expected = pnt.copy() + pnt_expected[:, 1] = 32 - pnt[:, 1] + np.testing.assert_equal(out[i], pnt_expected) + testing.run_module(__name__, __file__) diff --git a/tests/transforms_tests/point_tests/test_resize_point.py b/tests/transforms_tests/point_tests/test_resize_point.py index a3fb7b172b..79ce01daff 100644 --- a/tests/transforms_tests/point_tests/test_resize_point.py +++ b/tests/transforms_tests/point_tests/test_resize_point.py @@ -8,14 +8,24 @@ class TestResizePoint(unittest.TestCase): - def test_resize_point(self): + def test_resize_point_ndarray(self): point = np.random.uniform( - low=0., high=32., size=(12, 2)) + low=0., high=32., size=(3, 12, 2)) out = resize_point(point, in_size=(16, 32), out_size=(8, 64)) - point[:, 0] *= 0.5 - point[:, 1] *= 2 + point[:, :, 0] *= 0.5 + point[:, :, 1] *= 2 np.testing.assert_equal(out, point) + def test_resize_point_list(self): + point = [np.random.uniform( + low=0., high=32., size=(12, 2))] + + out = resize_point(point, in_size=(16, 32), out_size=(8, 64)) + for i, pnt in enumerate(point): + pnt[:, 0] *= 0.5 + pnt[:, 1] *= 2 + np.testing.assert_equal(out[i], pnt) + testing.run_module(__name__, __file__) diff --git a/tests/transforms_tests/point_tests/test_translate_point.py b/tests/transforms_tests/point_tests/test_translate_point.py index 1030bf22cb..8851d13e3d 100644 --- a/tests/transforms_tests/point_tests/test_translate_point.py +++ b/tests/transforms_tests/point_tests/test_translate_point.py @@ -8,15 +8,26 @@ class TestTranslatePoint(unittest.TestCase): - def test_translate_point(self): + def test_translate_point_ndarray(self): point = np.random.uniform( - low=0., high=32., size=(10, 2)) + low=0., high=32., size=(3, 10, 2)) out = translate_point(point, y_offset=3, x_offset=5) expected = np.empty_like(point) - expected[:, 0] = point[:, 0] + 3 - expected[:, 1] = point[:, 1] + 5 + expected[:, :, 0] = point[:, :, 0] + 3 + expected[:, :, 1] = point[:, :, 1] + 5 np.testing.assert_equal(out, expected) + def test_translate_point_list(self): + point = [np.random.uniform( + low=0., high=32., size=(10, 2))] + + out = translate_point(point, y_offset=3, x_offset=5) + for i, pnt in enumerate(point): + expected = np.empty_like(pnt) + expected[:, 0] = pnt[:, 0] + 3 + expected[:, 1] = pnt[:, 1] + 5 + np.testing.assert_equal(out[i], expected) + testing.run_module(__name__, __file__) diff --git a/tests/visualizations_tests/test_vis_keypoint_coco.py b/tests/visualizations_tests/test_vis_keypoint_coco.py new file mode 100644 index 0000000000..97c2f09a9b --- /dev/null +++ b/tests/visualizations_tests/test_vis_keypoint_coco.py @@ -0,0 +1,101 @@ +import unittest + +import numpy as np + +from chainer import testing + +from chainercv.datasets import coco_keypoint_names +from chainercv.visualizations import vis_keypoint_coco + +try: + import matplotlib # NOQA + _available = True +except ImportError: + _available = False + + +human_id = 0 + + +def _generate_point(n_inst, size): + H, W = size + n_joint = len(coco_keypoint_names[human_id]) + ys = np.random.uniform(0, H, size=(n_inst, n_joint)) + xs = np.random.uniform(0, W, size=(n_inst, n_joint)) + point = np.stack((ys, xs), axis=2).astype(np.float32) + + visible = np.random.randint(0, 2, size=(n_inst, n_joint)).astype(np.bool) + + point_score = np.random.uniform( + 0, 6, size=(n_inst, n_joint)).astype(np.float32) + return point, visible, point_score + + +@testing.parameterize(*testing.product({ + 'n_inst': [3, 0], + 'use_img': [False, True], + 'use_visible': [False, True], + 'use_point_score': [False, True] +})) +@unittest.skipUnless(_available, 'matplotlib is not installed') +class TestVisKeypointCOCO(unittest.TestCase): + + def setUp(self): + size = (32, 48) + self.point, visible, point_score = _generate_point(self.n_inst, size) + self.img = (np.random.randint( + 0, 255, size=(3,) + size).astype(np.float32) + if self.use_img else None) + self.visible = visible if self.use_visible else None + self.point_score = point_score if self.use_point_score else None + + def test_vis_keypoint_coco(self): + ax = vis_keypoint_coco( + self.img, self.point, self.visible, + self.point_score) + + self.assertIsInstance(ax, matplotlib.axes.Axes) + + +@unittest.skipUnless(_available, 'matplotlib is not installed') +class TestVisKeypointCOCOInvisibleInputs(unittest.TestCase): + + def setUp(self): + size = (32, 48) + n_inst = 10 + self.point, self.visible, self.point_score = _generate_point( + n_inst, size) + self.img = np.random.randint( + 0, 255, size=(3,) + size).astype(np.float32) + + def _check(self, img, point, visible, point_score): + with self.assertRaises(ValueError): + vis_keypoint_coco(img, point, visible, point_score) + + def test_invisible_n_inst_point(self): + self._check(self.img, self.point[:5], self.visible, self.point_score) + + def test_invisible_n_inst_visible(self): + self._check(self.img, self.point, self.visible[:5], self.point_score) + + def test_invisible_n_inst_point_score(self): + self._check(self.img, self.point, self.visible, self.point_score[:5]) + + def test_invisible_n_joint_point(self): + self._check( + self.img, self.point[:, :15], self.visible, self.point_score) + + def test_invisible_n_joint_visible(self): + self._check( + self.img, self.point, self.visible[:, :15], self.point_score) + + def test_invisible_n_joint_point_score(self): + self._check( + self.img, self.point, self.visible, self.point_score[:, :15]) + + def test_invisible_visible_dtype(self): + self._check(self.img, self.point, self.visible.astype(np.int32), + self.point_score) + + +testing.run_module(__name__, __file__)