diff --git a/README.md b/README.md
index 7d49cb232..b8dae1ea6 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,13 @@
A simple, fully convolutional model for real-time instance segmentation. This is the code for [our paper](https://arxiv.org/abs/1904.02689).
+#### YOLACT++ implementation and models released!
+YOLACT++ resnet50 model runs at 33.5 fps on a Titan Xp and achieves 34.1 mAP on COCO's `test-dev`.
+
+Related paper will be posted on arXiv soon.
+
+In order to use YOLACT++, make sure you compile the DCNv2 code. (See [Installation](https://github.com/dbolya/yolact#installation))
+
#### ICCV update (v1.1) released! Check out the ICCV trailer here:
[](https://www.youtube.com/watch?v=0pMfmo8qfpQ)
@@ -37,6 +44,11 @@ Some examples from our base model (33.5 fps on a Titan Xp and 29.8 mAP on COCO's
git clone https://github.com/dbolya/yolact.git
cd yolact
```
+ - Compile deformable convolutional layers (from [DCNv2](https://github.com/CharlesShang/DCNv2/tree/pytorch_1.0))
+ ```Shell
+ cd external/DCNv2
+ ./make.sh
+ ```
- If you'd like to train YOLACT, download the COCO dataset and the 2014/2017 annotations. Note that this script will take a while and dump 21gb of files into `./data/coco`.
```Shell
sh data/scripts/COCO.sh
@@ -57,6 +69,13 @@ As of April 5th, 2019 here are our latest models along with their FPS on a Titan
| 550 | Resnet101-FPN | 33.0 | 29.8 | [yolact_base_54_800000.pth](https://drive.google.com/file/d/1UYy3dMapbH1BnmtZU4WH1zbYgOzzHHf_/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EYRWxBEoKU9DiblrWx2M89MBGFkVVB_drlRd_v5sdT3Hgg)
| 700 | Resnet101-FPN | 23.6 | 31.2 | [yolact_im700_54_800000.pth](https://drive.google.com/file/d/1lE4Lz5p25teiXV-6HdTiOJSnS7u7GBzg/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/Eagg5RSc5hFEhp7sPtvLNyoBjhlf2feog7t8OQzHKKphjw)
+YOLACT++ models (released on Dec. 6th, 2019):
+
+| Image Size | Backbone | FPS | mAP | Weights | |
+|:----------:|:-------------:|:----:|:----:|----------------------------------------------------------------------------------------------------------------------|--------|
+| 550 | Resnet50-FPN | 33.5 | 34.1 | [yolact_plus_resnet50_54_800000.pth](https://drive.google.com/file/d/1ZPu1YR2UzGHQD0o1rEqy-j5bmEm3lbyP/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EcJAtMiEFlhAnVsDf00yWRIBUC4m8iE9NEEiV05XwtEoGw) |
+| 550 | Resnet101-FPN | 27.3 | 34.6 | [yolact_plus_base_54_800000.pth](https://drive.google.com/file/d/15id0Qq5eqRbkD-N3ZjDZXdCvRyIaHpFB/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EVQ62sF0SrJPrl_68onyHF8BpG7c05A8PavV4a849sZgEA)
+
To evalute the model, put the corresponding weights file in the `./weights` directory and run one of the following commands.
## Quantitative Results on COCO
```Shell
diff --git a/backbone.py b/backbone.py
index e7360423e..100e8ff02 100644
--- a/backbone.py
+++ b/backbone.py
@@ -4,16 +4,25 @@
from collections import OrderedDict
+from dcn_v2 import DCN
+
class Bottleneck(nn.Module):
""" Adapted from torchvision.models.resnet """
expansion = 4
- def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=nn.BatchNorm2d, dilation=1):
+ def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=nn.BatchNorm2d, dilation=1, use_dcn=False):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, dilation=dilation)
self.bn1 = norm_layer(planes)
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
- padding=dilation, bias=False, dilation=dilation)
+ if use_dcn:
+ self.conv2 = DCN(planes, planes, kernel_size=3, stride=stride,
+ padding=dilation, dilation=dilation, deformable_groups=1)
+ self.conv2.bias.data.zero_()
+ self.conv2.conv_offset_mask.weight.data.zero_()
+ self.conv2.conv_offset_mask.bias.data.zero_()
+ else:
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=dilation, bias=False, dilation=dilation)
self.bn2 = norm_layer(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False, dilation=dilation)
self.bn3 = norm_layer(planes * 4)
@@ -47,7 +56,7 @@ def forward(self, x):
class ResNetBackbone(nn.Module):
""" Adapted from torchvision.models.resnet """
- def __init__(self, layers, atrous_layers=[], block=Bottleneck, norm_layer=nn.BatchNorm2d):
+ def __init__(self, layers, dcn_layers=[0, 0, 0, 0], dcn_interval=1, atrous_layers=[], block=Bottleneck, norm_layer=nn.BatchNorm2d):
super().__init__()
# These will be populated by _make_layer
@@ -66,10 +75,10 @@ def __init__(self, layers, atrous_layers=[], block=Bottleneck, norm_layer=nn.Bat
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- self._make_layer(block, 64, layers[0])
- self._make_layer(block, 128, layers[1], stride=2)
- self._make_layer(block, 256, layers[2], stride=2)
- self._make_layer(block, 512, layers[3], stride=2)
+ self._make_layer(block, 64, layers[0], dcn_layers=dcn_layers[0], dcn_interval=dcn_interval)
+ self._make_layer(block, 128, layers[1], stride=2, dcn_layers=dcn_layers[1], dcn_interval=dcn_interval)
+ self._make_layer(block, 256, layers[2], stride=2, dcn_layers=dcn_layers[2], dcn_interval=dcn_interval)
+ self._make_layer(block, 512, layers[3], stride=2, dcn_layers=dcn_layers[3], dcn_interval=dcn_interval)
# This contains every module that should be initialized by loading in pretrained weights.
# Any extra layers added onto this that won't be initialized by init_backbone will not be
@@ -78,7 +87,7 @@ def __init__(self, layers, atrous_layers=[], block=Bottleneck, norm_layer=nn.Bat
self.backbone_modules = [m for m in self.modules() if isinstance(m, nn.Conv2d)]
- def _make_layer(self, block, planes, blocks, stride=1):
+ def _make_layer(self, block, planes, blocks, stride=1, dcn_layers=0, dcn_interval=1):
""" Here one layer means a string of n Bottleneck blocks. """
downsample = None
@@ -97,11 +106,12 @@ def _make_layer(self, block, planes, blocks, stride=1):
)
layers = []
- layers.append(block(self.inplanes, planes, stride, downsample, self.norm_layer, self.dilation))
+ use_dcn = (dcn_layers >= blocks)
+ layers.append(block(self.inplanes, planes, stride, downsample, self.norm_layer, self.dilation, use_dcn=use_dcn))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
- layers.append(block(self.inplanes, planes, norm_layer=self.norm_layer))
-
+ use_dcn = ((i+dcn_layers) >= blocks) and (i % dcn_interval == 0)
+ layers.append(block(self.inplanes, planes, norm_layer=self.norm_layer, use_dcn=use_dcn))
layer = nn.Sequential(*layers)
self.channels.append(planes * block.expansion)
diff --git a/data/config.py b/data/config.py
index 6ebacab45..b945dd0fb 100644
--- a/data/config.py
+++ b/data/config.py
@@ -247,6 +247,11 @@ def print(self):
'pred_aspect_ratios': [ [[0.66685089, 1.7073535, 0.87508774, 1.16524493, 0.49059086]] ] * 6,
})
+resnet101_dcn_inter3_backbone = resnet101_backbone.copy({
+ 'name': 'ResNet101_DCN_Interval3',
+ 'args': ([3, 4, 23, 3], [0, 4, 23, 3], 3),
+})
+
resnet50_backbone = resnet101_backbone.copy({
'name': 'ResNet50',
'path': 'resnet50-19c8e357.pth',
@@ -255,6 +260,11 @@ def print(self):
'transform': resnet_transform,
})
+resnet50_dcnv2_backbone = resnet50_backbone.copy({
+ 'name': 'ResNet50_DCNv2',
+ 'args': ([3, 4, 6, 3], [0, 4, 6, 3]),
+})
+
darknet53_backbone = backbone_base.copy({
'name': 'DarkNet53',
'path': 'darknet53.pth',
@@ -614,6 +624,19 @@ def print(self):
'backbone': None,
'name': 'base_config',
+
+ # Fast Mask Re-scoring Network
+ # Inspried by Mask Scoring R-CNN (https://arxiv.org/abs/1903.00241)
+ # Do not crop out the mask with bbox but slide a convnet on the image-size mask,
+ # then use global pooling to get the final mask score
+ 'use_maskiou': False,
+ 'maskiou_net': [],
+ 'remove_small_gt_mask': -1,
+
+ 'maskiou_alpha': 1.0,
+ 'rescore_mask': False,
+ 'rescore_bbox': False,
+ 'maskious_to_train': -1,
})
@@ -736,6 +759,44 @@ def print(self):
})
})
+# ----------------------- YOLACT++ CONFIGS ----------------------- #
+
+yolact_plus_base_config = yolact_base_config.copy({
+ 'name': 'yolact_plus_base',
+
+ 'backbone': resnet101_dcn_inter3_backbone.copy({
+ 'selected_layers': list(range(1, 4)),
+
+ 'pred_aspect_ratios': [ [[1, 1/2, 2]] ]*5,
+ 'pred_scales': [[i * 2 ** (j / 3.0) for j in range(3)] for i in [24, 48, 96, 192, 384]],
+ 'use_pixel_scales': True,
+ 'preapply_sqrt': False,
+ 'use_square_anchors': False,
+ }),
+
+ 'use_maskiou': True,
+ 'maskiou_net': [(8, 3, {'stride': 2}), (16, 3, {'stride': 2}), (32, 3, {'stride': 2}), (64, 3, {'stride': 2}), (128, 3, {'stride': 2}), (80, 1, {})],
+ 'maskiou_alpha': 25,
+ 'rescore_bbox': False,
+ 'rescore_mask': True,
+
+ 'remove_small_gt_mask': 5*5,
+})
+
+yolact_plus_resnet50_config = yolact_plus_base_config.copy({
+ 'name': 'yolact_plus_resnet50',
+
+ 'backbone': resnet50_dcnv2_backbone.copy({
+ 'selected_layers': list(range(1, 4)),
+
+ 'pred_aspect_ratios': [ [[1, 1/2, 2]] ]*5,
+ 'pred_scales': [[i * 2 ** (j / 3.0) for j in range(3)] for i in [24, 48, 96, 192, 384]],
+ 'use_pixel_scales': True,
+ 'preapply_sqrt': False,
+ 'use_square_anchors': False,
+ }),
+})
+
# Default config
cfg = yolact_base_config.copy()
diff --git a/eval.py b/eval.py
index d26400856..f569b690e 100644
--- a/eval.py
+++ b/eval.py
@@ -132,7 +132,7 @@ def parse_args(argv=None):
coco_cats_inv = {}
color_cache = defaultdict(lambda: {})
-def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str=''):
+def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str='', maskiou_net=None):
"""
Note: If undo_transform=False then im_h and im_w are allowed to be None.
"""
@@ -146,14 +146,34 @@ def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, ma
with timer.env('Postprocess'):
t = postprocess(dets_out, w, h, visualize_lincomb = args.display_lincomb,
crop_masks = args.crop,
- score_threshold = args.score_threshold)
+ score_threshold = args.score_threshold,
+ maskiou_net = maskiou_net)
torch.cuda.synchronize()
+ # FIXME reduce copy
with timer.env('Copy'):
if cfg.eval_mask_branch:
# Masks are drawn on the GPU, so don't copy
- masks = t[3][:args.top_k]
- classes, scores, boxes = [x[:args.top_k].cpu().numpy() for x in t[:3]]
+ masks = t[3]
+ classes, scores, boxes = [x for x in t[:3]]
+ if isinstance(scores, list):
+ box_scores = scores[0].cpu().numpy()
+ mask_scores = scores[1].cpu().numpy()
+ # Re-rank predictions by mask scores
+ _scores = mask_scores * box_scores
+ idx = np.argsort(-_scores)
+ scores = box_scores[idx]
+ classes = classes.cpu().numpy()[idx]
+ boxes = boxes.cpu().numpy()[idx]
+ masks = masks[idx]
+ else:
+ scores = scores.cpu().numpy()
+ classes = classes.cpu().numpy()
+ boxes = boxes.cpu().numpy()
+ scores = scores[:args.top_k]
+ classes = classes[:args.top_k]
+ boxes = boxes[:args.top_k]
+ masks = masks[:args.top_k]
num_dets_to_consider = min(args.top_k, classes.shape[0])
for j in range(num_dets_to_consider):
@@ -257,12 +277,20 @@ def get_color(j, on_gpu=None):
return img_numpy
-def prep_benchmark(dets_out, h, w):
+def prep_benchmark(dets_out, h, w, maskiou_net=None):
with timer.env('Postprocess'):
- t = postprocess(dets_out, w, h, crop_masks=args.crop, score_threshold=args.score_threshold)
+ t = postprocess(dets_out, w, h, crop_masks=args.crop, score_threshold=args.score_threshold, maskiou_net=maskiou_net)
with timer.env('Copy'):
- classes, scores, boxes, masks = [x[:args.top_k].cpu().numpy() for x in t]
+ classes, scores, boxes, masks = [x[:args.top_k] for x in t]
+ if isinstance(scores, list):
+ box_scores = scores[0].cpu().numpy()
+ mask_scores = scores[1].cpu().numpy()
+ else:
+ scores = scores.cpu().numpy()
+ classes = classes.cpu().numpy()
+ boxes = boxes.cpu().numpy()
+ masks = masks.cpu().numpy()
with timer.env('Sync'):
# Just in case
@@ -371,7 +399,7 @@ def _bbox_iou(bbox1, bbox2, iscrowd=False):
ret = jaccard(bbox1, bbox2, iscrowd)
return ret.cpu()
-def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, detections:Detections=None):
+def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, detections:Detections=None, maskiou_net=None):
""" Returns a list of APs for this image, with each element being for a class """
if not args.output_coco_json:
with timer.env('Prepare gt'):
@@ -388,13 +416,19 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de
crowd_classes, gt_classes = split(gt_classes)
with timer.env('Postprocess'):
- classes, scores, boxes, masks = postprocess(dets, w, h, crop_masks=args.crop, score_threshold=args.score_threshold)
+ classes, scores, boxes, masks = postprocess(dets, w, h, crop_masks=args.crop, score_threshold=args.score_threshold, maskiou_net=maskiou_net)
if classes.size(0) == 0:
return
classes = list(classes.cpu().numpy().astype(int))
- scores = list(scores.cpu().numpy().astype(float))
+ if isinstance(scores, list):
+ box_scores = list(scores[0].cpu().numpy().astype(float))
+ mask_scores = list(scores[1].cpu().numpy().astype(float))
+ else:
+ scores = list(scores.cpu().numpy().astype(float))
+ box_scores = scores
+ mask_scores = scores
masks = masks.view(-1, h*w).cuda()
boxes = boxes.cuda()
@@ -406,8 +440,8 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de
for i in range(masks.shape[0]):
# Make sure that the bounding box actually makes sense and a mask was produced
if (boxes[i, 3] - boxes[i, 1]) * (boxes[i, 2] - boxes[i, 0]) > 0:
- detections.add_bbox(image_id, classes[i], boxes[i,:], scores[i])
- detections.add_mask(image_id, classes[i], masks[i,:,:], scores[i])
+ detections.add_bbox(image_id, classes[i], boxes[i,:], box_scores[i])
+ detections.add_mask(image_id, classes[i], masks[i,:,:], mask_scores[i])
return
with timer.env('Eval Setup'):
@@ -425,8 +459,8 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de
crowd_bbox_iou_cache = None
iou_types = [
- ('box', lambda i,j: bbox_iou_cache[i, j].item(), lambda i,j: crowd_bbox_iou_cache[i,j].item()),
- ('mask', lambda i,j: mask_iou_cache[i, j].item(), lambda i,j: crowd_mask_iou_cache[i,j].item())
+ ('box', lambda i,j: bbox_iou_cache[i, j].item(), lambda i,j: crowd_bbox_iou_cache[i,j].item(), lambda i: box_scores[i]),
+ ('mask', lambda i,j: mask_iou_cache[i, j].item(), lambda i,j: crowd_mask_iou_cache[i,j].item(), lambda i: mask_scores[i])
]
timer.start('Main loop')
@@ -437,7 +471,7 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de
for iouIdx in range(len(iou_thresholds)):
iou_threshold = iou_thresholds[iouIdx]
- for iou_type, iou_func, crowd_func in iou_types:
+ for iou_type, iou_func, crowd_func, score_func in iou_types:
gt_used = [False] * len(gt_classes)
ap_obj = ap_data[iou_type][iouIdx][_class]
@@ -461,7 +495,7 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de
if max_match_idx >= 0:
gt_used[max_match_idx] = True
- ap_obj.push(scores[i], True)
+ ap_obj.push(score_func(i), True)
else:
# If the detection matches a crowd, we can just ignore it
matched_crowd = False
@@ -481,7 +515,7 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de
# same result as COCOEval. There aren't even that many crowd annotations to
# begin with, but accuracy is of the utmost importance.
if not matched_crowd:
- ap_obj.push(scores[i], False)
+ ap_obj.push(score_func(i), False)
timer.stop('Main loop')
@@ -846,6 +880,7 @@ def evaluate(net:Yolact, dataset, train_mode=False):
net.detect.use_cross_class_nms = args.cross_class_nms
cfg.mask_proto_debug = args.mask_proto_debug
+ # TODO Currently we do not support Fast Mask Re-scroing in evalimage, evalimages, and evalvideo
if args.image is not None:
if ':' in args.image:
inp, out = args.image.split(':')
@@ -921,13 +956,14 @@ def evaluate(net:Yolact, dataset, train_mode=False):
with timer.env('Network Extra'):
preds = net(batch)
+ maskiou_net = net.get_maskiou_net()
# Perform the meat of the operation here depending on our mode.
if args.display:
- img_numpy = prep_display(preds, img, h, w)
+ img_numpy = prep_display(preds, img, h, w, maskiou_net=maskiou_net)
elif args.benchmark:
- prep_benchmark(preds, h, w)
+ prep_benchmark(preds, h, w, maskiou_net=maskiou_net)
else:
- prep_metrics(ap_data, preds, img, gt, gt_masks, h, w, num_crowd, dataset.ids[image_idx], detections)
+ prep_metrics(ap_data, preds, img, gt, gt_masks, h, w, num_crowd, dataset.ids[image_idx], detections, maskiou_net=maskiou_net)
# First couple of images take longer because we're constructing the graph.
# Since that's technically initialization, don't include those in the FPS calculations.
diff --git a/external/DCNv2/LICENSE b/external/DCNv2/LICENSE
new file mode 100644
index 000000000..b2e3b5207
--- /dev/null
+++ b/external/DCNv2/LICENSE
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2019, Charles Shang
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/external/DCNv2/README.md b/external/DCNv2/README.md
new file mode 100644
index 000000000..9787cfa2c
--- /dev/null
+++ b/external/DCNv2/README.md
@@ -0,0 +1,65 @@
+## Deformable Convolutional Networks V2 with Pytorch 1.0
+
+### Build
+```bash
+ ./make.sh # build
+ python test.py # run examples and gradient check
+```
+
+### An Example
+- deformable conv
+```python
+ from dcn_v2 import DCN
+ input = torch.randn(2, 64, 128, 128).cuda()
+ # wrap all things (offset and mask) in DCN
+ dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda()
+ output = dcn(input)
+ print(output.shape)
+```
+- deformable roi pooling
+```python
+ from dcn_v2 import DCNPooling
+ input = torch.randn(2, 32, 64, 64).cuda()
+ batch_inds = torch.randint(2, (20, 1)).cuda().float()
+ x = torch.randint(256, (20, 1)).cuda().float()
+ y = torch.randint(256, (20, 1)).cuda().float()
+ w = torch.randint(64, (20, 1)).cuda().float()
+ h = torch.randint(64, (20, 1)).cuda().float()
+ rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
+
+ # mdformable pooling (V2)
+ # wrap all things (offset and mask) in DCNPooling
+ dpooling = DCNPooling(spatial_scale=1.0 / 4,
+ pooled_size=7,
+ output_dim=32,
+ no_trans=False,
+ group_size=1,
+ trans_std=0.1).cuda()
+
+ dout = dpooling(input, rois)
+```
+### Note
+Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with,
+```bash
+git checkout pytorch_0.4
+```
+
+### Known Issues:
+
+- [x] Gradient check w.r.t offset (solved)
+- [ ] Backward is not reentrant (minor)
+
+This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op).
+
+I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes.
+However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some
+non-differential points?
+
+Update: all gradient check passes with double precision.
+
+Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for
+float `<1e-15` for double),
+so it may not be a serious problem (?)
+
+Please post an issue or PR if you have any comments.
+
\ No newline at end of file
diff --git a/external/DCNv2/__init__.py b/external/DCNv2/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/external/DCNv2/dcn_v2.py b/external/DCNv2/dcn_v2.py
new file mode 100644
index 000000000..982bef512
--- /dev/null
+++ b/external/DCNv2/dcn_v2.py
@@ -0,0 +1,303 @@
+#!/usr/bin/env python
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import math
+import torch
+from torch import nn
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+from torch.autograd.function import once_differentiable
+
+import _ext as _backend
+
+
+class _DCNv2(Function):
+ @staticmethod
+ def forward(ctx, input, offset, mask, weight, bias,
+ stride, padding, dilation, deformable_groups):
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.kernel_size = _pair(weight.shape[2:4])
+ ctx.deformable_groups = deformable_groups
+ output = _backend.dcn_v2_forward(input, weight, bias,
+ offset, mask,
+ ctx.kernel_size[0], ctx.kernel_size[1],
+ ctx.stride[0], ctx.stride[1],
+ ctx.padding[0], ctx.padding[1],
+ ctx.dilation[0], ctx.dilation[1],
+ ctx.deformable_groups)
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \
+ _backend.dcn_v2_backward(input, weight,
+ bias,
+ offset, mask,
+ grad_output,
+ ctx.kernel_size[0], ctx.kernel_size[1],
+ ctx.stride[0], ctx.stride[1],
+ ctx.padding[0], ctx.padding[1],
+ ctx.dilation[0], ctx.dilation[1],
+ ctx.deformable_groups)
+
+ return grad_input, grad_offset, grad_mask, grad_weight, grad_bias,\
+ None, None, None, None,
+
+
+dcn_v2_conv = _DCNv2.apply
+
+
+class DCNv2(nn.Module):
+
+ def __init__(self, in_channels, out_channels,
+ kernel_size, stride, padding, dilation=1, deformable_groups=1):
+ super(DCNv2, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.deformable_groups = deformable_groups
+
+ self.weight = nn.Parameter(torch.Tensor(
+ out_channels, in_channels, *self.kernel_size))
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ self.bias.data.zero_()
+
+ def forward(self, input, offset, mask):
+ assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
+ offset.shape[1]
+ assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
+ mask.shape[1]
+ return dcn_v2_conv(input, offset, mask,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.deformable_groups)
+
+
+class DCN(DCNv2):
+
+ def __init__(self, in_channels, out_channels,
+ kernel_size, stride, padding,
+ dilation=1, deformable_groups=1):
+ super(DCN, self).__init__(in_channels, out_channels,
+ kernel_size, stride, padding, dilation, deformable_groups)
+
+ channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
+ self.conv_offset_mask = nn.Conv2d(self.in_channels,
+ channels_,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset_mask.weight.data.zero_()
+ self.conv_offset_mask.bias.data.zero_()
+
+ def forward(self, input):
+ out = self.conv_offset_mask(input)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return dcn_v2_conv(input, offset, mask,
+ self.weight, self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.deformable_groups)
+
+
+
+class _DCNv2Pooling(Function):
+ @staticmethod
+ def forward(ctx, input, rois, offset,
+ spatial_scale,
+ pooled_size,
+ output_dim,
+ no_trans,
+ group_size=1,
+ part_size=None,
+ sample_per_part=4,
+ trans_std=.0):
+ ctx.spatial_scale = spatial_scale
+ ctx.no_trans = int(no_trans)
+ ctx.output_dim = output_dim
+ ctx.group_size = group_size
+ ctx.pooled_size = pooled_size
+ ctx.part_size = pooled_size if part_size is None else part_size
+ ctx.sample_per_part = sample_per_part
+ ctx.trans_std = trans_std
+
+ output, output_count = \
+ _backend.dcn_v2_psroi_pooling_forward(input, rois, offset,
+ ctx.no_trans, ctx.spatial_scale,
+ ctx.output_dim, ctx.group_size,
+ ctx.pooled_size, ctx.part_size,
+ ctx.sample_per_part, ctx.trans_std)
+ ctx.save_for_backward(input, rois, offset, output_count)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, rois, offset, output_count = ctx.saved_tensors
+ grad_input, grad_offset = \
+ _backend.dcn_v2_psroi_pooling_backward(grad_output,
+ input,
+ rois,
+ offset,
+ output_count,
+ ctx.no_trans,
+ ctx.spatial_scale,
+ ctx.output_dim,
+ ctx.group_size,
+ ctx.pooled_size,
+ ctx.part_size,
+ ctx.sample_per_part,
+ ctx.trans_std)
+
+ return grad_input, None, grad_offset, \
+ None, None, None, None, None, None, None, None
+
+
+dcn_v2_pooling = _DCNv2Pooling.apply
+
+
+class DCNv2Pooling(nn.Module):
+
+ def __init__(self,
+ spatial_scale,
+ pooled_size,
+ output_dim,
+ no_trans,
+ group_size=1,
+ part_size=None,
+ sample_per_part=4,
+ trans_std=.0):
+ super(DCNv2Pooling, self).__init__()
+ self.spatial_scale = spatial_scale
+ self.pooled_size = pooled_size
+ self.output_dim = output_dim
+ self.no_trans = no_trans
+ self.group_size = group_size
+ self.part_size = pooled_size if part_size is None else part_size
+ self.sample_per_part = sample_per_part
+ self.trans_std = trans_std
+
+ def forward(self, input, rois, offset):
+ assert input.shape[1] == self.output_dim
+ if self.no_trans:
+ offset = input.new()
+ return dcn_v2_pooling(input, rois, offset,
+ self.spatial_scale,
+ self.pooled_size,
+ self.output_dim,
+ self.no_trans,
+ self.group_size,
+ self.part_size,
+ self.sample_per_part,
+ self.trans_std)
+
+
+class DCNPooling(DCNv2Pooling):
+
+ def __init__(self,
+ spatial_scale,
+ pooled_size,
+ output_dim,
+ no_trans,
+ group_size=1,
+ part_size=None,
+ sample_per_part=4,
+ trans_std=.0,
+ deform_fc_dim=1024):
+ super(DCNPooling, self).__init__(spatial_scale,
+ pooled_size,
+ output_dim,
+ no_trans,
+ group_size,
+ part_size,
+ sample_per_part,
+ trans_std)
+
+ self.deform_fc_dim = deform_fc_dim
+
+ if not no_trans:
+ self.offset_mask_fc = nn.Sequential(
+ nn.Linear(self.pooled_size * self.pooled_size *
+ self.output_dim, self.deform_fc_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_dim, self.pooled_size *
+ self.pooled_size * 3)
+ )
+ self.offset_mask_fc[4].weight.data.zero_()
+ self.offset_mask_fc[4].bias.data.zero_()
+
+ def forward(self, input, rois):
+ offset = input.new()
+
+ if not self.no_trans:
+
+ # do roi_align first
+ n = rois.shape[0]
+ roi = dcn_v2_pooling(input, rois, offset,
+ self.spatial_scale,
+ self.pooled_size,
+ self.output_dim,
+ True, # no trans
+ self.group_size,
+ self.part_size,
+ self.sample_per_part,
+ self.trans_std)
+
+ # build mask and offset
+ offset_mask = self.offset_mask_fc(roi.view(n, -1))
+ offset_mask = offset_mask.view(
+ n, 3, self.pooled_size, self.pooled_size)
+ o1, o2, mask = torch.chunk(offset_mask, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+
+ # do pooling with offset and mask
+ return dcn_v2_pooling(input, rois, offset,
+ self.spatial_scale,
+ self.pooled_size,
+ self.output_dim,
+ self.no_trans,
+ self.group_size,
+ self.part_size,
+ self.sample_per_part,
+ self.trans_std) * mask
+ # only roi_align
+ return dcn_v2_pooling(input, rois, offset,
+ self.spatial_scale,
+ self.pooled_size,
+ self.output_dim,
+ self.no_trans,
+ self.group_size,
+ self.part_size,
+ self.sample_per_part,
+ self.trans_std)
diff --git a/external/DCNv2/make.sh b/external/DCNv2/make.sh
new file mode 100755
index 000000000..f1f15c0e3
--- /dev/null
+++ b/external/DCNv2/make.sh
@@ -0,0 +1,2 @@
+#!/usr/bin/env bash
+python setup.py build develop
diff --git a/external/DCNv2/setup.py b/external/DCNv2/setup.py
new file mode 100644
index 000000000..108249428
--- /dev/null
+++ b/external/DCNv2/setup.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python
+
+import os
+import glob
+
+import torch
+
+from torch.utils.cpp_extension import CUDA_HOME
+from torch.utils.cpp_extension import CppExtension
+from torch.utils.cpp_extension import CUDAExtension
+
+from setuptools import find_packages
+from setuptools import setup
+
+requirements = ["torch", "torchvision"]
+
+def get_extensions():
+ this_dir = os.path.dirname(os.path.abspath(__file__))
+ extensions_dir = os.path.join(this_dir, "src")
+
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
+
+ sources = main_file + source_cpu
+ extension = CppExtension
+ extra_compile_args = {"cxx": []}
+ define_macros = []
+
+ if torch.cuda.is_available() and CUDA_HOME is not None:
+ extension = CUDAExtension
+ sources += source_cuda
+ define_macros += [("WITH_CUDA", None)]
+ extra_compile_args["nvcc"] = [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ ]
+ else:
+ raise NotImplementedError('Cuda is not availabel')
+
+ sources = [os.path.join(extensions_dir, s) for s in sources]
+ include_dirs = [extensions_dir]
+ ext_modules = [
+ extension(
+ "_ext",
+ sources,
+ include_dirs=include_dirs,
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args,
+ )
+ ]
+ return ext_modules
+
+setup(
+ name="DCNv2",
+ version="0.1",
+ author="charlesshang",
+ url="https://github.com/charlesshang/DCNv2",
+ description="deformable convolutional networks",
+ packages=find_packages(exclude=("configs", "tests",)),
+ # install_requires=requirements,
+ ext_modules=get_extensions(),
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
+)
\ No newline at end of file
diff --git a/external/DCNv2/src/cpu/dcn_v2_cpu.cpp b/external/DCNv2/src/cpu/dcn_v2_cpu.cpp
new file mode 100644
index 000000000..a68ccef8e
--- /dev/null
+++ b/external/DCNv2/src/cpu/dcn_v2_cpu.cpp
@@ -0,0 +1,74 @@
+#include
+
+#include
+#include
+
+
+at::Tensor
+dcn_v2_cpu_forward(const at::Tensor &input,
+ const at::Tensor &weight,
+ const at::Tensor &bias,
+ const at::Tensor &offset,
+ const at::Tensor &mask,
+ const int kernel_h,
+ const int kernel_w,
+ const int stride_h,
+ const int stride_w,
+ const int pad_h,
+ const int pad_w,
+ const int dilation_h,
+ const int dilation_w,
+ const int deformable_group)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::vector
+dcn_v2_cpu_backward(const at::Tensor &input,
+ const at::Tensor &weight,
+ const at::Tensor &bias,
+ const at::Tensor &offset,
+ const at::Tensor &mask,
+ const at::Tensor &grad_output,
+ int kernel_h, int kernel_w,
+ int stride_h, int stride_w,
+ int pad_h, int pad_w,
+ int dilation_h, int dilation_w,
+ int deformable_group)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::tuple
+dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
+ const at::Tensor &bbox,
+ const at::Tensor &trans,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::tuple
+dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
+ const at::Tensor &input,
+ const at::Tensor &bbox,
+ const at::Tensor &trans,
+ const at::Tensor &top_count,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std)
+{
+ AT_ERROR("Not implement on cpu");
+}
\ No newline at end of file
diff --git a/external/DCNv2/src/cpu/vision.h b/external/DCNv2/src/cpu/vision.h
new file mode 100644
index 000000000..d5fbf1f07
--- /dev/null
+++ b/external/DCNv2/src/cpu/vision.h
@@ -0,0 +1,60 @@
+#pragma once
+#include
+
+at::Tensor
+dcn_v2_cpu_forward(const at::Tensor &input,
+ const at::Tensor &weight,
+ const at::Tensor &bias,
+ const at::Tensor &offset,
+ const at::Tensor &mask,
+ const int kernel_h,
+ const int kernel_w,
+ const int stride_h,
+ const int stride_w,
+ const int pad_h,
+ const int pad_w,
+ const int dilation_h,
+ const int dilation_w,
+ const int deformable_group);
+
+std::vector
+dcn_v2_cpu_backward(const at::Tensor &input,
+ const at::Tensor &weight,
+ const at::Tensor &bias,
+ const at::Tensor &offset,
+ const at::Tensor &mask,
+ const at::Tensor &grad_output,
+ int kernel_h, int kernel_w,
+ int stride_h, int stride_w,
+ int pad_h, int pad_w,
+ int dilation_h, int dilation_w,
+ int deformable_group);
+
+
+std::tuple
+dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
+ const at::Tensor &bbox,
+ const at::Tensor &trans,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std);
+
+std::tuple
+dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
+ const at::Tensor &input,
+ const at::Tensor &bbox,
+ const at::Tensor &trans,
+ const at::Tensor &top_count,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std);
\ No newline at end of file
diff --git a/external/DCNv2/src/cuda/dcn_v2_cuda.cu b/external/DCNv2/src/cuda/dcn_v2_cuda.cu
new file mode 100644
index 000000000..767ed8fb1
--- /dev/null
+++ b/external/DCNv2/src/cuda/dcn_v2_cuda.cu
@@ -0,0 +1,335 @@
+#include
+#include "cuda/dcn_v2_im2col_cuda.h"
+
+#include
+#include
+
+#include
+#include
+#include
+
+extern THCState *state;
+
+// author: Charles Shang
+// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
+
+// [batch gemm]
+// https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu
+
+__global__ void createBatchGemmBuffer(const float **input_b, float **output_b,
+ float **columns_b, const float **ones_b,
+ const float **weight_b, const float **bias_b,
+ float *input, float *output,
+ float *columns, float *ones,
+ float *weight, float *bias,
+ const int input_stride, const int output_stride,
+ const int columns_stride, const int ones_stride,
+ const int num_batches)
+{
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < num_batches)
+ {
+ input_b[idx] = input + idx * input_stride;
+ output_b[idx] = output + idx * output_stride;
+ columns_b[idx] = columns + idx * columns_stride;
+ ones_b[idx] = ones + idx * ones_stride;
+ // share weights and bias within a Mini-Batch
+ weight_b[idx] = weight;
+ bias_b[idx] = bias;
+ }
+}
+
+at::Tensor
+dcn_v2_cuda_forward(const at::Tensor &input,
+ const at::Tensor &weight,
+ const at::Tensor &bias,
+ const at::Tensor &offset,
+ const at::Tensor &mask,
+ const int kernel_h,
+ const int kernel_w,
+ const int stride_h,
+ const int stride_w,
+ const int pad_h,
+ const int pad_w,
+ const int dilation_h,
+ const int dilation_w,
+ const int deformable_group)
+{
+ using scalar_t = float;
+ // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));
+ AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
+ AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
+ AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
+ AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
+ AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_out = weight.size(0);
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+
+ // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h);
+ // printf("Channels: %d %d\n", channels, channels_kernel);
+ // printf("Channels: %d %d\n", channels_out, channels_kernel);
+
+ AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
+ "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
+
+ AT_ASSERTM(channels == channels_kernel,
+ "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
+
+ const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ auto ones = at::ones({batch, height_out, width_out}, input.options());
+ auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
+ auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
+
+ // prepare for batch-wise computing, which is significantly faster than instance-wise computing
+ // when batch size is large.
+ // launch batch threads
+ int matrices_size = batch * sizeof(float *);
+ auto input_b = static_cast(THCudaMalloc(state, matrices_size));
+ auto output_b = static_cast(THCudaMalloc(state, matrices_size));
+ auto columns_b = static_cast(THCudaMalloc(state, matrices_size));
+ auto ones_b = static_cast(THCudaMalloc(state, matrices_size));
+ auto weight_b = static_cast(THCudaMalloc(state, matrices_size));
+ auto bias_b = static_cast(THCudaMalloc(state, matrices_size));
+
+ const int block = 128;
+ const int grid = (batch + block - 1) / block;
+
+ createBatchGemmBuffer<<>>(
+ input_b, output_b,
+ columns_b, ones_b,
+ weight_b, bias_b,
+ input.data(),
+ output.data(),
+ columns.data(),
+ ones.data(),
+ weight.data(),
+ bias.data(),
+ channels * width * height,
+ channels_out * width_out * height_out,
+ channels * kernel_h * kernel_w * height_out * width_out,
+ height_out * width_out,
+ batch);
+
+ long m_ = channels_out;
+ long n_ = height_out * width_out;
+ long k_ = 1;
+ THCudaBlas_SgemmBatched(state,
+ 't',
+ 'n',
+ n_,
+ m_,
+ k_,
+ 1.0f,
+ ones_b, k_,
+ bias_b, k_,
+ 0.0f,
+ output_b, n_,
+ batch);
+
+ modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
+ input.data(),
+ offset.data(),
+ mask.data(),
+ batch, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ deformable_group,
+ columns.data());
+
+ long m = channels_out;
+ long n = height_out * width_out;
+ long k = channels * kernel_h * kernel_w;
+ THCudaBlas_SgemmBatched(state,
+ 'n',
+ 'n',
+ n,
+ m,
+ k,
+ 1.0f,
+ (const float **)columns_b, n,
+ weight_b, k,
+ 1.0f,
+ output_b, n,
+ batch);
+
+ THCudaFree(state, input_b);
+ THCudaFree(state, output_b);
+ THCudaFree(state, columns_b);
+ THCudaFree(state, ones_b);
+ THCudaFree(state, weight_b);
+ THCudaFree(state, bias_b);
+ return output;
+}
+
+__global__ void createBatchGemmBufferBackward(
+ float **grad_output_b,
+ float **columns_b,
+ float **ones_b,
+ float **weight_b,
+ float **grad_weight_b,
+ float **grad_bias_b,
+ float *grad_output,
+ float *columns,
+ float *ones,
+ float *weight,
+ float *grad_weight,
+ float *grad_bias,
+ const int grad_output_stride,
+ const int columns_stride,
+ const int ones_stride,
+ const int num_batches)
+{
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < num_batches)
+ {
+ grad_output_b[idx] = grad_output + idx * grad_output_stride;
+ columns_b[idx] = columns + idx * columns_stride;
+ ones_b[idx] = ones + idx * ones_stride;
+
+ // share weights and bias within a Mini-Batch
+ weight_b[idx] = weight;
+ grad_weight_b[idx] = grad_weight;
+ grad_bias_b[idx] = grad_bias;
+ }
+}
+
+std::vector dcn_v2_cuda_backward(const at::Tensor &input,
+ const at::Tensor &weight,
+ const at::Tensor &bias,
+ const at::Tensor &offset,
+ const at::Tensor &mask,
+ const at::Tensor &grad_output,
+ int kernel_h, int kernel_w,
+ int stride_h, int stride_w,
+ int pad_h, int pad_w,
+ int dilation_h, int dilation_w,
+ int deformable_group)
+{
+
+ THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous");
+ THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous");
+
+ AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
+ AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
+ AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
+ AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
+ AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_out = weight.size(0);
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+
+ AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
+ "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
+
+ AT_ASSERTM(channels == channels_kernel,
+ "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
+
+ const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ auto ones = at::ones({height_out, width_out}, input.options());
+ auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
+ auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
+
+ auto grad_input = at::zeros_like(input);
+ auto grad_weight = at::zeros_like(weight);
+ auto grad_bias = at::zeros_like(bias);
+ auto grad_offset = at::zeros_like(offset);
+ auto grad_mask = at::zeros_like(mask);
+
+ using scalar_t = float;
+
+ for (int b = 0; b < batch; b++)
+ {
+ auto input_n = input.select(0, b);
+ auto offset_n = offset.select(0, b);
+ auto mask_n = mask.select(0, b);
+ auto grad_output_n = grad_output.select(0, b);
+ auto grad_input_n = grad_input.select(0, b);
+ auto grad_offset_n = grad_offset.select(0, b);
+ auto grad_mask_n = grad_mask.select(0, b);
+
+ long m = channels * kernel_h * kernel_w;
+ long n = height_out * width_out;
+ long k = channels_out;
+
+ THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f,
+ grad_output_n.data(), n,
+ weight.data(), m, 0.0f,
+ columns.data(), n);
+
+ // gradient w.r.t. input coordinate data
+ modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state),
+ columns.data(),
+ input_n.data(),
+ offset_n.data(),
+ mask_n.data(),
+ 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group,
+ grad_offset_n.data(),
+ grad_mask_n.data());
+ // gradient w.r.t. input data
+ modulated_deformable_col2im_cuda(THCState_getCurrentStream(state),
+ columns.data(),
+ offset_n.data(),
+ mask_n.data(),
+ 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group,
+ grad_input_n.data());
+
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and group
+ modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
+ input_n.data(),
+ offset_n.data(),
+ mask_n.data(),
+ 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group,
+ columns.data());
+
+ long m_ = channels_out;
+ long n_ = channels * kernel_h * kernel_w;
+ long k_ = height_out * width_out;
+
+ THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f,
+ columns.data(), k_,
+ grad_output_n.data(), k_, 1.0f,
+ grad_weight.data(), n_);
+
+ // gradient w.r.t. bias
+ // long m_ = channels_out;
+ // long k__ = height_out * width_out;
+ THCudaBlas_Sgemv(state,
+ 't',
+ k_, m_, 1.0f,
+ grad_output_n.data(), k_,
+ ones.data(), 1, 1.0f,
+ grad_bias.data(), 1);
+ }
+
+ return {
+ grad_input, grad_offset, grad_mask, grad_weight, grad_bias
+ };
+}
\ No newline at end of file
diff --git a/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu b/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu
new file mode 100644
index 000000000..4183793ba
--- /dev/null
+++ b/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu
@@ -0,0 +1,402 @@
+#include "dcn_v2_im2col_cuda.h"
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N)
+{
+ return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
+}
+
+
+__device__ float dmcn_im2col_bilinear(const float *bottom_data, const int data_width,
+ const int height, const int width, float h, float w)
+{
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ float lh = h - h_low;
+ float lw = w - w_low;
+ float hh = 1 - lh, hw = 1 - lw;
+
+ float v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ float v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ float v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ float v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+__device__ float dmcn_get_gradient_weight(float argmax_h, float argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ float weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+__device__ float dmcn_get_coordinate_weight(float argmax_h, float argmax_w,
+ const int height, const int width, const float *im_data,
+ const int data_width, const int bp_dir)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ float weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
+ const float *data_im, const float *data_offset, const float *data_mask,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ float *data_col)
+{
+ // launch channels * batch_size * height_col * width_col cores
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)
+ // here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis
+
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ // const int b_col = (index / width_col / height_col) % batch_size;
+ const int b_col = (index / width_col / height_col / num_channels) % batch_size;
+ // const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_im = (index / width_col / height_col) % num_channels;
+ // const int c_col = c_im * kernel_h * kernel_w;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+
+ // float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col;
+ //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
+ const float offset_h = data_offset_ptr[data_offset_h_ptr];
+ const float offset_w = data_offset_ptr[data_offset_w_ptr];
+ const float mask = data_mask_ptr[data_mask_hw_ptr];
+ float val = static_cast(0);
+ const float h_im = h_in + i * dilation_h + offset_h;
+ const float w_im = w_in + j * dilation_w + offset_w;
+ //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const float map_h = i * dilation_h + offset_h;
+ //const float map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val * mask;
+ // data_col_ptr += batch_size * height_col * width_col;
+ data_col_ptr += height_col * width_col;
+ }
+ }
+ }
+}
+
+__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
+ const float *data_col, const float *data_offset, const float *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ float *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
+ const float offset_h = data_offset_ptr[data_offset_h_ptr];
+ const float offset_w = data_offset_ptr[data_offset_w_ptr];
+ const float mask = data_mask_ptr[data_mask_hw_ptr];
+ const float cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const float cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const float cur_top_grad = data_col[index] * mask;
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ float weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
+ const float *data_col, const float *data_im,
+ const float *data_offset, const float *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ float *grad_offset, float *grad_mask)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ float val = 0, mval = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
+ const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
+ const float offset_h = data_offset_ptr[data_offset_h_ptr];
+ const float offset_w = data_offset_ptr[data_offset_w_ptr];
+ const float mask = data_mask_ptr[data_mask_hw_ptr];
+ float inv_h = h_in + i * dilation_h + offset_h;
+ float inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ else
+ {
+ mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
+ }
+ const float weight = dmcn_get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos] * mask;
+ cnt += 1;
+ }
+ // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
+ grad_offset[index] = val;
+ if (offset_c % 2 == 0)
+ // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
+ grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
+ }
+}
+
+void modulated_deformable_im2col_cuda(cudaStream_t stream,
+ const float* data_im, const float* data_offset, const float* data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, float* data_col) {
+ // num_axes should be smaller than block size
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * batch_size * height_col * width_col;
+ modulated_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, channels, deformable_group, height_col, width_col, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+void modulated_deformable_col2im_cuda(cudaStream_t stream,
+ const float* data_col, const float* data_offset, const float* data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, float* grad_im){
+
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
+ modulated_deformable_col2im_gpu_kernel
+ <<>>(
+ num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, deformable_group, height_col, width_col, grad_im);
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,
+ const float* data_col, const float* data_im, const float* data_offset, const float* data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group,
+ float* grad_offset, float* grad_mask) {
+ const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
+ const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
+ modulated_deformable_col2im_coord_gpu_kernel
+ <<>>(
+ num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
+ grad_offset, grad_mask);
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
\ No newline at end of file
diff --git a/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.h b/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.h
new file mode 100644
index 000000000..c85683198
--- /dev/null
+++ b/external/DCNv2/src/cuda/dcn_v2_im2col_cuda.h
@@ -0,0 +1,101 @@
+
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.h
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1811.11168
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
+ */
+
+/***************** Adapted by Charles Shang *********************/
+
+#ifndef DCN_V2_IM2COL_CUDA
+#define DCN_V2_IM2COL_CUDA
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif
+
+ void modulated_deformable_im2col_cuda(cudaStream_t stream,
+ const float *data_im, const float *data_offset, const float *data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, float *data_col);
+
+ void modulated_deformable_col2im_cuda(cudaStream_t stream,
+ const float *data_col, const float *data_offset, const float *data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, float *grad_im);
+
+ void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,
+ const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group,
+ float *grad_offset, float *grad_mask);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
\ No newline at end of file
diff --git a/external/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu b/external/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu
new file mode 100644
index 000000000..07b438e19
--- /dev/null
+++ b/external/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu
@@ -0,0 +1,419 @@
+/*!
+ * Copyright (c) 2017 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file deformable_psroi_pooling.cu
+ * \brief
+ * \author Yi Li, Guodong Zhang, Jifeng Dai
+*/
+/***************** Adapted by Charles Shang *********************/
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N)
+{
+ return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
+}
+
+template
+__device__ T bilinear_interp(
+ const T *data,
+ const T x,
+ const T y,
+ const int width,
+ const int height)
+{
+ int x1 = floor(x);
+ int x2 = ceil(x);
+ int y1 = floor(y);
+ int y2 = ceil(y);
+ T dist_x = static_cast(x - x1);
+ T dist_y = static_cast(y - y1);
+ T value11 = data[y1 * width + x1];
+ T value12 = data[y2 * width + x1];
+ T value21 = data[y1 * width + x2];
+ T value22 = data[y2 * width + x2];
+ T value = (1 - dist_x) * (1 - dist_y) * value11 +
+ (1 - dist_x) * dist_y * value12 +
+ dist_x * (1 - dist_y) * value21 +
+ dist_x * dist_y * value22;
+ return value;
+}
+
+template
+__global__ void DeformablePSROIPoolForwardKernel(
+ const int count,
+ const T *bottom_data,
+ const T spatial_scale,
+ const int channels,
+ const int height, const int width,
+ const int pooled_height, const int pooled_width,
+ const T *bottom_rois, const T *bottom_trans,
+ const int no_trans,
+ const T trans_std,
+ const int sample_per_part,
+ const int output_dim,
+ const int group_size,
+ const int part_size,
+ const int num_classes,
+ const int channels_each_class,
+ T *top_data,
+ T *top_count)
+{
+ CUDA_KERNEL_LOOP(index, count)
+ {
+ // The output is in order (n, ctop, ph, pw)
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int ctop = (index / pooled_width / pooled_height) % output_dim;
+ int n = index / pooled_width / pooled_height / output_dim;
+
+ // [start, end) interval for spatial sampling
+ const T *offset_bottom_rois = bottom_rois + n * 5;
+ int roi_batch_ind = offset_bottom_rois[0];
+ T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
+ T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
+ T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
+ T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
+
+ // Force too small ROIs to be 1x1
+ T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
+ T roi_height = max(roi_end_h - roi_start_h, 0.1);
+
+ // Compute w and h at bottom
+ T bin_size_h = roi_height / static_cast(pooled_height);
+ T bin_size_w = roi_width / static_cast(pooled_width);
+
+ T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
+ T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
+
+ int part_h = floor(static_cast(ph) / pooled_height * part_size);
+ int part_w = floor(static_cast(pw) / pooled_width * part_size);
+ int class_id = ctop / channels_each_class;
+ T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
+ T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
+
+ T wstart = static_cast(pw) * bin_size_w + roi_start_w;
+ wstart += trans_x * roi_width;
+ T hstart = static_cast(ph) * bin_size_h + roi_start_h;
+ hstart += trans_y * roi_height;
+
+ T sum = 0;
+ int count = 0;
+ int gw = floor(static_cast(pw) * group_size / pooled_width);
+ int gh = floor(static_cast(ph) * group_size / pooled_height);
+ gw = min(max(gw, 0), group_size - 1);
+ gh = min(max(gh, 0), group_size - 1);
+
+ const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
+ for (int ih = 0; ih < sample_per_part; ih++)
+ {
+ for (int iw = 0; iw < sample_per_part; iw++)
+ {
+ T w = wstart + iw * sub_bin_size_w;
+ T h = hstart + ih * sub_bin_size_h;
+ // bilinear interpolation
+ if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
+ {
+ continue;
+ }
+ w = min(max(w, 0.), width - 1.);
+ h = min(max(h, 0.), height - 1.);
+ int c = (ctop * group_size + gh) * group_size + gw;
+ T val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
+ sum += val;
+ count++;
+ }
+ }
+ top_data[index] = count == 0 ? static_cast(0) : sum / count;
+ top_count[index] = count;
+ }
+}
+
+template
+__global__ void DeformablePSROIPoolBackwardAccKernel(
+ const int count,
+ const T *top_diff,
+ const T *top_count,
+ const int num_rois,
+ const T spatial_scale,
+ const int channels,
+ const int height, const int width,
+ const int pooled_height, const int pooled_width,
+ const int output_dim,
+ T *bottom_data_diff, T *bottom_trans_diff,
+ const T *bottom_data,
+ const T *bottom_rois,
+ const T *bottom_trans,
+ const int no_trans,
+ const T trans_std,
+ const int sample_per_part,
+ const int group_size,
+ const int part_size,
+ const int num_classes,
+ const int channels_each_class)
+{
+ CUDA_KERNEL_LOOP(index, count)
+ {
+ // The output is in order (n, ctop, ph, pw)
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int ctop = (index / pooled_width / pooled_height) % output_dim;
+ int n = index / pooled_width / pooled_height / output_dim;
+
+ // [start, end) interval for spatial sampling
+ const T *offset_bottom_rois = bottom_rois + n * 5;
+ int roi_batch_ind = offset_bottom_rois[0];
+ T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
+ T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
+ T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
+ T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
+
+ // Force too small ROIs to be 1x1
+ T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
+ T roi_height = max(roi_end_h - roi_start_h, 0.1);
+
+ // Compute w and h at bottom
+ T bin_size_h = roi_height / static_cast(pooled_height);
+ T bin_size_w = roi_width / static_cast(pooled_width);
+
+ T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
+ T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
+
+ int part_h = floor(static_cast(ph) / pooled_height * part_size);
+ int part_w = floor(static_cast(pw) / pooled_width * part_size);
+ int class_id = ctop / channels_each_class;
+ T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
+ T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
+
+ T wstart = static_cast(pw) * bin_size_w + roi_start_w;
+ wstart += trans_x * roi_width;
+ T hstart = static_cast(ph) * bin_size_h + roi_start_h;
+ hstart += trans_y * roi_height;
+
+ if (top_count[index] <= 0)
+ {
+ continue;
+ }
+ T diff_val = top_diff[index] / top_count[index];
+ const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
+ T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
+ int gw = floor(static_cast(pw) * group_size / pooled_width);
+ int gh = floor(static_cast(ph) * group_size / pooled_height);
+ gw = min(max(gw, 0), group_size - 1);
+ gh = min(max(gh, 0), group_size - 1);
+
+ for (int ih = 0; ih < sample_per_part; ih++)
+ {
+ for (int iw = 0; iw < sample_per_part; iw++)
+ {
+ T w = wstart + iw * sub_bin_size_w;
+ T h = hstart + ih * sub_bin_size_h;
+ // bilinear interpolation
+ if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
+ {
+ continue;
+ }
+ w = min(max(w, 0.), width - 1.);
+ h = min(max(h, 0.), height - 1.);
+ int c = (ctop * group_size + gh) * group_size + gw;
+ // backward on feature
+ int x0 = floor(w);
+ int x1 = ceil(w);
+ int y0 = floor(h);
+ int y1 = ceil(h);
+ T dist_x = w - x0, dist_y = h - y0;
+ T q00 = (1 - dist_x) * (1 - dist_y);
+ T q01 = (1 - dist_x) * dist_y;
+ T q10 = dist_x * (1 - dist_y);
+ T q11 = dist_x * dist_y;
+ int bottom_index_base = c * height * width;
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
+
+ if (no_trans)
+ {
+ continue;
+ }
+ T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
+ T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
+ T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
+ T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
+ T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
+ diff_x *= roi_width;
+ T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
+ diff_y *= roi_height;
+
+ atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
+ atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
+ }
+ }
+ }
+}
+
+std::tuple
+dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,
+ const at::Tensor &bbox,
+ const at::Tensor &trans,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std)
+{
+ AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
+ AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor");
+ AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+ const int channels_trans = no_trans ? 2 : trans.size(1);
+ const int num_bbox = bbox.size(0);
+
+ AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
+ auto pooled_height = pooled_size;
+ auto pooled_width = pooled_size;
+
+ auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
+ long out_size = num_bbox * output_dim * pooled_height * pooled_width;
+ auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
+
+ const int num_classes = no_trans ? 1 : channels_trans / 2;
+ const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ if (out.numel() == 0)
+ {
+ THCudaCheck(cudaGetLastError());
+ return std::make_tuple(out, top_count);
+ }
+
+ dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
+ dim3 block(512);
+
+ AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] {
+ DeformablePSROIPoolForwardKernel<<>>(
+ out_size,
+ input.contiguous().data(),
+ spatial_scale,
+ channels,
+ height, width,
+ pooled_height,
+ pooled_width,
+ bbox.contiguous().data(),
+ trans.contiguous().data(),
+ no_trans,
+ trans_std,
+ sample_per_part,
+ output_dim,
+ group_size,
+ part_size,
+ num_classes,
+ channels_each_class,
+ out.data(),
+ top_count.data());
+ });
+ THCudaCheck(cudaGetLastError());
+ return std::make_tuple(out, top_count);
+}
+
+std::tuple
+dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,
+ const at::Tensor &input,
+ const at::Tensor &bbox,
+ const at::Tensor &trans,
+ const at::Tensor &top_count,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std)
+{
+ AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor");
+ AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
+ AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor");
+ AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");
+ AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor");
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+ const int channels_trans = no_trans ? 2 : trans.size(1);
+ const int num_bbox = bbox.size(0);
+
+ AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
+ auto pooled_height = pooled_size;
+ auto pooled_width = pooled_size;
+ long out_size = num_bbox * output_dim * pooled_height * pooled_width;
+ const int num_classes = no_trans ? 1 : channels_trans / 2;
+ const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+
+ auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());
+ auto trans_grad = at::zeros_like(trans);
+
+ if (input_grad.numel() == 0)
+ {
+ THCudaCheck(cudaGetLastError());
+ return std::make_tuple(input_grad, trans_grad);
+ }
+
+ dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
+ dim3 block(512);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] {
+ DeformablePSROIPoolBackwardAccKernel<<>>(
+ out_size,
+ out_grad.contiguous().data(),
+ top_count.contiguous().data(),
+ num_bbox,
+ spatial_scale,
+ channels,
+ height,
+ width,
+ pooled_height,
+ pooled_width,
+ output_dim,
+ input_grad.contiguous().data(),
+ trans_grad.contiguous().data(),
+ input.contiguous().data(),
+ bbox.contiguous().data(),
+ trans.contiguous().data(),
+ no_trans,
+ trans_std,
+ sample_per_part,
+ group_size,
+ part_size,
+ num_classes,
+ channels_each_class);
+ });
+ THCudaCheck(cudaGetLastError());
+ return std::make_tuple(input_grad, trans_grad);
+}
\ No newline at end of file
diff --git a/external/DCNv2/src/cuda/vision.h b/external/DCNv2/src/cuda/vision.h
new file mode 100644
index 000000000..e42a2a79a
--- /dev/null
+++ b/external/DCNv2/src/cuda/vision.h
@@ -0,0 +1,60 @@
+#pragma once
+#include
+
+at::Tensor
+dcn_v2_cuda_forward(const at::Tensor &input,
+ const at::Tensor &weight,
+ const at::Tensor &bias,
+ const at::Tensor &offset,
+ const at::Tensor &mask,
+ const int kernel_h,
+ const int kernel_w,
+ const int stride_h,
+ const int stride_w,
+ const int pad_h,
+ const int pad_w,
+ const int dilation_h,
+ const int dilation_w,
+ const int deformable_group);
+
+std::vector
+dcn_v2_cuda_backward(const at::Tensor &input,
+ const at::Tensor &weight,
+ const at::Tensor &bias,
+ const at::Tensor &offset,
+ const at::Tensor &mask,
+ const at::Tensor &grad_output,
+ int kernel_h, int kernel_w,
+ int stride_h, int stride_w,
+ int pad_h, int pad_w,
+ int dilation_h, int dilation_w,
+ int deformable_group);
+
+
+std::tuple
+dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,
+ const at::Tensor &bbox,
+ const at::Tensor &trans,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std);
+
+std::tuple
+dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,
+ const at::Tensor &input,
+ const at::Tensor &bbox,
+ const at::Tensor &trans,
+ const at::Tensor &top_count,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std);
\ No newline at end of file
diff --git a/external/DCNv2/src/dcn_v2.h b/external/DCNv2/src/dcn_v2.h
new file mode 100644
index 000000000..23f5caf2f
--- /dev/null
+++ b/external/DCNv2/src/dcn_v2.h
@@ -0,0 +1,145 @@
+#pragma once
+
+#include "cpu/vision.h"
+
+#ifdef WITH_CUDA
+#include "cuda/vision.h"
+#endif
+
+at::Tensor
+dcn_v2_forward(const at::Tensor &input,
+ const at::Tensor &weight,
+ const at::Tensor &bias,
+ const at::Tensor &offset,
+ const at::Tensor &mask,
+ const int kernel_h,
+ const int kernel_w,
+ const int stride_h,
+ const int stride_w,
+ const int pad_h,
+ const int pad_w,
+ const int dilation_h,
+ const int dilation_w,
+ const int deformable_group)
+{
+ if (input.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return dcn_v2_cuda_forward(input, weight, bias, offset, mask,
+ kernel_h, kernel_w,
+ stride_h, stride_w,
+ pad_h, pad_w,
+ dilation_h, dilation_w,
+ deformable_group);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+dcn_v2_backward(const at::Tensor &input,
+ const at::Tensor &weight,
+ const at::Tensor &bias,
+ const at::Tensor &offset,
+ const at::Tensor &mask,
+ const at::Tensor &grad_output,
+ int kernel_h, int kernel_w,
+ int stride_h, int stride_w,
+ int pad_h, int pad_w,
+ int dilation_h, int dilation_w,
+ int deformable_group)
+{
+ if (input.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return dcn_v2_cuda_backward(input,
+ weight,
+ bias,
+ offset,
+ mask,
+ grad_output,
+ kernel_h, kernel_w,
+ stride_h, stride_w,
+ pad_h, pad_w,
+ dilation_h, dilation_w,
+ deformable_group);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::tuple
+dcn_v2_psroi_pooling_forward(const at::Tensor &input,
+ const at::Tensor &bbox,
+ const at::Tensor &trans,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std)
+{
+ if (input.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return dcn_v2_psroi_pooling_cuda_forward(input,
+ bbox,
+ trans,
+ no_trans,
+ spatial_scale,
+ output_dim,
+ group_size,
+ pooled_size,
+ part_size,
+ sample_per_part,
+ trans_std);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::tuple
+dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad,
+ const at::Tensor &input,
+ const at::Tensor &bbox,
+ const at::Tensor &trans,
+ const at::Tensor &top_count,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std)
+{
+ if (input.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return dcn_v2_psroi_pooling_cuda_backward(out_grad,
+ input,
+ bbox,
+ trans,
+ top_count,
+ no_trans,
+ spatial_scale,
+ output_dim,
+ group_size,
+ pooled_size,
+ part_size,
+ sample_per_part,
+ trans_std);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
\ No newline at end of file
diff --git a/external/DCNv2/src/vision.cpp b/external/DCNv2/src/vision.cpp
new file mode 100644
index 000000000..ff54233e0
--- /dev/null
+++ b/external/DCNv2/src/vision.cpp
@@ -0,0 +1,9 @@
+
+#include "dcn_v2.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward");
+ m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward");
+ m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward");
+ m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward");
+}
diff --git a/external/DCNv2/test.py b/external/DCNv2/test.py
new file mode 100644
index 000000000..3bd5bd223
--- /dev/null
+++ b/external/DCNv2/test.py
@@ -0,0 +1,270 @@
+#!/usr/bin/env python
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import time
+import torch
+import torch.nn as nn
+from torch.autograd import gradcheck
+
+from dcn_v2 import dcn_v2_conv, DCNv2, DCN
+from dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling
+
+deformable_groups = 1
+N, inC, inH, inW = 2, 2, 4, 4
+outC = 2
+kH, kW = 3, 3
+
+
+def conv_identify(weight, bias):
+ weight.data.zero_()
+ bias.data.zero_()
+ o, i, h, w = weight.shape
+ y = h//2
+ x = w//2
+ for p in range(i):
+ for q in range(o):
+ if p == q:
+ weight.data[q, p, y, x] = 1.0
+
+
+def check_zero_offset():
+ conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW,
+ kernel_size=(kH, kW),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=True).cuda()
+
+ conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW,
+ kernel_size=(kH, kW),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=True).cuda()
+
+ dcn_v2 = DCNv2(inC, outC, (kH, kW),
+ stride=1, padding=1, dilation=1,
+ deformable_groups=deformable_groups).cuda()
+
+ conv_offset.weight.data.zero_()
+ conv_offset.bias.data.zero_()
+ conv_mask.weight.data.zero_()
+ conv_mask.bias.data.zero_()
+ conv_identify(dcn_v2.weight, dcn_v2.bias)
+
+ input = torch.randn(N, inC, inH, inW).cuda()
+ offset = conv_offset(input)
+ mask = conv_mask(input)
+ mask = torch.sigmoid(mask)
+ output = dcn_v2(input, offset, mask)
+ output *= 2
+ d = (input - output).abs().max()
+ if d < 1e-10:
+ print('Zero offset passed')
+ else:
+ print('Zero offset failed')
+ print(input)
+ print(output)
+
+def check_gradient_dconv():
+
+ input = torch.rand(N, inC, inH, inW).cuda() * 0.01
+ input.requires_grad = True
+
+ offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2
+ # offset.data.zero_()
+ # offset.data -= 0.5
+ offset.requires_grad = True
+
+ mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda()
+ # mask.data.zero_()
+ mask.requires_grad = True
+ mask = torch.sigmoid(mask)
+
+ weight = torch.randn(outC, inC, kH, kW).cuda()
+ weight.requires_grad = True
+
+ bias = torch.rand(outC).cuda()
+ bias.requires_grad = True
+
+ stride = 1
+ padding = 1
+ dilation = 1
+
+ print('check_gradient_dconv: ',
+ gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias,
+ stride, padding, dilation, deformable_groups),
+ eps=1e-3, atol=1e-4, rtol=1e-2))
+
+
+def check_pooling_zero_offset():
+
+ input = torch.randn(2, 16, 64, 64).cuda().zero_()
+ input[0, :, 16:26, 16:26] = 1.
+ input[1, :, 10:20, 20:30] = 2.
+ rois = torch.tensor([
+ [0, 65, 65, 103, 103],
+ [1, 81, 41, 119, 79],
+ ]).cuda().float()
+ pooling = DCNv2Pooling(spatial_scale=1.0 / 4,
+ pooled_size=7,
+ output_dim=16,
+ no_trans=True,
+ group_size=1,
+ trans_std=0.0).cuda()
+
+ out = pooling(input, rois, input.new())
+ s = ', '.join(['%f' % out[i, :, :, :].mean().item()
+ for i in range(rois.shape[0])])
+ print(s)
+
+ dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,
+ pooled_size=7,
+ output_dim=16,
+ no_trans=False,
+ group_size=1,
+ trans_std=0.0).cuda()
+ offset = torch.randn(20, 2, 7, 7).cuda().zero_()
+ dout = dpooling(input, rois, offset)
+ s = ', '.join(['%f' % dout[i, :, :, :].mean().item()
+ for i in range(rois.shape[0])])
+ print(s)
+
+
+def check_gradient_dpooling():
+ input = torch.randn(2, 3, 5, 5).cuda() * 0.01
+ N = 4
+ batch_inds = torch.randint(2, (N, 1)).cuda().float()
+ x = torch.rand((N, 1)).cuda().float() * 15
+ y = torch.rand((N, 1)).cuda().float() * 15
+ w = torch.rand((N, 1)).cuda().float() * 10
+ h = torch.rand((N, 1)).cuda().float() * 10
+ rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
+ offset = torch.randn(N, 2, 3, 3).cuda()
+ input.requires_grad = True
+ offset.requires_grad = True
+
+ spatial_scale = 1.0 / 4
+ pooled_size = 3
+ output_dim = 3
+ no_trans = 0
+ group_size = 1
+ trans_std = 0.0
+ sample_per_part = 4
+ part_size = pooled_size
+
+ print('check_gradient_dpooling:',
+ gradcheck(dcn_v2_pooling, (input, rois, offset,
+ spatial_scale,
+ pooled_size,
+ output_dim,
+ no_trans,
+ group_size,
+ part_size,
+ sample_per_part,
+ trans_std),
+ eps=1e-4))
+
+
+def example_dconv():
+ input = torch.randn(2, 64, 128, 128).cuda()
+ # wrap all things (offset and mask) in DCN
+ dcn = DCN(64, 64, kernel_size=(3, 3), stride=1,
+ padding=1, deformable_groups=2).cuda()
+ # print(dcn.weight.shape, input.shape)
+ output = dcn(input)
+ targert = output.new(*output.size())
+ targert.data.uniform_(-0.01, 0.01)
+ error = (targert - output).mean()
+ error.backward()
+ print(output.shape)
+
+
+def example_dpooling():
+ input = torch.randn(2, 32, 64, 64).cuda()
+ batch_inds = torch.randint(2, (20, 1)).cuda().float()
+ x = torch.randint(256, (20, 1)).cuda().float()
+ y = torch.randint(256, (20, 1)).cuda().float()
+ w = torch.randint(64, (20, 1)).cuda().float()
+ h = torch.randint(64, (20, 1)).cuda().float()
+ rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
+ offset = torch.randn(20, 2, 7, 7).cuda()
+ input.requires_grad = True
+ offset.requires_grad = True
+
+ # normal roi_align
+ pooling = DCNv2Pooling(spatial_scale=1.0 / 4,
+ pooled_size=7,
+ output_dim=32,
+ no_trans=True,
+ group_size=1,
+ trans_std=0.1).cuda()
+
+ # deformable pooling
+ dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,
+ pooled_size=7,
+ output_dim=32,
+ no_trans=False,
+ group_size=1,
+ trans_std=0.1).cuda()
+
+ out = pooling(input, rois, offset)
+ dout = dpooling(input, rois, offset)
+ print(out.shape)
+ print(dout.shape)
+
+ target_out = out.new(*out.size())
+ target_out.data.uniform_(-0.01, 0.01)
+ target_dout = dout.new(*dout.size())
+ target_dout.data.uniform_(-0.01, 0.01)
+ e = (target_out - out).mean()
+ e.backward()
+ e = (target_dout - dout).mean()
+ e.backward()
+
+
+def example_mdpooling():
+ input = torch.randn(2, 32, 64, 64).cuda()
+ input.requires_grad = True
+ batch_inds = torch.randint(2, (20, 1)).cuda().float()
+ x = torch.randint(256, (20, 1)).cuda().float()
+ y = torch.randint(256, (20, 1)).cuda().float()
+ w = torch.randint(64, (20, 1)).cuda().float()
+ h = torch.randint(64, (20, 1)).cuda().float()
+ rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
+
+ # mdformable pooling (V2)
+ dpooling = DCNPooling(spatial_scale=1.0 / 4,
+ pooled_size=7,
+ output_dim=32,
+ no_trans=False,
+ group_size=1,
+ trans_std=0.1,
+ deform_fc_dim=1024).cuda()
+
+ dout = dpooling(input, rois)
+ target = dout.new(*dout.size())
+ target.data.uniform_(-0.1, 0.1)
+ error = (target - dout).mean()
+ error.backward()
+ print(dout.shape)
+
+
+if __name__ == '__main__':
+
+ example_dconv()
+ example_dpooling()
+ example_mdpooling()
+
+ check_pooling_zero_offset()
+ # zero offset check
+ if inC == outC:
+ check_zero_offset()
+
+ check_gradient_dpooling()
+ check_gradient_dconv()
+ # """
+ # ****** Note: backward is not reentrant error may not be a serious problem,
+ # ****** since the max error is less than 1e-7,
+ # ****** Still looking for what trigger this problem
+ # """
diff --git a/layers/mask_score.py b/layers/mask_score.py
new file mode 100644
index 000000000..ee912de42
--- /dev/null
+++ b/layers/mask_score.py
@@ -0,0 +1,36 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from data.config import cfg, mask_type
+from torch.autograd import Variable
+import torch.backends.cudnn as cudnn
+from .box_utils import match, crop
+import cv2
+from datetime import datetime
+import os
+
+from utils.functions import make_net
+
+class FastMaskIoUNet(nn.Module):
+
+ def __init__(self):
+ super(FastMaskIoUNet, self).__init__()
+ input_channels = 1
+ self.maskiou_net, _ = make_net(input_channels, cfg.maskiou_net, include_last_relu=True)
+
+ def forward(self, x, target=None):
+ cudnn.benchmark = False
+ x = self.maskiou_net(x)
+ cudnn.benchmark = True
+ # global pooling
+ maskiou_p = F.max_pool2d(x, kernel_size=x.size()[2:]).squeeze(-1).squeeze(-1)
+
+ if self.training:
+ maskiou_t = target[0]
+ label_t = target[1]
+ label_t = label_t[:, None]
+ maskiou_p = torch.gather(maskiou_p, dim=1, index=label_t).squeeze()
+ loss_i = F.smooth_l1_loss(maskiou_p, maskiou_t, reduction='mean')
+ return loss_i * cfg.maskiou_alpha
+ else:
+ return maskiou_p
\ No newline at end of file
diff --git a/layers/modules/multibox_loss.py b/layers/modules/multibox_loss.py
index a83ffade6..d5b7148a6 100644
--- a/layers/modules/multibox_loss.py
+++ b/layers/modules/multibox_loss.py
@@ -156,8 +156,13 @@ def forward(self, predictions, targets, masks, num_crowds):
else:
losses['M'] = self.direct_mask_loss(pos_idx, idx_t, loc_data, mask_data, priors, masks)
elif cfg.mask_type == mask_type.lincomb:
- losses.update(self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data))
-
+ ret = self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels)
+ if cfg.use_maskiou:
+ loss, maskiou_targets = ret
+ else:
+ loss = ret
+ losses.update(loss)
+
if cfg.mask_proto_loss is not None:
if cfg.mask_proto_loss == 'l1':
losses['P'] = torch.mean(torch.abs(proto_data)) / self.l1_expected_area * self.l1_alpha
@@ -201,7 +206,10 @@ def forward(self, predictions, targets, masks, num_crowds):
# - D: Coefficient Diversity Loss
# - E: Class Existence Loss
# - S: Semantic Segmentation Loss
- return losses
+ if cfg.use_maskiou:
+ return losses, maskiou_targets
+ else:
+ return losses
def class_existence_loss(self, class_data, class_existence_t):
return cfg.class_existence_alpha * F.binary_cross_entropy_with_logits(class_data, class_existence_t, reduction='sum')
@@ -487,7 +495,7 @@ def coeff_diversity_loss(self, coeffs, instance_t):
return cfg.mask_proto_coeff_diversity_alpha * loss.sum() / num_pos
- def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, interpolation_mode='bilinear'):
+ def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'):
mask_h = proto_data.size(1)
mask_w = proto_data.size(2)
@@ -500,6 +508,10 @@ def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data,
loss_m = 0
loss_d = 0 # Coefficient diversity loss
+ maskiou_t_list = []
+ maskiou_net_input_list = []
+ label_t_list = []
+
for idx in range(mask_data.size(0)):
with torch.no_grad():
downsampled_masks = F.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w),
@@ -570,7 +582,8 @@ def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data,
mask_scores = mask_scores[select, :]
num_pos = proto_coef.size(0)
- mask_t = downsampled_masks[:, :, pos_idx_t]
+ mask_t = downsampled_masks[:, :, pos_idx_t]
+ label_t = labels[idx][pos_idx_t]
# Size: [mask_h, mask_w, num_pos]
pred_masks = proto_masks @ proto_coef.t()
@@ -611,10 +624,54 @@ def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data,
pre_loss *= old_num_pos / num_pos
loss_m += torch.sum(pre_loss)
+
+ if cfg.use_maskiou:
+ if cfg.remove_small_gt_mask > 0:
+ gt_mask_area = torch.sum(mask_t, dim=(0, 1))
+ select = gt_mask_area > cfg.remove_small_gt_mask
+
+ if torch.sum(select) < 1:
+ continue
+
+ pos_gt_box_t = pos_gt_box_t[select, :]
+ pred_masks = pred_masks[:, :, select]
+ mask_t = mask_t[:, :, select]
+ label_t = label_t[select]
+
+ maskiou_net_input = pred_masks.permute(2, 0, 1).contiguous().unsqueeze(1)
+ pred_masks = pred_masks.gt(0.5).float()
+ maskiou_t = self._mask_iou(pred_masks, mask_t)
+
+ maskiou_net_input_list.append(maskiou_net_input)
+ maskiou_t_list.append(maskiou_t)
+ label_t_list.append(label_t)
losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}
if cfg.mask_proto_coeff_diversity_loss:
losses['D'] = loss_d
+ if cfg.use_maskiou:
+ maskiou_t = torch.cat(maskiou_t_list)
+ label_t = torch.cat(label_t_list)
+ maskiou_net_input = torch.cat(maskiou_net_input_list)
+
+ num_samples = maskiou_t.size(0)
+ if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train:
+ perm = torch.randperm(num_samples)
+ select = perm[:cfg.masks_to_train]
+ maskiou_t = maskiou_t[select]
+ label_t = label_t[select]
+ maskiou_net_input = maskiou_net_input[select]
+
+ return losses, [maskiou_net_input, maskiou_t, label_t]
+
return losses
+
+ def _mask_iou(self, mask1, mask2):
+ intersection = torch.sum(mask1*mask2, dim=(0, 1))
+ area1 = torch.sum(mask1, dim=(0, 1))
+ area2 = torch.sum(mask2, dim=(0, 1))
+ union = (area1 + area2) - intersection
+ ret = intersection / union
+ return ret
\ No newline at end of file
diff --git a/layers/output_utils.py b/layers/output_utils.py
index a6342e266..1f7b2550c 100644
--- a/layers/output_utils.py
+++ b/layers/output_utils.py
@@ -13,7 +13,7 @@
from .box_utils import crop, sanitize_coordinates
def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
- visualize_lincomb=False, crop_masks=True, score_threshold=0):
+ visualize_lincomb=False, crop_masks=True, score_threshold=0, maskiou_net=None):
"""
Postprocesses the output of Yolact on testing mode into a format that makes sense,
accounting for all the possible configuration settings.
@@ -74,6 +74,17 @@ def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
# Permute into the correct output shape [num_dets, proto_h, proto_w]
masks = masks.permute(2, 0, 1).contiguous()
+ if cfg.use_maskiou:
+ with timer.env('maskiou_net'):
+ with torch.no_grad():
+ maskiou_p = maskiou_net(masks.unsqueeze(1))
+ maskiou_p = torch.gather(maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1)
+ if cfg.rescore_mask:
+ if cfg.rescore_bbox:
+ scores = scores * maskiou_p
+ else:
+ scores = [scores, scores * maskiou_p]
+
# Scale masks up to the full image
masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0)
diff --git a/train.py b/train.py
index 87d0ee5a4..b664317e4 100644
--- a/train.py
+++ b/train.py
@@ -115,7 +115,7 @@ def replace(name):
print('Per-GPU batch size is less than the recommended limit for batch norm. Disabling batch norm.')
cfg.freeze_bn = True
-loss_types = ['B', 'C', 'M', 'P', 'D', 'E', 'S']
+loss_types = ['B', 'C', 'M', 'P', 'D', 'E', 'S', 'I']
if torch.cuda.is_available():
if args.cuda:
@@ -138,10 +138,20 @@ def __init__(self, net:Yolact, criterion:MultiBoxLoss):
self.net = net
self.criterion = criterion
+ if cfg.use_maskiou:
+ self.maskiou_net = net.get_maskiou_net()
def forward(self, images, targets, masks, num_crowds):
preds = self.net(images)
- return self.criterion(preds, targets, masks, num_crowds)
+
+ if cfg.use_maskiou:
+ losses, maskiou_targets = self.criterion(preds, targets, masks, num_crowds)
+ maskiou_net_input, maskiou_t, label_t = maskiou_targets
+ loss_i = self.maskiou_net(maskiou_net_input, target=[maskiou_t, label_t])
+ losses['I'] = loss_i
+ else:
+ losses = self.criterion(preds, targets, masks, num_crowds)
+ return losses
class CustomDataParallel(nn.DataParallel):
"""
diff --git a/utils/functions.py b/utils/functions.py
index 809060db8..3b7a4e45a 100644
--- a/utils/functions.py
+++ b/utils/functions.py
@@ -1,9 +1,10 @@
-
import torch
+import torch.nn as nn
import os
import math
from collections import deque
from pathlib import Path
+from layers.interpolate import InterpolateModule
class MovingAverage():
""" Keeps an average window of the specified number of items. """
@@ -158,3 +159,55 @@ def get_latest(save_folder, config):
max_name = path_name
return max_name
+
+def make_net(in_channels, conf, include_last_relu=True):
+ """
+ A helper function to take a config setting and turn it into a network.
+ Used by protonet and extrahead. Returns (network, out_channels)
+ """
+ def make_layer(layer_cfg):
+ nonlocal in_channels
+
+ # Possible patterns:
+ # ( 256, 3, {}) -> conv
+ # ( 256,-2, {}) -> deconv
+ # (None,-2, {}) -> bilinear interpolate
+ # ('cat',[],{}) -> concat the subnetworks in the list
+ #
+ # You know it would have probably been simpler just to adopt a 'c' 'd' 'u' naming scheme.
+ # Whatever, it's too late now.
+ if isinstance(layer_cfg[0], str):
+ layer_name = layer_cfg[0]
+
+ if layer_name == 'cat':
+ nets = [make_net(in_channels, x) for x in layer_cfg[1]]
+ layer = Concat([net[0] for net in nets], layer_cfg[2])
+ num_channels = sum([net[1] for net in nets])
+ else:
+ num_channels = layer_cfg[0]
+ kernel_size = layer_cfg[1]
+
+ if kernel_size > 0:
+ layer = nn.Conv2d(in_channels, num_channels, kernel_size, **layer_cfg[2])
+ else:
+ if num_channels is None:
+ layer = InterpolateModule(scale_factor=-kernel_size, mode='bilinear', align_corners=False, **layer_cfg[2])
+ else:
+ layer = nn.ConvTranspose2d(in_channels, num_channels, -kernel_size, **layer_cfg[2])
+
+ in_channels = num_channels if num_channels is not None else in_channels
+
+ # Don't return a ReLU layer if we're doing an upsample. This probably doesn't affect anything
+ # output-wise, but there's no need to go through a ReLU here.
+ # Commented out for backwards compatibility with previous models
+ # if num_channels is None:
+ # return [layer]
+ # else:
+ return [layer, nn.ReLU(inplace=True)]
+
+ # Use sum to concat together all the component layer lists
+ net = sum([make_layer(x) for x in conf], [])
+ if not include_last_relu:
+ net = net[:-1]
+
+ return nn.Sequential(*(net)), in_channels
\ No newline at end of file
diff --git a/yolact.py b/yolact.py
index a49d9a2d2..fb0d66c03 100644
--- a/yolact.py
+++ b/yolact.py
@@ -10,12 +10,13 @@
from data.config import cfg, mask_type
from layers import Detect
+from layers.mask_score import FastMaskIoUNet
from layers.interpolate import InterpolateModule
from backbone import construct_backbone
import torch.backends.cudnn as cudnn
from utils import timer
-from utils.functions import MovingAverage
+from utils.functions import MovingAverage, make_net
# This is required for Pytorch 1.0.1 on Windows to initialize Cuda on some driver versions.
# See the bug report here: https://github.com/pytorch/pytorch/issues/17108
@@ -42,60 +43,6 @@ def forward(self, x):
# Concat each along the channel dimension
return torch.cat([net(x) for net in self.nets], dim=1, **self.extra_params)
-
-
-def make_net(in_channels, conf, include_last_relu=True):
- """
- A helper function to take a config setting and turn it into a network.
- Used by protonet and extrahead. Returns (network, out_channels)
- """
- def make_layer(layer_cfg):
- nonlocal in_channels
-
- # Possible patterns:
- # ( 256, 3, {}) -> conv
- # ( 256,-2, {}) -> deconv
- # (None,-2, {}) -> bilinear interpolate
- # ('cat',[],{}) -> concat the subnetworks in the list
- #
- # You know it would have probably been simpler just to adopt a 'c' 'd' 'u' naming scheme.
- # Whatever, it's too late now.
- if isinstance(layer_cfg[0], str):
- layer_name = layer_cfg[0]
-
- if layer_name == 'cat':
- nets = [make_net(in_channels, x) for x in layer_cfg[1]]
- layer = Concat([net[0] for net in nets], layer_cfg[2])
- num_channels = sum([net[1] for net in nets])
- else:
- num_channels = layer_cfg[0]
- kernel_size = layer_cfg[1]
-
- if kernel_size > 0:
- layer = nn.Conv2d(in_channels, num_channels, kernel_size, **layer_cfg[2])
- else:
- if num_channels is None:
- layer = InterpolateModule(scale_factor=-kernel_size, mode='bilinear', align_corners=False, **layer_cfg[2])
- else:
- layer = nn.ConvTranspose2d(in_channels, num_channels, -kernel_size, **layer_cfg[2])
-
- in_channels = num_channels if num_channels is not None else in_channels
-
- # Don't return a ReLU layer if we're doing an upsample. This probably doesn't affect anything
- # output-wise, but there's no need to go through a ReLU here.
- # Commented out for backwards compatibility with previous models
- # if num_channels is None:
- # return [layer]
- # else:
- return [layer, nn.ReLU(inplace=True)]
-
- # Use sum to concat together all the component layer lists
- net = sum([make_layer(x) for x in conf], [])
- if not include_last_relu:
- net = net[:-1]
-
- return nn.Sequential(*(net)), in_channels
-
prior_cache = defaultdict(lambda: None)
class PredictionModule(nn.Module):
@@ -129,7 +76,7 @@ def __init__(self, in_channels, out_channels=1024, aspect_ratios=[[1]], scales=[
self.num_classes = cfg.num_classes
self.mask_dim = cfg.mask_dim # Defined by Yolact
- self.num_priors = sum(len(x) for x in aspect_ratios)
+ self.num_priors = sum(len(x)*len(scales) for x in aspect_ratios)
self.parent = [parent] # Don't include this in the state dict
self.index = index
self.num_heads = cfg.num_heads # Defined by Yolact
@@ -264,7 +211,7 @@ def forward(self, x):
preds['inst'] = inst
return preds
-
+
def make_priors(self, conv_h, conv_w, device):
""" Note that priors are [x,y,width,height] where (x,y) is the center of the box. """
global prior_cache
@@ -280,24 +227,25 @@ def make_priors(self, conv_h, conv_w, device):
x = (i + 0.5) / conv_w
y = (j + 0.5) / conv_h
- for scale, ars in zip(self.scales, self.aspect_ratios):
- for ar in ars:
- if not cfg.backbone.preapply_sqrt:
- ar = sqrt(ar)
-
- if cfg.backbone.use_pixel_scales:
- w = scale * ar / cfg._tmp_img_w # These are populated by
- h = scale / ar / cfg._tmp_img_h # Yolact.forward
- else:
- w = scale * ar / conv_w
- h = scale / ar / conv_h
-
- # This is for backward compatability with a bug where I made everything square by accident
- if cfg.backbone.use_square_anchors:
- h = w
-
- prior_data += [x, y, w, h]
-
+ for ars in self.aspect_ratios:
+ for scale in self.scales:
+ for ar in ars:
+ if not cfg.backbone.preapply_sqrt:
+ ar = sqrt(ar)
+
+ if cfg.backbone.use_pixel_scales:
+ w = scale * ar / cfg.max_size
+ h = scale / ar / cfg.max_size
+ else:
+ w = scale * ar / conv_w
+ h = scale / ar / conv_h
+
+ # This is for backward compatability with a bug where I made everything square by accident
+ if cfg.backbone.use_square_anchors:
+ h = w
+
+ prior_data += [x, y, w, h]
+
self.priors = torch.Tensor(prior_data, device=device).view(-1, 4).detach()
self.priors.requires_grad = False
self.last_img_size = (cfg._tmp_img_w, cfg._tmp_img_h)
@@ -470,6 +418,9 @@ def __init__(self):
self.selected_layers = cfg.backbone.selected_layers
src_channels = self.backbone.channels
+ if cfg.use_maskiou:
+ self.maskiou_net = FastMaskIoUNet()
+
if cfg.fpn is not None:
# Some hacky rewiring to accomodate the FPN
self.fpn = FPN([src_channels[i] for i in self.selected_layers])
@@ -523,7 +474,6 @@ def load_weights(self, path):
if key.startswith('fpn.downsample_layers.'):
if cfg.fpn is not None and int(key.split('.')[2]) >= cfg.fpn.num_downsample:
del state_dict[key]
-
self.load_state_dict(state_dict)
def init_weights(self, backbone_path):
@@ -575,6 +525,11 @@ def all_in(x, y):
else:
module.bias.data.zero_()
+ def get_maskiou_net(self):
+ if cfg.use_maskiou:
+ return self.maskiou_net
+ return None
+
def train(self, mode=True):
super().train(mode)
@@ -589,7 +544,7 @@ def freeze_bn(self, enable=False):
module.weight.requires_grad = enable
module.bias.requires_grad = enable
-
+
def forward(self, x):
""" The input should be of size [batch_size, 3, img_h, img_w] """
_, _, img_h, img_w = x.size()