diff --git a/README.md b/README.md index 7d49cb232..b8dae1ea6 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,13 @@ A simple, fully convolutional model for real-time instance segmentation. This is the code for [our paper](https://arxiv.org/abs/1904.02689). +#### YOLACT++ implementation and models released! +YOLACT++ resnet50 model runs at 33.5 fps on a Titan Xp and achieves 34.1 mAP on COCO's `test-dev`. + +Related paper will be posted on arXiv soon. + +In order to use YOLACT++, make sure you compile the DCNv2 code. (See [Installation](https://github.com/dbolya/yolact#installation)) + #### ICCV update (v1.1) released! Check out the ICCV trailer here: [![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/0pMfmo8qfpQ/0.jpg)](https://www.youtube.com/watch?v=0pMfmo8qfpQ) @@ -37,6 +44,11 @@ Some examples from our base model (33.5 fps on a Titan Xp and 29.8 mAP on COCO's git clone https://github.com/dbolya/yolact.git cd yolact ``` + - Compile deformable convolutional layers (from [DCNv2](https://github.com/CharlesShang/DCNv2/tree/pytorch_1.0)) + ```Shell + cd external/DCNv2 + ./make.sh + ``` - If you'd like to train YOLACT, download the COCO dataset and the 2014/2017 annotations. Note that this script will take a while and dump 21gb of files into `./data/coco`. ```Shell sh data/scripts/COCO.sh @@ -57,6 +69,13 @@ As of April 5th, 2019 here are our latest models along with their FPS on a Titan | 550 | Resnet101-FPN | 33.0 | 29.8 | [yolact_base_54_800000.pth](https://drive.google.com/file/d/1UYy3dMapbH1BnmtZU4WH1zbYgOzzHHf_/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EYRWxBEoKU9DiblrWx2M89MBGFkVVB_drlRd_v5sdT3Hgg) | 700 | Resnet101-FPN | 23.6 | 31.2 | [yolact_im700_54_800000.pth](https://drive.google.com/file/d/1lE4Lz5p25teiXV-6HdTiOJSnS7u7GBzg/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/Eagg5RSc5hFEhp7sPtvLNyoBjhlf2feog7t8OQzHKKphjw) +YOLACT++ models (released on Dec. 6th, 2019): + +| Image Size | Backbone | FPS | mAP | Weights | | +|:----------:|:-------------:|:----:|:----:|----------------------------------------------------------------------------------------------------------------------|--------| +| 550 | Resnet50-FPN | 33.5 | 34.1 | [yolact_plus_resnet50_54_800000.pth](https://drive.google.com/file/d/1ZPu1YR2UzGHQD0o1rEqy-j5bmEm3lbyP/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EcJAtMiEFlhAnVsDf00yWRIBUC4m8iE9NEEiV05XwtEoGw) | +| 550 | Resnet101-FPN | 27.3 | 34.6 | [yolact_plus_base_54_800000.pth](https://drive.google.com/file/d/15id0Qq5eqRbkD-N3ZjDZXdCvRyIaHpFB/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EVQ62sF0SrJPrl_68onyHF8BpG7c05A8PavV4a849sZgEA) + To evalute the model, put the corresponding weights file in the `./weights` directory and run one of the following commands. ## Quantitative Results on COCO ```Shell diff --git a/backbone.py b/backbone.py index e7360423e..100e8ff02 100644 --- a/backbone.py +++ b/backbone.py @@ -4,16 +4,25 @@ from collections import OrderedDict +from dcn_v2 import DCN + class Bottleneck(nn.Module): """ Adapted from torchvision.models.resnet """ expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=nn.BatchNorm2d, dilation=1): + def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=nn.BatchNorm2d, dilation=1, use_dcn=False): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, dilation=dilation) self.bn1 = norm_layer(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=dilation, bias=False, dilation=dilation) + if use_dcn: + self.conv2 = DCN(planes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, deformable_groups=1) + self.conv2.bias.data.zero_() + self.conv2.conv_offset_mask.weight.data.zero_() + self.conv2.conv_offset_mask.bias.data.zero_() + else: + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=dilation, bias=False, dilation=dilation) self.bn2 = norm_layer(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False, dilation=dilation) self.bn3 = norm_layer(planes * 4) @@ -47,7 +56,7 @@ def forward(self, x): class ResNetBackbone(nn.Module): """ Adapted from torchvision.models.resnet """ - def __init__(self, layers, atrous_layers=[], block=Bottleneck, norm_layer=nn.BatchNorm2d): + def __init__(self, layers, dcn_layers=[0, 0, 0, 0], dcn_interval=1, atrous_layers=[], block=Bottleneck, norm_layer=nn.BatchNorm2d): super().__init__() # These will be populated by _make_layer @@ -66,10 +75,10 @@ def __init__(self, layers, atrous_layers=[], block=Bottleneck, norm_layer=nn.Bat self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self._make_layer(block, 64, layers[0]) - self._make_layer(block, 128, layers[1], stride=2) - self._make_layer(block, 256, layers[2], stride=2) - self._make_layer(block, 512, layers[3], stride=2) + self._make_layer(block, 64, layers[0], dcn_layers=dcn_layers[0], dcn_interval=dcn_interval) + self._make_layer(block, 128, layers[1], stride=2, dcn_layers=dcn_layers[1], dcn_interval=dcn_interval) + self._make_layer(block, 256, layers[2], stride=2, dcn_layers=dcn_layers[2], dcn_interval=dcn_interval) + self._make_layer(block, 512, layers[3], stride=2, dcn_layers=dcn_layers[3], dcn_interval=dcn_interval) # This contains every module that should be initialized by loading in pretrained weights. # Any extra layers added onto this that won't be initialized by init_backbone will not be @@ -78,7 +87,7 @@ def __init__(self, layers, atrous_layers=[], block=Bottleneck, norm_layer=nn.Bat self.backbone_modules = [m for m in self.modules() if isinstance(m, nn.Conv2d)] - def _make_layer(self, block, planes, blocks, stride=1): + def _make_layer(self, block, planes, blocks, stride=1, dcn_layers=0, dcn_interval=1): """ Here one layer means a string of n Bottleneck blocks. """ downsample = None @@ -97,11 +106,12 @@ def _make_layer(self, block, planes, blocks, stride=1): ) layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, self.norm_layer, self.dilation)) + use_dcn = (dcn_layers >= blocks) + layers.append(block(self.inplanes, planes, stride, downsample, self.norm_layer, self.dilation, use_dcn=use_dcn)) self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block(self.inplanes, planes, norm_layer=self.norm_layer)) - + use_dcn = ((i+dcn_layers) >= blocks) and (i % dcn_interval == 0) + layers.append(block(self.inplanes, planes, norm_layer=self.norm_layer, use_dcn=use_dcn)) layer = nn.Sequential(*layers) self.channels.append(planes * block.expansion) diff --git a/data/config.py b/data/config.py index 6ebacab45..b945dd0fb 100644 --- a/data/config.py +++ b/data/config.py @@ -247,6 +247,11 @@ def print(self): 'pred_aspect_ratios': [ [[0.66685089, 1.7073535, 0.87508774, 1.16524493, 0.49059086]] ] * 6, }) +resnet101_dcn_inter3_backbone = resnet101_backbone.copy({ + 'name': 'ResNet101_DCN_Interval3', + 'args': ([3, 4, 23, 3], [0, 4, 23, 3], 3), +}) + resnet50_backbone = resnet101_backbone.copy({ 'name': 'ResNet50', 'path': 'resnet50-19c8e357.pth', @@ -255,6 +260,11 @@ def print(self): 'transform': resnet_transform, }) +resnet50_dcnv2_backbone = resnet50_backbone.copy({ + 'name': 'ResNet50_DCNv2', + 'args': ([3, 4, 6, 3], [0, 4, 6, 3]), +}) + darknet53_backbone = backbone_base.copy({ 'name': 'DarkNet53', 'path': 'darknet53.pth', @@ -614,6 +624,19 @@ def print(self): 'backbone': None, 'name': 'base_config', + + # Fast Mask Re-scoring Network + # Inspried by Mask Scoring R-CNN (https://arxiv.org/abs/1903.00241) + # Do not crop out the mask with bbox but slide a convnet on the image-size mask, + # then use global pooling to get the final mask score + 'use_maskiou': False, + 'maskiou_net': [], + 'remove_small_gt_mask': -1, + + 'maskiou_alpha': 1.0, + 'rescore_mask': False, + 'rescore_bbox': False, + 'maskious_to_train': -1, }) @@ -736,6 +759,44 @@ def print(self): }) }) +# ----------------------- YOLACT++ CONFIGS ----------------------- # + +yolact_plus_base_config = yolact_base_config.copy({ + 'name': 'yolact_plus_base', + + 'backbone': resnet101_dcn_inter3_backbone.copy({ + 'selected_layers': list(range(1, 4)), + + 'pred_aspect_ratios': [ [[1, 1/2, 2]] ]*5, + 'pred_scales': [[i * 2 ** (j / 3.0) for j in range(3)] for i in [24, 48, 96, 192, 384]], + 'use_pixel_scales': True, + 'preapply_sqrt': False, + 'use_square_anchors': False, + }), + + 'use_maskiou': True, + 'maskiou_net': [(8, 3, {'stride': 2}), (16, 3, {'stride': 2}), (32, 3, {'stride': 2}), (64, 3, {'stride': 2}), (128, 3, {'stride': 2}), (80, 1, {})], + 'maskiou_alpha': 25, + 'rescore_bbox': False, + 'rescore_mask': True, + + 'remove_small_gt_mask': 5*5, +}) + +yolact_plus_resnet50_config = yolact_plus_base_config.copy({ + 'name': 'yolact_plus_resnet50', + + 'backbone': resnet50_dcnv2_backbone.copy({ + 'selected_layers': list(range(1, 4)), + + 'pred_aspect_ratios': [ [[1, 1/2, 2]] ]*5, + 'pred_scales': [[i * 2 ** (j / 3.0) for j in range(3)] for i in [24, 48, 96, 192, 384]], + 'use_pixel_scales': True, + 'preapply_sqrt': False, + 'use_square_anchors': False, + }), +}) + # Default config cfg = yolact_base_config.copy() diff --git a/eval.py b/eval.py index d26400856..f569b690e 100644 --- a/eval.py +++ b/eval.py @@ -132,7 +132,7 @@ def parse_args(argv=None): coco_cats_inv = {} color_cache = defaultdict(lambda: {}) -def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str=''): +def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str='', maskiou_net=None): """ Note: If undo_transform=False then im_h and im_w are allowed to be None. """ @@ -146,14 +146,34 @@ def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, ma with timer.env('Postprocess'): t = postprocess(dets_out, w, h, visualize_lincomb = args.display_lincomb, crop_masks = args.crop, - score_threshold = args.score_threshold) + score_threshold = args.score_threshold, + maskiou_net = maskiou_net) torch.cuda.synchronize() + # FIXME reduce copy with timer.env('Copy'): if cfg.eval_mask_branch: # Masks are drawn on the GPU, so don't copy - masks = t[3][:args.top_k] - classes, scores, boxes = [x[:args.top_k].cpu().numpy() for x in t[:3]] + masks = t[3] + classes, scores, boxes = [x for x in t[:3]] + if isinstance(scores, list): + box_scores = scores[0].cpu().numpy() + mask_scores = scores[1].cpu().numpy() + # Re-rank predictions by mask scores + _scores = mask_scores * box_scores + idx = np.argsort(-_scores) + scores = box_scores[idx] + classes = classes.cpu().numpy()[idx] + boxes = boxes.cpu().numpy()[idx] + masks = masks[idx] + else: + scores = scores.cpu().numpy() + classes = classes.cpu().numpy() + boxes = boxes.cpu().numpy() + scores = scores[:args.top_k] + classes = classes[:args.top_k] + boxes = boxes[:args.top_k] + masks = masks[:args.top_k] num_dets_to_consider = min(args.top_k, classes.shape[0]) for j in range(num_dets_to_consider): @@ -257,12 +277,20 @@ def get_color(j, on_gpu=None): return img_numpy -def prep_benchmark(dets_out, h, w): +def prep_benchmark(dets_out, h, w, maskiou_net=None): with timer.env('Postprocess'): - t = postprocess(dets_out, w, h, crop_masks=args.crop, score_threshold=args.score_threshold) + t = postprocess(dets_out, w, h, crop_masks=args.crop, score_threshold=args.score_threshold, maskiou_net=maskiou_net) with timer.env('Copy'): - classes, scores, boxes, masks = [x[:args.top_k].cpu().numpy() for x in t] + classes, scores, boxes, masks = [x[:args.top_k] for x in t] + if isinstance(scores, list): + box_scores = scores[0].cpu().numpy() + mask_scores = scores[1].cpu().numpy() + else: + scores = scores.cpu().numpy() + classes = classes.cpu().numpy() + boxes = boxes.cpu().numpy() + masks = masks.cpu().numpy() with timer.env('Sync'): # Just in case @@ -371,7 +399,7 @@ def _bbox_iou(bbox1, bbox2, iscrowd=False): ret = jaccard(bbox1, bbox2, iscrowd) return ret.cpu() -def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, detections:Detections=None): +def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, detections:Detections=None, maskiou_net=None): """ Returns a list of APs for this image, with each element being for a class """ if not args.output_coco_json: with timer.env('Prepare gt'): @@ -388,13 +416,19 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de crowd_classes, gt_classes = split(gt_classes) with timer.env('Postprocess'): - classes, scores, boxes, masks = postprocess(dets, w, h, crop_masks=args.crop, score_threshold=args.score_threshold) + classes, scores, boxes, masks = postprocess(dets, w, h, crop_masks=args.crop, score_threshold=args.score_threshold, maskiou_net=maskiou_net) if classes.size(0) == 0: return classes = list(classes.cpu().numpy().astype(int)) - scores = list(scores.cpu().numpy().astype(float)) + if isinstance(scores, list): + box_scores = list(scores[0].cpu().numpy().astype(float)) + mask_scores = list(scores[1].cpu().numpy().astype(float)) + else: + scores = list(scores.cpu().numpy().astype(float)) + box_scores = scores + mask_scores = scores masks = masks.view(-1, h*w).cuda() boxes = boxes.cuda() @@ -406,8 +440,8 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de for i in range(masks.shape[0]): # Make sure that the bounding box actually makes sense and a mask was produced if (boxes[i, 3] - boxes[i, 1]) * (boxes[i, 2] - boxes[i, 0]) > 0: - detections.add_bbox(image_id, classes[i], boxes[i,:], scores[i]) - detections.add_mask(image_id, classes[i], masks[i,:,:], scores[i]) + detections.add_bbox(image_id, classes[i], boxes[i,:], box_scores[i]) + detections.add_mask(image_id, classes[i], masks[i,:,:], mask_scores[i]) return with timer.env('Eval Setup'): @@ -425,8 +459,8 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de crowd_bbox_iou_cache = None iou_types = [ - ('box', lambda i,j: bbox_iou_cache[i, j].item(), lambda i,j: crowd_bbox_iou_cache[i,j].item()), - ('mask', lambda i,j: mask_iou_cache[i, j].item(), lambda i,j: crowd_mask_iou_cache[i,j].item()) + ('box', lambda i,j: bbox_iou_cache[i, j].item(), lambda i,j: crowd_bbox_iou_cache[i,j].item(), lambda i: box_scores[i]), + ('mask', lambda i,j: mask_iou_cache[i, j].item(), lambda i,j: crowd_mask_iou_cache[i,j].item(), lambda i: mask_scores[i]) ] timer.start('Main loop') @@ -437,7 +471,7 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de for iouIdx in range(len(iou_thresholds)): iou_threshold = iou_thresholds[iouIdx] - for iou_type, iou_func, crowd_func in iou_types: + for iou_type, iou_func, crowd_func, score_func in iou_types: gt_used = [False] * len(gt_classes) ap_obj = ap_data[iou_type][iouIdx][_class] @@ -461,7 +495,7 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de if max_match_idx >= 0: gt_used[max_match_idx] = True - ap_obj.push(scores[i], True) + ap_obj.push(score_func(i), True) else: # If the detection matches a crowd, we can just ignore it matched_crowd = False @@ -481,7 +515,7 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de # same result as COCOEval. There aren't even that many crowd annotations to # begin with, but accuracy is of the utmost importance. if not matched_crowd: - ap_obj.push(scores[i], False) + ap_obj.push(score_func(i), False) timer.stop('Main loop') @@ -846,6 +880,7 @@ def evaluate(net:Yolact, dataset, train_mode=False): net.detect.use_cross_class_nms = args.cross_class_nms cfg.mask_proto_debug = args.mask_proto_debug + # TODO Currently we do not support Fast Mask Re-scroing in evalimage, evalimages, and evalvideo if args.image is not None: if ':' in args.image: inp, out = args.image.split(':') @@ -921,13 +956,14 @@ def evaluate(net:Yolact, dataset, train_mode=False): with timer.env('Network Extra'): preds = net(batch) + maskiou_net = net.get_maskiou_net() # Perform the meat of the operation here depending on our mode. if args.display: - img_numpy = prep_display(preds, img, h, w) + img_numpy = prep_display(preds, img, h, w, maskiou_net=maskiou_net) elif args.benchmark: - prep_benchmark(preds, h, w) + prep_benchmark(preds, h, w, maskiou_net=maskiou_net) else: - prep_metrics(ap_data, preds, img, gt, gt_masks, h, w, num_crowd, dataset.ids[image_idx], detections) + prep_metrics(ap_data, preds, img, gt, gt_masks, h, w, num_crowd, dataset.ids[image_idx], detections, maskiou_net=maskiou_net) # First couple of images take longer because we're constructing the graph. # Since that's technically initialization, don't include those in the FPS calculations. diff --git a/external/DCNv2/LICENSE b/external/DCNv2/LICENSE new file mode 100644 index 000000000..b2e3b5207 --- /dev/null +++ b/external/DCNv2/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2019, Charles Shang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/external/DCNv2/README.md b/external/DCNv2/README.md new file mode 100644 index 000000000..9787cfa2c --- /dev/null +++ b/external/DCNv2/README.md @@ -0,0 +1,65 @@ +## Deformable Convolutional Networks V2 with Pytorch 1.0 + +### Build +```bash + ./make.sh # build + python test.py # run examples and gradient check +``` + +### An Example +- deformable conv +```python + from dcn_v2 import DCN + input = torch.randn(2, 64, 128, 128).cuda() + # wrap all things (offset and mask) in DCN + dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda() + output = dcn(input) + print(output.shape) +``` +- deformable roi pooling +```python + from dcn_v2 import DCNPooling + input = torch.randn(2, 32, 64, 64).cuda() + batch_inds = torch.randint(2, (20, 1)).cuda().float() + x = torch.randint(256, (20, 1)).cuda().float() + y = torch.randint(256, (20, 1)).cuda().float() + w = torch.randint(64, (20, 1)).cuda().float() + h = torch.randint(64, (20, 1)).cuda().float() + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + + # mdformable pooling (V2) + # wrap all things (offset and mask) in DCNPooling + dpooling = DCNPooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1).cuda() + + dout = dpooling(input, rois) +``` +### Note +Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with, +```bash +git checkout pytorch_0.4 +``` + +### Known Issues: + +- [x] Gradient check w.r.t offset (solved) +- [ ] Backward is not reentrant (minor) + +This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). + +I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes. +However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some +non-differential points? + +Update: all gradient check passes with double precision. + +Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for +float `<1e-15` for double), +so it may not be a serious problem (?) + +Please post an issue or PR if you have any comments. + \ No newline at end of file diff --git a/external/DCNv2/__init__.py b/external/DCNv2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/DCNv2/dcn_v2.py b/external/DCNv2/dcn_v2.py new file mode 100644 index 000000000..982bef512 --- /dev/null +++ b/external/DCNv2/dcn_v2.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import math +import torch +from torch import nn +from torch.autograd import Function +from torch.nn.modules.utils import _pair +from torch.autograd.function import once_differentiable + +import _ext as _backend + + +class _DCNv2(Function): + @staticmethod + def forward(ctx, input, offset, mask, weight, bias, + stride, padding, dilation, deformable_groups): + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.kernel_size = _pair(weight.shape[2:4]) + ctx.deformable_groups = deformable_groups + output = _backend.dcn_v2_forward(input, weight, bias, + offset, mask, + ctx.kernel_size[0], ctx.kernel_size[1], + ctx.stride[0], ctx.stride[1], + ctx.padding[0], ctx.padding[1], + ctx.dilation[0], ctx.dilation[1], + ctx.deformable_groups) + ctx.save_for_backward(input, offset, mask, weight, bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \ + _backend.dcn_v2_backward(input, weight, + bias, + offset, mask, + grad_output, + ctx.kernel_size[0], ctx.kernel_size[1], + ctx.stride[0], ctx.stride[1], + ctx.padding[0], ctx.padding[1], + ctx.dilation[0], ctx.dilation[1], + ctx.deformable_groups) + + return grad_input, grad_offset, grad_mask, grad_weight, grad_bias,\ + None, None, None, None, + + +dcn_v2_conv = _DCNv2.apply + + +class DCNv2(nn.Module): + + def __init__(self, in_channels, out_channels, + kernel_size, stride, padding, dilation=1, deformable_groups=1): + super(DCNv2, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.deformable_groups = deformable_groups + + self.weight = nn.Parameter(torch.Tensor( + out_channels, in_channels, *self.kernel_size)) + self.bias = nn.Parameter(torch.Tensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + self.bias.data.zero_() + + def forward(self, input, offset, mask): + assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ + offset.shape[1] + assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ + mask.shape[1] + return dcn_v2_conv(input, offset, mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.deformable_groups) + + +class DCN(DCNv2): + + def __init__(self, in_channels, out_channels, + kernel_size, stride, padding, + dilation=1, deformable_groups=1): + super(DCN, self).__init__(in_channels, out_channels, + kernel_size, stride, padding, dilation, deformable_groups) + + channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] + self.conv_offset_mask = nn.Conv2d(self.in_channels, + channels_, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset_mask.weight.data.zero_() + self.conv_offset_mask.bias.data.zero_() + + def forward(self, input): + out = self.conv_offset_mask(input) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return dcn_v2_conv(input, offset, mask, + self.weight, self.bias, + self.stride, + self.padding, + self.dilation, + self.deformable_groups) + + + +class _DCNv2Pooling(Function): + @staticmethod + def forward(ctx, input, rois, offset, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0): + ctx.spatial_scale = spatial_scale + ctx.no_trans = int(no_trans) + ctx.output_dim = output_dim + ctx.group_size = group_size + ctx.pooled_size = pooled_size + ctx.part_size = pooled_size if part_size is None else part_size + ctx.sample_per_part = sample_per_part + ctx.trans_std = trans_std + + output, output_count = \ + _backend.dcn_v2_psroi_pooling_forward(input, rois, offset, + ctx.no_trans, ctx.spatial_scale, + ctx.output_dim, ctx.group_size, + ctx.pooled_size, ctx.part_size, + ctx.sample_per_part, ctx.trans_std) + ctx.save_for_backward(input, rois, offset, output_count) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, rois, offset, output_count = ctx.saved_tensors + grad_input, grad_offset = \ + _backend.dcn_v2_psroi_pooling_backward(grad_output, + input, + rois, + offset, + output_count, + ctx.no_trans, + ctx.spatial_scale, + ctx.output_dim, + ctx.group_size, + ctx.pooled_size, + ctx.part_size, + ctx.sample_per_part, + ctx.trans_std) + + return grad_input, None, grad_offset, \ + None, None, None, None, None, None, None, None + + +dcn_v2_pooling = _DCNv2Pooling.apply + + +class DCNv2Pooling(nn.Module): + + def __init__(self, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0): + super(DCNv2Pooling, self).__init__() + self.spatial_scale = spatial_scale + self.pooled_size = pooled_size + self.output_dim = output_dim + self.no_trans = no_trans + self.group_size = group_size + self.part_size = pooled_size if part_size is None else part_size + self.sample_per_part = sample_per_part + self.trans_std = trans_std + + def forward(self, input, rois, offset): + assert input.shape[1] == self.output_dim + if self.no_trans: + offset = input.new() + return dcn_v2_pooling(input, rois, offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + self.no_trans, + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std) + + +class DCNPooling(DCNv2Pooling): + + def __init__(self, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0, + deform_fc_dim=1024): + super(DCNPooling, self).__init__(spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size, + part_size, + sample_per_part, + trans_std) + + self.deform_fc_dim = deform_fc_dim + + if not no_trans: + self.offset_mask_fc = nn.Sequential( + nn.Linear(self.pooled_size * self.pooled_size * + self.output_dim, self.deform_fc_dim), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_dim, self.deform_fc_dim), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_dim, self.pooled_size * + self.pooled_size * 3) + ) + self.offset_mask_fc[4].weight.data.zero_() + self.offset_mask_fc[4].bias.data.zero_() + + def forward(self, input, rois): + offset = input.new() + + if not self.no_trans: + + # do roi_align first + n = rois.shape[0] + roi = dcn_v2_pooling(input, rois, offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + True, # no trans + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std) + + # build mask and offset + offset_mask = self.offset_mask_fc(roi.view(n, -1)) + offset_mask = offset_mask.view( + n, 3, self.pooled_size, self.pooled_size) + o1, o2, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + # do pooling with offset and mask + return dcn_v2_pooling(input, rois, offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + self.no_trans, + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std) * mask + # only roi_align + return dcn_v2_pooling(input, rois, offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + self.no_trans, + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std) diff --git a/external/DCNv2/make.sh b/external/DCNv2/make.sh new file mode 100755 index 000000000..f1f15c0e3 --- /dev/null +++ b/external/DCNv2/make.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +python setup.py build develop diff --git a/external/DCNv2/setup.py b/external/DCNv2/setup.py new file mode 100644 index 000000000..108249428 --- /dev/null +++ b/external/DCNv2/setup.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError('Cuda is not availabel') + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "_ext", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + +setup( + name="DCNv2", + version="0.1", + author="charlesshang", + url="https://github.com/charlesshang/DCNv2", + description="deformable convolutional networks", + packages=find_packages(exclude=("configs", "tests",)), + # install_requires=requirements, + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) \ No newline at end of file diff --git a/external/DCNv2/src/cpu/dcn_v2_cpu.cpp b/external/DCNv2/src/cpu/dcn_v2_cpu.cpp new file mode 100644 index 000000000..a68ccef8e --- /dev/null +++ b/external/DCNv2/src/cpu/dcn_v2_cpu.cpp @@ -0,0 +1,74 @@ +#include + +#include +#include + + +at::Tensor +dcn_v2_cpu_forward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int deformable_group) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +dcn_v2_cpu_backward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const at::Tensor &grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + AT_ERROR("Not implement on cpu"); +} + +std::tuple +dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + AT_ERROR("Not implement on cpu"); +} + +std::tuple +dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, + const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const at::Tensor &top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + AT_ERROR("Not implement on cpu"); +} \ No newline at end of file diff --git a/external/DCNv2/src/cpu/vision.h b/external/DCNv2/src/cpu/vision.h new file mode 100644 index 000000000..d5fbf1f07 --- /dev/null +++ b/external/DCNv2/src/cpu/vision.h @@ -0,0 +1,60 @@ +#pragma once +#include + +at::Tensor +dcn_v2_cpu_forward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int deformable_group); + +std::vector +dcn_v2_cpu_backward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const at::Tensor &grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); + + +std::tuple +dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); + +std::tuple +dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, + const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const at::Tensor &top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); \ No newline at end of file diff --git a/external/DCNv2/src/cuda/dcn_v2_cuda.cu b/external/DCNv2/src/cuda/dcn_v2_cuda.cu new file mode 100644 index 000000000..767ed8fb1 --- /dev/null +++ b/external/DCNv2/src/cuda/dcn_v2_cuda.cu @@ -0,0 +1,335 @@ +#include +#include "cuda/dcn_v2_im2col_cuda.h" + +#include +#include + +#include +#include +#include + +extern THCState *state; + +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +// [batch gemm] +// https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu + +__global__ void createBatchGemmBuffer(const float **input_b, float **output_b, + float **columns_b, const float **ones_b, + const float **weight_b, const float **bias_b, + float *input, float *output, + float *columns, float *ones, + float *weight, float *bias, + const int input_stride, const int output_stride, + const int columns_stride, const int ones_stride, + const int num_batches) +{ + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num_batches) + { + input_b[idx] = input + idx * input_stride; + output_b[idx] = output + idx * output_stride; + columns_b[idx] = columns + idx * columns_stride; + ones_b[idx] = ones + idx * ones_stride; + // share weights and bias within a Mini-Batch + weight_b[idx] = weight; + bias_b[idx] = bias; + } +} + +at::Tensor +dcn_v2_cuda_forward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int deformable_group) +{ + using scalar_t = float; + // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); + AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); + AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); + // printf("Channels: %d %d\n", channels, channels_kernel); + // printf("Channels: %d %d\n", channels_out, channels_kernel); + + AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, + "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); + + AT_ASSERTM(channels == channels_kernel, + "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + auto ones = at::ones({batch, height_out, width_out}, input.options()); + auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); + auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); + + // prepare for batch-wise computing, which is significantly faster than instance-wise computing + // when batch size is large. + // launch batch threads + int matrices_size = batch * sizeof(float *); + auto input_b = static_cast(THCudaMalloc(state, matrices_size)); + auto output_b = static_cast(THCudaMalloc(state, matrices_size)); + auto columns_b = static_cast(THCudaMalloc(state, matrices_size)); + auto ones_b = static_cast(THCudaMalloc(state, matrices_size)); + auto weight_b = static_cast(THCudaMalloc(state, matrices_size)); + auto bias_b = static_cast(THCudaMalloc(state, matrices_size)); + + const int block = 128; + const int grid = (batch + block - 1) / block; + + createBatchGemmBuffer<<>>( + input_b, output_b, + columns_b, ones_b, + weight_b, bias_b, + input.data(), + output.data(), + columns.data(), + ones.data(), + weight.data(), + bias.data(), + channels * width * height, + channels_out * width_out * height_out, + channels * kernel_h * kernel_w * height_out * width_out, + height_out * width_out, + batch); + + long m_ = channels_out; + long n_ = height_out * width_out; + long k_ = 1; + THCudaBlas_SgemmBatched(state, + 't', + 'n', + n_, + m_, + k_, + 1.0f, + ones_b, k_, + bias_b, k_, + 0.0f, + output_b, n_, + batch); + + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + input.data(), + offset.data(), + mask.data(), + batch, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + deformable_group, + columns.data()); + + long m = channels_out; + long n = height_out * width_out; + long k = channels * kernel_h * kernel_w; + THCudaBlas_SgemmBatched(state, + 'n', + 'n', + n, + m, + k, + 1.0f, + (const float **)columns_b, n, + weight_b, k, + 1.0f, + output_b, n, + batch); + + THCudaFree(state, input_b); + THCudaFree(state, output_b); + THCudaFree(state, columns_b); + THCudaFree(state, ones_b); + THCudaFree(state, weight_b); + THCudaFree(state, bias_b); + return output; +} + +__global__ void createBatchGemmBufferBackward( + float **grad_output_b, + float **columns_b, + float **ones_b, + float **weight_b, + float **grad_weight_b, + float **grad_bias_b, + float *grad_output, + float *columns, + float *ones, + float *weight, + float *grad_weight, + float *grad_bias, + const int grad_output_stride, + const int columns_stride, + const int ones_stride, + const int num_batches) +{ + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num_batches) + { + grad_output_b[idx] = grad_output + idx * grad_output_stride; + columns_b[idx] = columns + idx * columns_stride; + ones_b[idx] = ones + idx * ones_stride; + + // share weights and bias within a Mini-Batch + weight_b[idx] = weight; + grad_weight_b[idx] = grad_weight; + grad_bias_b[idx] = grad_bias; + } +} + +std::vector dcn_v2_cuda_backward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const at::Tensor &grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + + THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); + THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); + + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); + AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); + AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, + "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); + + AT_ASSERTM(channels == channels_kernel, + "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + auto ones = at::ones({height_out, width_out}, input.options()); + auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); + auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); + + auto grad_input = at::zeros_like(input); + auto grad_weight = at::zeros_like(weight); + auto grad_bias = at::zeros_like(bias); + auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + + using scalar_t = float; + + for (int b = 0; b < batch; b++) + { + auto input_n = input.select(0, b); + auto offset_n = offset.select(0, b); + auto mask_n = mask.select(0, b); + auto grad_output_n = grad_output.select(0, b); + auto grad_input_n = grad_input.select(0, b); + auto grad_offset_n = grad_offset.select(0, b); + auto grad_mask_n = grad_mask.select(0, b); + + long m = channels * kernel_h * kernel_w; + long n = height_out * width_out; + long k = channels_out; + + THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, + grad_output_n.data(), n, + weight.data(), m, 0.0f, + columns.data(), n); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), + columns.data(), + input_n.data(), + offset_n.data(), + mask_n.data(), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + grad_offset_n.data(), + grad_mask_n.data()); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), + columns.data(), + offset_n.data(), + mask_n.data(), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + grad_input_n.data()); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and group + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + input_n.data(), + offset_n.data(), + mask_n.data(), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + columns.data()); + + long m_ = channels_out; + long n_ = channels * kernel_h * kernel_w; + long k_ = height_out * width_out; + + THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, + columns.data(), k_, + grad_output_n.data(), k_, 1.0f, + grad_weight.data(), n_); + + // gradient w.r.t. bias + // long m_ = channels_out; + // long k__ = height_out * width_out; + THCudaBlas_Sgemv(state, + 't', + k_, m_, 1.0f, + grad_output_n.data(), k_, + ones.data(), 1, 1.0f, + grad_bias.data(), 1); + } + + return { + grad_input, grad_offset, grad_mask, grad_weight, grad_bias + }; +} \ No newline at end of file diff --git a/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu b/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu new file mode 100644 index 000000000..4183793ba --- /dev/null +++ b/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu @@ -0,0 +1,402 @@ +#include "dcn_v2_im2col_cuda.h" +#include +#include +#include + +#include +#include + +#include +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + + +__device__ float dmcn_im2col_bilinear(const float *bottom_data, const int data_width, + const int height, const int width, float h, float w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h - h_low; + float lw = w - w_low; + float hh = 1 - lh, hw = 1 - lw; + + float v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + float v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + float v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + float v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +__device__ float dmcn_get_gradient_weight(float argmax_h, float argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + float weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +__device__ float dmcn_get_coordinate_weight(float argmax_h, float argmax_w, + const int height, const int width, const float *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + float weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const float *data_im, const float *data_offset, const float *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + float *data_col) +{ + // launch channels * batch_size * height_col * width_col cores + CUDA_KERNEL_LOOP(index, n) + { + // NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow) + // here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis + + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + // const int b_col = (index / width_col / height_col) % batch_size; + const int b_col = (index / width_col / height_col / num_channels) % batch_size; + // const int c_im = (index / width_col / height_col) / batch_size; + const int c_im = (index / width_col / height_col) % num_channels; + // const int c_col = c_im * kernel_h * kernel_w; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + // float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + float val = static_cast(0); + const float h_im = h_in + i * dilation_h + offset_h; + const float w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + // data_col_ptr += batch_size * height_col * width_col; + data_col_ptr += height_col * width_col; + } + } + } +} + +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const float *data_col, const float *data_offset, const float *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + float *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + const float cur_inv_h_data = h_in + i * dilation_h + offset_h; + const float cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const float cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + float weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const float *data_col, const float *data_im, + const float *data_offset, const float *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + float *grad_offset, float *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + float val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + float inv_h = h_in + i * dilation_h + offset_h; + float inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const float weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda(cudaStream_t stream, + const float* data_im, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float* data_col) { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + modulated_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +void modulated_deformable_col2im_cuda(cudaStream_t stream, + const float* data_col, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float* grad_im){ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + modulated_deformable_col2im_gpu_kernel + <<>>( + num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + float* grad_offset, float* grad_mask) { + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + modulated_deformable_col2im_coord_gpu_kernel + <<>>( + num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset, grad_mask); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.h b/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.h new file mode 100644 index 000000000..c85683198 --- /dev/null +++ b/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.h @@ -0,0 +1,101 @@ + +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.h + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1811.11168 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu + */ + +/***************** Adapted by Charles Shang *********************/ + +#ifndef DCN_V2_IM2COL_CUDA +#define DCN_V2_IM2COL_CUDA + +#ifdef __cplusplus +extern "C" +{ +#endif + + void modulated_deformable_im2col_cuda(cudaStream_t stream, + const float *data_im, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float *data_col); + + void modulated_deformable_col2im_cuda(cudaStream_t stream, + const float *data_col, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float *grad_im); + + void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + float *grad_offset, float *grad_mask); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/external/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu b/external/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu new file mode 100644 index 000000000..07b438e19 --- /dev/null +++ b/external/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu @@ -0,0 +1,419 @@ +/*! + * Copyright (c) 2017 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file deformable_psroi_pooling.cu + * \brief + * \author Yi Li, Guodong Zhang, Jifeng Dai +*/ +/***************** Adapted by Charles Shang *********************/ + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__device__ T bilinear_interp( + const T *data, + const T x, + const T y, + const int width, + const int height) +{ + int x1 = floor(x); + int x2 = ceil(x); + int y1 = floor(y); + int y2 = ceil(y); + T dist_x = static_cast(x - x1); + T dist_y = static_cast(y - y1); + T value11 = data[y1 * width + x1]; + T value12 = data[y2 * width + x1]; + T value21 = data[y1 * width + x2]; + T value22 = data[y2 * width + x2]; + T value = (1 - dist_x) * (1 - dist_y) * value11 + + (1 - dist_x) * dist_y * value12 + + dist_x * (1 - dist_y) * value21 + + dist_x * dist_y * value22; + return value; +} + +template +__global__ void DeformablePSROIPoolForwardKernel( + const int count, + const T *bottom_data, + const T spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const T *bottom_rois, const T *bottom_trans, + const int no_trans, + const T trans_std, + const int sample_per_part, + const int output_dim, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class, + T *top_data, + T *top_count) +{ + CUDA_KERNEL_LOOP(index, count) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const T *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + T roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); + + int part_h = floor(static_cast(ph) / pooled_height * part_size); + int part_w = floor(static_cast(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; + T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; + + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + T sum = 0; + int count = 0; + int gw = floor(static_cast(pw) * group_size / pooled_width); + int gh = floor(static_cast(ph) * group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + + const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + T w = wstart + iw * sub_bin_size_w; + T h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + int c = (ctop * group_size + gh) * group_size + gw; + T val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); + sum += val; + count++; + } + } + top_data[index] = count == 0 ? static_cast(0) : sum / count; + top_count[index] = count; + } +} + +template +__global__ void DeformablePSROIPoolBackwardAccKernel( + const int count, + const T *top_diff, + const T *top_count, + const int num_rois, + const T spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int output_dim, + T *bottom_data_diff, T *bottom_trans_diff, + const T *bottom_data, + const T *bottom_rois, + const T *bottom_trans, + const int no_trans, + const T trans_std, + const int sample_per_part, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class) +{ + CUDA_KERNEL_LOOP(index, count) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const T *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + T roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); + + int part_h = floor(static_cast(ph) / pooled_height * part_size); + int part_w = floor(static_cast(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; + T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; + + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + if (top_count[index] <= 0) + { + continue; + } + T diff_val = top_diff[index] / top_count[index]; + const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; + T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; + int gw = floor(static_cast(pw) * group_size / pooled_width); + int gh = floor(static_cast(ph) * group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + T w = wstart + iw * sub_bin_size_w; + T h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + int c = (ctop * group_size + gh) * group_size + gw; + // backward on feature + int x0 = floor(w); + int x1 = ceil(w); + int y0 = floor(h); + int y1 = ceil(h); + T dist_x = w - x0, dist_y = h - y0; + T q00 = (1 - dist_x) * (1 - dist_y); + T q01 = (1 - dist_x) * dist_y; + T q10 = dist_x * (1 - dist_y); + T q11 = dist_x * dist_y; + int bottom_index_base = c * height * width; + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); + + if (no_trans) + { + continue; + } + T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; + T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; + T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; + T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; + T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; + diff_x *= roi_width; + T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; + diff_y *= roi_height; + + atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); + atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); + } + } + } +} + +std::tuple +dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + const int num_bbox = bbox.size(0); + + AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); + auto pooled_height = pooled_size; + auto pooled_width = pooled_size; + + auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); + long out_size = num_bbox * output_dim * pooled_height * pooled_width; + auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); + + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (out.numel() == 0) + { + THCudaCheck(cudaGetLastError()); + return std::make_tuple(out, top_count); + } + + dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); + dim3 block(512); + + AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] { + DeformablePSROIPoolForwardKernel<<>>( + out_size, + input.contiguous().data(), + spatial_scale, + channels, + height, width, + pooled_height, + pooled_width, + bbox.contiguous().data(), + trans.contiguous().data(), + no_trans, + trans_std, + sample_per_part, + output_dim, + group_size, + part_size, + num_classes, + channels_each_class, + out.data(), + top_count.data()); + }); + THCudaCheck(cudaGetLastError()); + return std::make_tuple(out, top_count); +} + +std::tuple +dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, + const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const at::Tensor &top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor"); + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor"); + AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); + AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + const int num_bbox = bbox.size(0); + + AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); + auto pooled_height = pooled_size; + auto pooled_width = pooled_size; + long out_size = num_bbox * output_dim * pooled_height * pooled_width; + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options()); + auto trans_grad = at::zeros_like(trans); + + if (input_grad.numel() == 0) + { + THCudaCheck(cudaGetLastError()); + return std::make_tuple(input_grad, trans_grad); + } + + dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); + dim3 block(512); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] { + DeformablePSROIPoolBackwardAccKernel<<>>( + out_size, + out_grad.contiguous().data(), + top_count.contiguous().data(), + num_bbox, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + output_dim, + input_grad.contiguous().data(), + trans_grad.contiguous().data(), + input.contiguous().data(), + bbox.contiguous().data(), + trans.contiguous().data(), + no_trans, + trans_std, + sample_per_part, + group_size, + part_size, + num_classes, + channels_each_class); + }); + THCudaCheck(cudaGetLastError()); + return std::make_tuple(input_grad, trans_grad); +} \ No newline at end of file diff --git a/external/DCNv2/src/cuda/vision.h b/external/DCNv2/src/cuda/vision.h new file mode 100644 index 000000000..e42a2a79a --- /dev/null +++ b/external/DCNv2/src/cuda/vision.h @@ -0,0 +1,60 @@ +#pragma once +#include + +at::Tensor +dcn_v2_cuda_forward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int deformable_group); + +std::vector +dcn_v2_cuda_backward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const at::Tensor &grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); + + +std::tuple +dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); + +std::tuple +dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, + const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const at::Tensor &top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); \ No newline at end of file diff --git a/external/DCNv2/src/dcn_v2.h b/external/DCNv2/src/dcn_v2.h new file mode 100644 index 000000000..23f5caf2f --- /dev/null +++ b/external/DCNv2/src/dcn_v2.h @@ -0,0 +1,145 @@ +#pragma once + +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + +at::Tensor +dcn_v2_forward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int deformable_group) +{ + if (input.type().is_cuda()) + { +#ifdef WITH_CUDA + return dcn_v2_cuda_forward(input, weight, bias, offset, mask, + kernel_h, kernel_w, + stride_h, stride_w, + pad_h, pad_w, + dilation_h, dilation_w, + deformable_group); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +dcn_v2_backward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const at::Tensor &grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + if (input.type().is_cuda()) + { +#ifdef WITH_CUDA + return dcn_v2_cuda_backward(input, + weight, + bias, + offset, + mask, + grad_output, + kernel_h, kernel_w, + stride_h, stride_w, + pad_h, pad_w, + dilation_h, dilation_w, + deformable_group); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::tuple +dcn_v2_psroi_pooling_forward(const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + if (input.type().is_cuda()) + { +#ifdef WITH_CUDA + return dcn_v2_psroi_pooling_cuda_forward(input, + bbox, + trans, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::tuple +dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad, + const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const at::Tensor &top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + if (input.type().is_cuda()) + { +#ifdef WITH_CUDA + return dcn_v2_psroi_pooling_cuda_backward(out_grad, + input, + bbox, + trans, + top_count, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} \ No newline at end of file diff --git a/external/DCNv2/src/vision.cpp b/external/DCNv2/src/vision.cpp new file mode 100644 index 000000000..ff54233e0 --- /dev/null +++ b/external/DCNv2/src/vision.cpp @@ -0,0 +1,9 @@ + +#include "dcn_v2.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward"); + m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward"); + m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward"); + m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward"); +} diff --git a/external/DCNv2/test.py b/external/DCNv2/test.py new file mode 100644 index 000000000..3bd5bd223 --- /dev/null +++ b/external/DCNv2/test.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from dcn_v2 import dcn_v2_conv, DCNv2, DCN +from dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling + +deformable_groups = 1 +N, inC, inH, inW = 2, 2, 4, 4 +outC = 2 +kH, kW = 3, 3 + + +def conv_identify(weight, bias): + weight.data.zero_() + bias.data.zero_() + o, i, h, w = weight.shape + y = h//2 + x = w//2 + for p in range(i): + for q in range(o): + if p == q: + weight.data[q, p, y, x] = 1.0 + + +def check_zero_offset(): + conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True).cuda() + + conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True).cuda() + + dcn_v2 = DCNv2(inC, outC, (kH, kW), + stride=1, padding=1, dilation=1, + deformable_groups=deformable_groups).cuda() + + conv_offset.weight.data.zero_() + conv_offset.bias.data.zero_() + conv_mask.weight.data.zero_() + conv_mask.bias.data.zero_() + conv_identify(dcn_v2.weight, dcn_v2.bias) + + input = torch.randn(N, inC, inH, inW).cuda() + offset = conv_offset(input) + mask = conv_mask(input) + mask = torch.sigmoid(mask) + output = dcn_v2(input, offset, mask) + output *= 2 + d = (input - output).abs().max() + if d < 1e-10: + print('Zero offset passed') + else: + print('Zero offset failed') + print(input) + print(output) + +def check_gradient_dconv(): + + input = torch.rand(N, inC, inH, inW).cuda() * 0.01 + input.requires_grad = True + + offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2 + # offset.data.zero_() + # offset.data -= 0.5 + offset.requires_grad = True + + mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() + # mask.data.zero_() + mask.requires_grad = True + mask = torch.sigmoid(mask) + + weight = torch.randn(outC, inC, kH, kW).cuda() + weight.requires_grad = True + + bias = torch.rand(outC).cuda() + bias.requires_grad = True + + stride = 1 + padding = 1 + dilation = 1 + + print('check_gradient_dconv: ', + gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias, + stride, padding, dilation, deformable_groups), + eps=1e-3, atol=1e-4, rtol=1e-2)) + + +def check_pooling_zero_offset(): + + input = torch.randn(2, 16, 64, 64).cuda().zero_() + input[0, :, 16:26, 16:26] = 1. + input[1, :, 10:20, 20:30] = 2. + rois = torch.tensor([ + [0, 65, 65, 103, 103], + [1, 81, 41, 119, 79], + ]).cuda().float() + pooling = DCNv2Pooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=16, + no_trans=True, + group_size=1, + trans_std=0.0).cuda() + + out = pooling(input, rois, input.new()) + s = ', '.join(['%f' % out[i, :, :, :].mean().item() + for i in range(rois.shape[0])]) + print(s) + + dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=16, + no_trans=False, + group_size=1, + trans_std=0.0).cuda() + offset = torch.randn(20, 2, 7, 7).cuda().zero_() + dout = dpooling(input, rois, offset) + s = ', '.join(['%f' % dout[i, :, :, :].mean().item() + for i in range(rois.shape[0])]) + print(s) + + +def check_gradient_dpooling(): + input = torch.randn(2, 3, 5, 5).cuda() * 0.01 + N = 4 + batch_inds = torch.randint(2, (N, 1)).cuda().float() + x = torch.rand((N, 1)).cuda().float() * 15 + y = torch.rand((N, 1)).cuda().float() * 15 + w = torch.rand((N, 1)).cuda().float() * 10 + h = torch.rand((N, 1)).cuda().float() * 10 + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + offset = torch.randn(N, 2, 3, 3).cuda() + input.requires_grad = True + offset.requires_grad = True + + spatial_scale = 1.0 / 4 + pooled_size = 3 + output_dim = 3 + no_trans = 0 + group_size = 1 + trans_std = 0.0 + sample_per_part = 4 + part_size = pooled_size + + print('check_gradient_dpooling:', + gradcheck(dcn_v2_pooling, (input, rois, offset, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size, + part_size, + sample_per_part, + trans_std), + eps=1e-4)) + + +def example_dconv(): + input = torch.randn(2, 64, 128, 128).cuda() + # wrap all things (offset and mask) in DCN + dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, + padding=1, deformable_groups=2).cuda() + # print(dcn.weight.shape, input.shape) + output = dcn(input) + targert = output.new(*output.size()) + targert.data.uniform_(-0.01, 0.01) + error = (targert - output).mean() + error.backward() + print(output.shape) + + +def example_dpooling(): + input = torch.randn(2, 32, 64, 64).cuda() + batch_inds = torch.randint(2, (20, 1)).cuda().float() + x = torch.randint(256, (20, 1)).cuda().float() + y = torch.randint(256, (20, 1)).cuda().float() + w = torch.randint(64, (20, 1)).cuda().float() + h = torch.randint(64, (20, 1)).cuda().float() + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + offset = torch.randn(20, 2, 7, 7).cuda() + input.requires_grad = True + offset.requires_grad = True + + # normal roi_align + pooling = DCNv2Pooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=True, + group_size=1, + trans_std=0.1).cuda() + + # deformable pooling + dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1).cuda() + + out = pooling(input, rois, offset) + dout = dpooling(input, rois, offset) + print(out.shape) + print(dout.shape) + + target_out = out.new(*out.size()) + target_out.data.uniform_(-0.01, 0.01) + target_dout = dout.new(*dout.size()) + target_dout.data.uniform_(-0.01, 0.01) + e = (target_out - out).mean() + e.backward() + e = (target_dout - dout).mean() + e.backward() + + +def example_mdpooling(): + input = torch.randn(2, 32, 64, 64).cuda() + input.requires_grad = True + batch_inds = torch.randint(2, (20, 1)).cuda().float() + x = torch.randint(256, (20, 1)).cuda().float() + y = torch.randint(256, (20, 1)).cuda().float() + w = torch.randint(64, (20, 1)).cuda().float() + h = torch.randint(64, (20, 1)).cuda().float() + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + + # mdformable pooling (V2) + dpooling = DCNPooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1, + deform_fc_dim=1024).cuda() + + dout = dpooling(input, rois) + target = dout.new(*dout.size()) + target.data.uniform_(-0.1, 0.1) + error = (target - dout).mean() + error.backward() + print(dout.shape) + + +if __name__ == '__main__': + + example_dconv() + example_dpooling() + example_mdpooling() + + check_pooling_zero_offset() + # zero offset check + if inC == outC: + check_zero_offset() + + check_gradient_dpooling() + check_gradient_dconv() + # """ + # ****** Note: backward is not reentrant error may not be a serious problem, + # ****** since the max error is less than 1e-7, + # ****** Still looking for what trigger this problem + # """ diff --git a/layers/mask_score.py b/layers/mask_score.py new file mode 100644 index 000000000..ee912de42 --- /dev/null +++ b/layers/mask_score.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from data.config import cfg, mask_type +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +from .box_utils import match, crop +import cv2 +from datetime import datetime +import os + +from utils.functions import make_net + +class FastMaskIoUNet(nn.Module): + + def __init__(self): + super(FastMaskIoUNet, self).__init__() + input_channels = 1 + self.maskiou_net, _ = make_net(input_channels, cfg.maskiou_net, include_last_relu=True) + + def forward(self, x, target=None): + cudnn.benchmark = False + x = self.maskiou_net(x) + cudnn.benchmark = True + # global pooling + maskiou_p = F.max_pool2d(x, kernel_size=x.size()[2:]).squeeze(-1).squeeze(-1) + + if self.training: + maskiou_t = target[0] + label_t = target[1] + label_t = label_t[:, None] + maskiou_p = torch.gather(maskiou_p, dim=1, index=label_t).squeeze() + loss_i = F.smooth_l1_loss(maskiou_p, maskiou_t, reduction='mean') + return loss_i * cfg.maskiou_alpha + else: + return maskiou_p \ No newline at end of file diff --git a/layers/modules/multibox_loss.py b/layers/modules/multibox_loss.py index a83ffade6..d5b7148a6 100644 --- a/layers/modules/multibox_loss.py +++ b/layers/modules/multibox_loss.py @@ -156,8 +156,13 @@ def forward(self, predictions, targets, masks, num_crowds): else: losses['M'] = self.direct_mask_loss(pos_idx, idx_t, loc_data, mask_data, priors, masks) elif cfg.mask_type == mask_type.lincomb: - losses.update(self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data)) - + ret = self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels) + if cfg.use_maskiou: + loss, maskiou_targets = ret + else: + loss = ret + losses.update(loss) + if cfg.mask_proto_loss is not None: if cfg.mask_proto_loss == 'l1': losses['P'] = torch.mean(torch.abs(proto_data)) / self.l1_expected_area * self.l1_alpha @@ -201,7 +206,10 @@ def forward(self, predictions, targets, masks, num_crowds): # - D: Coefficient Diversity Loss # - E: Class Existence Loss # - S: Semantic Segmentation Loss - return losses + if cfg.use_maskiou: + return losses, maskiou_targets + else: + return losses def class_existence_loss(self, class_data, class_existence_t): return cfg.class_existence_alpha * F.binary_cross_entropy_with_logits(class_data, class_existence_t, reduction='sum') @@ -487,7 +495,7 @@ def coeff_diversity_loss(self, coeffs, instance_t): return cfg.mask_proto_coeff_diversity_alpha * loss.sum() / num_pos - def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, interpolation_mode='bilinear'): + def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'): mask_h = proto_data.size(1) mask_w = proto_data.size(2) @@ -500,6 +508,10 @@ def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, loss_m = 0 loss_d = 0 # Coefficient diversity loss + maskiou_t_list = [] + maskiou_net_input_list = [] + label_t_list = [] + for idx in range(mask_data.size(0)): with torch.no_grad(): downsampled_masks = F.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w), @@ -570,7 +582,8 @@ def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, mask_scores = mask_scores[select, :] num_pos = proto_coef.size(0) - mask_t = downsampled_masks[:, :, pos_idx_t] + mask_t = downsampled_masks[:, :, pos_idx_t] + label_t = labels[idx][pos_idx_t] # Size: [mask_h, mask_w, num_pos] pred_masks = proto_masks @ proto_coef.t() @@ -611,10 +624,54 @@ def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, pre_loss *= old_num_pos / num_pos loss_m += torch.sum(pre_loss) + + if cfg.use_maskiou: + if cfg.remove_small_gt_mask > 0: + gt_mask_area = torch.sum(mask_t, dim=(0, 1)) + select = gt_mask_area > cfg.remove_small_gt_mask + + if torch.sum(select) < 1: + continue + + pos_gt_box_t = pos_gt_box_t[select, :] + pred_masks = pred_masks[:, :, select] + mask_t = mask_t[:, :, select] + label_t = label_t[select] + + maskiou_net_input = pred_masks.permute(2, 0, 1).contiguous().unsqueeze(1) + pred_masks = pred_masks.gt(0.5).float() + maskiou_t = self._mask_iou(pred_masks, mask_t) + + maskiou_net_input_list.append(maskiou_net_input) + maskiou_t_list.append(maskiou_t) + label_t_list.append(label_t) losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w} if cfg.mask_proto_coeff_diversity_loss: losses['D'] = loss_d + if cfg.use_maskiou: + maskiou_t = torch.cat(maskiou_t_list) + label_t = torch.cat(label_t_list) + maskiou_net_input = torch.cat(maskiou_net_input_list) + + num_samples = maskiou_t.size(0) + if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train: + perm = torch.randperm(num_samples) + select = perm[:cfg.masks_to_train] + maskiou_t = maskiou_t[select] + label_t = label_t[select] + maskiou_net_input = maskiou_net_input[select] + + return losses, [maskiou_net_input, maskiou_t, label_t] + return losses + + def _mask_iou(self, mask1, mask2): + intersection = torch.sum(mask1*mask2, dim=(0, 1)) + area1 = torch.sum(mask1, dim=(0, 1)) + area2 = torch.sum(mask2, dim=(0, 1)) + union = (area1 + area2) - intersection + ret = intersection / union + return ret \ No newline at end of file diff --git a/layers/output_utils.py b/layers/output_utils.py index a6342e266..1f7b2550c 100644 --- a/layers/output_utils.py +++ b/layers/output_utils.py @@ -13,7 +13,7 @@ from .box_utils import crop, sanitize_coordinates def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', - visualize_lincomb=False, crop_masks=True, score_threshold=0): + visualize_lincomb=False, crop_masks=True, score_threshold=0, maskiou_net=None): """ Postprocesses the output of Yolact on testing mode into a format that makes sense, accounting for all the possible configuration settings. @@ -74,6 +74,17 @@ def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', # Permute into the correct output shape [num_dets, proto_h, proto_w] masks = masks.permute(2, 0, 1).contiguous() + if cfg.use_maskiou: + with timer.env('maskiou_net'): + with torch.no_grad(): + maskiou_p = maskiou_net(masks.unsqueeze(1)) + maskiou_p = torch.gather(maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1) + if cfg.rescore_mask: + if cfg.rescore_bbox: + scores = scores * maskiou_p + else: + scores = [scores, scores * maskiou_p] + # Scale masks up to the full image masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0) diff --git a/train.py b/train.py index 87d0ee5a4..b664317e4 100644 --- a/train.py +++ b/train.py @@ -115,7 +115,7 @@ def replace(name): print('Per-GPU batch size is less than the recommended limit for batch norm. Disabling batch norm.') cfg.freeze_bn = True -loss_types = ['B', 'C', 'M', 'P', 'D', 'E', 'S'] +loss_types = ['B', 'C', 'M', 'P', 'D', 'E', 'S', 'I'] if torch.cuda.is_available(): if args.cuda: @@ -138,10 +138,20 @@ def __init__(self, net:Yolact, criterion:MultiBoxLoss): self.net = net self.criterion = criterion + if cfg.use_maskiou: + self.maskiou_net = net.get_maskiou_net() def forward(self, images, targets, masks, num_crowds): preds = self.net(images) - return self.criterion(preds, targets, masks, num_crowds) + + if cfg.use_maskiou: + losses, maskiou_targets = self.criterion(preds, targets, masks, num_crowds) + maskiou_net_input, maskiou_t, label_t = maskiou_targets + loss_i = self.maskiou_net(maskiou_net_input, target=[maskiou_t, label_t]) + losses['I'] = loss_i + else: + losses = self.criterion(preds, targets, masks, num_crowds) + return losses class CustomDataParallel(nn.DataParallel): """ diff --git a/utils/functions.py b/utils/functions.py index 809060db8..3b7a4e45a 100644 --- a/utils/functions.py +++ b/utils/functions.py @@ -1,9 +1,10 @@ - import torch +import torch.nn as nn import os import math from collections import deque from pathlib import Path +from layers.interpolate import InterpolateModule class MovingAverage(): """ Keeps an average window of the specified number of items. """ @@ -158,3 +159,55 @@ def get_latest(save_folder, config): max_name = path_name return max_name + +def make_net(in_channels, conf, include_last_relu=True): + """ + A helper function to take a config setting and turn it into a network. + Used by protonet and extrahead. Returns (network, out_channels) + """ + def make_layer(layer_cfg): + nonlocal in_channels + + # Possible patterns: + # ( 256, 3, {}) -> conv + # ( 256,-2, {}) -> deconv + # (None,-2, {}) -> bilinear interpolate + # ('cat',[],{}) -> concat the subnetworks in the list + # + # You know it would have probably been simpler just to adopt a 'c' 'd' 'u' naming scheme. + # Whatever, it's too late now. + if isinstance(layer_cfg[0], str): + layer_name = layer_cfg[0] + + if layer_name == 'cat': + nets = [make_net(in_channels, x) for x in layer_cfg[1]] + layer = Concat([net[0] for net in nets], layer_cfg[2]) + num_channels = sum([net[1] for net in nets]) + else: + num_channels = layer_cfg[0] + kernel_size = layer_cfg[1] + + if kernel_size > 0: + layer = nn.Conv2d(in_channels, num_channels, kernel_size, **layer_cfg[2]) + else: + if num_channels is None: + layer = InterpolateModule(scale_factor=-kernel_size, mode='bilinear', align_corners=False, **layer_cfg[2]) + else: + layer = nn.ConvTranspose2d(in_channels, num_channels, -kernel_size, **layer_cfg[2]) + + in_channels = num_channels if num_channels is not None else in_channels + + # Don't return a ReLU layer if we're doing an upsample. This probably doesn't affect anything + # output-wise, but there's no need to go through a ReLU here. + # Commented out for backwards compatibility with previous models + # if num_channels is None: + # return [layer] + # else: + return [layer, nn.ReLU(inplace=True)] + + # Use sum to concat together all the component layer lists + net = sum([make_layer(x) for x in conf], []) + if not include_last_relu: + net = net[:-1] + + return nn.Sequential(*(net)), in_channels \ No newline at end of file diff --git a/yolact.py b/yolact.py index a49d9a2d2..fb0d66c03 100644 --- a/yolact.py +++ b/yolact.py @@ -10,12 +10,13 @@ from data.config import cfg, mask_type from layers import Detect +from layers.mask_score import FastMaskIoUNet from layers.interpolate import InterpolateModule from backbone import construct_backbone import torch.backends.cudnn as cudnn from utils import timer -from utils.functions import MovingAverage +from utils.functions import MovingAverage, make_net # This is required for Pytorch 1.0.1 on Windows to initialize Cuda on some driver versions. # See the bug report here: https://github.com/pytorch/pytorch/issues/17108 @@ -42,60 +43,6 @@ def forward(self, x): # Concat each along the channel dimension return torch.cat([net(x) for net in self.nets], dim=1, **self.extra_params) - - -def make_net(in_channels, conf, include_last_relu=True): - """ - A helper function to take a config setting and turn it into a network. - Used by protonet and extrahead. Returns (network, out_channels) - """ - def make_layer(layer_cfg): - nonlocal in_channels - - # Possible patterns: - # ( 256, 3, {}) -> conv - # ( 256,-2, {}) -> deconv - # (None,-2, {}) -> bilinear interpolate - # ('cat',[],{}) -> concat the subnetworks in the list - # - # You know it would have probably been simpler just to adopt a 'c' 'd' 'u' naming scheme. - # Whatever, it's too late now. - if isinstance(layer_cfg[0], str): - layer_name = layer_cfg[0] - - if layer_name == 'cat': - nets = [make_net(in_channels, x) for x in layer_cfg[1]] - layer = Concat([net[0] for net in nets], layer_cfg[2]) - num_channels = sum([net[1] for net in nets]) - else: - num_channels = layer_cfg[0] - kernel_size = layer_cfg[1] - - if kernel_size > 0: - layer = nn.Conv2d(in_channels, num_channels, kernel_size, **layer_cfg[2]) - else: - if num_channels is None: - layer = InterpolateModule(scale_factor=-kernel_size, mode='bilinear', align_corners=False, **layer_cfg[2]) - else: - layer = nn.ConvTranspose2d(in_channels, num_channels, -kernel_size, **layer_cfg[2]) - - in_channels = num_channels if num_channels is not None else in_channels - - # Don't return a ReLU layer if we're doing an upsample. This probably doesn't affect anything - # output-wise, but there's no need to go through a ReLU here. - # Commented out for backwards compatibility with previous models - # if num_channels is None: - # return [layer] - # else: - return [layer, nn.ReLU(inplace=True)] - - # Use sum to concat together all the component layer lists - net = sum([make_layer(x) for x in conf], []) - if not include_last_relu: - net = net[:-1] - - return nn.Sequential(*(net)), in_channels - prior_cache = defaultdict(lambda: None) class PredictionModule(nn.Module): @@ -129,7 +76,7 @@ def __init__(self, in_channels, out_channels=1024, aspect_ratios=[[1]], scales=[ self.num_classes = cfg.num_classes self.mask_dim = cfg.mask_dim # Defined by Yolact - self.num_priors = sum(len(x) for x in aspect_ratios) + self.num_priors = sum(len(x)*len(scales) for x in aspect_ratios) self.parent = [parent] # Don't include this in the state dict self.index = index self.num_heads = cfg.num_heads # Defined by Yolact @@ -264,7 +211,7 @@ def forward(self, x): preds['inst'] = inst return preds - + def make_priors(self, conv_h, conv_w, device): """ Note that priors are [x,y,width,height] where (x,y) is the center of the box. """ global prior_cache @@ -280,24 +227,25 @@ def make_priors(self, conv_h, conv_w, device): x = (i + 0.5) / conv_w y = (j + 0.5) / conv_h - for scale, ars in zip(self.scales, self.aspect_ratios): - for ar in ars: - if not cfg.backbone.preapply_sqrt: - ar = sqrt(ar) - - if cfg.backbone.use_pixel_scales: - w = scale * ar / cfg._tmp_img_w # These are populated by - h = scale / ar / cfg._tmp_img_h # Yolact.forward - else: - w = scale * ar / conv_w - h = scale / ar / conv_h - - # This is for backward compatability with a bug where I made everything square by accident - if cfg.backbone.use_square_anchors: - h = w - - prior_data += [x, y, w, h] - + for ars in self.aspect_ratios: + for scale in self.scales: + for ar in ars: + if not cfg.backbone.preapply_sqrt: + ar = sqrt(ar) + + if cfg.backbone.use_pixel_scales: + w = scale * ar / cfg.max_size + h = scale / ar / cfg.max_size + else: + w = scale * ar / conv_w + h = scale / ar / conv_h + + # This is for backward compatability with a bug where I made everything square by accident + if cfg.backbone.use_square_anchors: + h = w + + prior_data += [x, y, w, h] + self.priors = torch.Tensor(prior_data, device=device).view(-1, 4).detach() self.priors.requires_grad = False self.last_img_size = (cfg._tmp_img_w, cfg._tmp_img_h) @@ -470,6 +418,9 @@ def __init__(self): self.selected_layers = cfg.backbone.selected_layers src_channels = self.backbone.channels + if cfg.use_maskiou: + self.maskiou_net = FastMaskIoUNet() + if cfg.fpn is not None: # Some hacky rewiring to accomodate the FPN self.fpn = FPN([src_channels[i] for i in self.selected_layers]) @@ -523,7 +474,6 @@ def load_weights(self, path): if key.startswith('fpn.downsample_layers.'): if cfg.fpn is not None and int(key.split('.')[2]) >= cfg.fpn.num_downsample: del state_dict[key] - self.load_state_dict(state_dict) def init_weights(self, backbone_path): @@ -575,6 +525,11 @@ def all_in(x, y): else: module.bias.data.zero_() + def get_maskiou_net(self): + if cfg.use_maskiou: + return self.maskiou_net + return None + def train(self, mode=True): super().train(mode) @@ -589,7 +544,7 @@ def freeze_bn(self, enable=False): module.weight.requires_grad = enable module.bias.requires_grad = enable - + def forward(self, x): """ The input should be of size [batch_size, 3, img_h, img_w] """ _, _, img_h, img_w = x.size()