Skip to content

Commit

Permalink
Release YOLACT++
Browse files Browse the repository at this point in the history
  • Loading branch information
chongzhou96 committed Dec 6, 2019
1 parent 02bde37 commit ef56a8d
Show file tree
Hide file tree
Showing 26 changed files with 2,706 additions and 118 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

A simple, fully convolutional model for real-time instance segmentation. This is the code for [our paper](https://arxiv.org/abs/1904.02689).

#### YOLACT++ implementation and models released!
YOLACT++ resnet50 model runs at 33.5 fps on a Titan Xp and achieves 34.1 mAP on COCO's `test-dev`.

Related paper will be posted on arXiv soon.

In order to use YOLACT++, make sure you compile the DCNv2 code. (See [Installation](https://github.com/dbolya/yolact#installation))

#### ICCV update (v1.1) released! Check out the ICCV trailer here:
[![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/0pMfmo8qfpQ/0.jpg)](https://www.youtube.com/watch?v=0pMfmo8qfpQ)

Expand Down Expand Up @@ -37,6 +44,11 @@ Some examples from our base model (33.5 fps on a Titan Xp and 29.8 mAP on COCO's
git clone https://github.com/dbolya/yolact.git
cd yolact
```
- Compile deformable convolutional layers (from [DCNv2](https://github.com/CharlesShang/DCNv2/tree/pytorch_1.0))
```Shell
cd external/DCNv2
./make.sh
```
- If you'd like to train YOLACT, download the COCO dataset and the 2014/2017 annotations. Note that this script will take a while and dump 21gb of files into `./data/coco`.
```Shell
sh data/scripts/COCO.sh
Expand All @@ -57,6 +69,13 @@ As of April 5th, 2019 here are our latest models along with their FPS on a Titan
| 550 | Resnet101-FPN | 33.0 | 29.8 | [yolact_base_54_800000.pth](https://drive.google.com/file/d/1UYy3dMapbH1BnmtZU4WH1zbYgOzzHHf_/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EYRWxBEoKU9DiblrWx2M89MBGFkVVB_drlRd_v5sdT3Hgg)
| 700 | Resnet101-FPN | 23.6 | 31.2 | [yolact_im700_54_800000.pth](https://drive.google.com/file/d/1lE4Lz5p25teiXV-6HdTiOJSnS7u7GBzg/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/Eagg5RSc5hFEhp7sPtvLNyoBjhlf2feog7t8OQzHKKphjw)

YOLACT++ models (released on Dec. 6th, 2019):

| Image Size | Backbone | FPS | mAP | Weights | |
|:----------:|:-------------:|:----:|:----:|----------------------------------------------------------------------------------------------------------------------|--------|
| 550 | Resnet50-FPN | 33.5 | 34.1 | [yolact_plus_resnet50_54_800000.pth](https://drive.google.com/file/d/1ZPu1YR2UzGHQD0o1rEqy-j5bmEm3lbyP/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EcJAtMiEFlhAnVsDf00yWRIBUC4m8iE9NEEiV05XwtEoGw) |
| 550 | Resnet101-FPN | 27.3 | 34.6 | [yolact_plus_base_54_800000.pth](https://drive.google.com/file/d/15id0Qq5eqRbkD-N3ZjDZXdCvRyIaHpFB/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EVQ62sF0SrJPrl_68onyHF8BpG7c05A8PavV4a849sZgEA)

To evalute the model, put the corresponding weights file in the `./weights` directory and run one of the following commands.
## Quantitative Results on COCO
```Shell
Expand Down
34 changes: 22 additions & 12 deletions backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,25 @@

from collections import OrderedDict

from dcn_v2 import DCN

class Bottleneck(nn.Module):
""" Adapted from torchvision.models.resnet """
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=nn.BatchNorm2d, dilation=1):
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=nn.BatchNorm2d, dilation=1, use_dcn=False):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, dilation=dilation)
self.bn1 = norm_layer(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=dilation, bias=False, dilation=dilation)
if use_dcn:
self.conv2 = DCN(planes, planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, deformable_groups=1)
self.conv2.bias.data.zero_()
self.conv2.conv_offset_mask.weight.data.zero_()
self.conv2.conv_offset_mask.bias.data.zero_()
else:
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=dilation, bias=False, dilation=dilation)
self.bn2 = norm_layer(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False, dilation=dilation)
self.bn3 = norm_layer(planes * 4)
Expand Down Expand Up @@ -47,7 +56,7 @@ def forward(self, x):
class ResNetBackbone(nn.Module):
""" Adapted from torchvision.models.resnet """

def __init__(self, layers, atrous_layers=[], block=Bottleneck, norm_layer=nn.BatchNorm2d):
def __init__(self, layers, dcn_layers=[0, 0, 0, 0], dcn_interval=1, atrous_layers=[], block=Bottleneck, norm_layer=nn.BatchNorm2d):
super().__init__()

# These will be populated by _make_layer
Expand All @@ -66,10 +75,10 @@ def __init__(self, layers, atrous_layers=[], block=Bottleneck, norm_layer=nn.Bat
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

self._make_layer(block, 64, layers[0])
self._make_layer(block, 128, layers[1], stride=2)
self._make_layer(block, 256, layers[2], stride=2)
self._make_layer(block, 512, layers[3], stride=2)
self._make_layer(block, 64, layers[0], dcn_layers=dcn_layers[0], dcn_interval=dcn_interval)
self._make_layer(block, 128, layers[1], stride=2, dcn_layers=dcn_layers[1], dcn_interval=dcn_interval)
self._make_layer(block, 256, layers[2], stride=2, dcn_layers=dcn_layers[2], dcn_interval=dcn_interval)
self._make_layer(block, 512, layers[3], stride=2, dcn_layers=dcn_layers[3], dcn_interval=dcn_interval)

# This contains every module that should be initialized by loading in pretrained weights.
# Any extra layers added onto this that won't be initialized by init_backbone will not be
Expand All @@ -78,7 +87,7 @@ def __init__(self, layers, atrous_layers=[], block=Bottleneck, norm_layer=nn.Bat
self.backbone_modules = [m for m in self.modules() if isinstance(m, nn.Conv2d)]


def _make_layer(self, block, planes, blocks, stride=1):
def _make_layer(self, block, planes, blocks, stride=1, dcn_layers=0, dcn_interval=1):
""" Here one layer means a string of n Bottleneck blocks. """
downsample = None

Expand All @@ -97,11 +106,12 @@ def _make_layer(self, block, planes, blocks, stride=1):
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.norm_layer, self.dilation))
use_dcn = (dcn_layers >= blocks)
layers.append(block(self.inplanes, planes, stride, downsample, self.norm_layer, self.dilation, use_dcn=use_dcn))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, norm_layer=self.norm_layer))

use_dcn = ((i+dcn_layers) >= blocks) and (i % dcn_interval == 0)
layers.append(block(self.inplanes, planes, norm_layer=self.norm_layer, use_dcn=use_dcn))
layer = nn.Sequential(*layers)

self.channels.append(planes * block.expansion)
Expand Down
61 changes: 61 additions & 0 deletions data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ def print(self):
'pred_aspect_ratios': [ [[0.66685089, 1.7073535, 0.87508774, 1.16524493, 0.49059086]] ] * 6,
})

resnet101_dcn_inter3_backbone = resnet101_backbone.copy({
'name': 'ResNet101_DCN_Interval3',
'args': ([3, 4, 23, 3], [0, 4, 23, 3], 3),
})

resnet50_backbone = resnet101_backbone.copy({
'name': 'ResNet50',
'path': 'resnet50-19c8e357.pth',
Expand All @@ -255,6 +260,11 @@ def print(self):
'transform': resnet_transform,
})

resnet50_dcnv2_backbone = resnet50_backbone.copy({
'name': 'ResNet50_DCNv2',
'args': ([3, 4, 6, 3], [0, 4, 6, 3]),
})

darknet53_backbone = backbone_base.copy({
'name': 'DarkNet53',
'path': 'darknet53.pth',
Expand Down Expand Up @@ -614,6 +624,19 @@ def print(self):

'backbone': None,
'name': 'base_config',

# Fast Mask Re-scoring Network
# Inspried by Mask Scoring R-CNN (https://arxiv.org/abs/1903.00241)
# Do not crop out the mask with bbox but slide a convnet on the image-size mask,
# then use global pooling to get the final mask score
'use_maskiou': False,
'maskiou_net': [],
'remove_small_gt_mask': -1,

'maskiou_alpha': 1.0,
'rescore_mask': False,
'rescore_bbox': False,
'maskious_to_train': -1,
})


Expand Down Expand Up @@ -736,6 +759,44 @@ def print(self):
})
})

# ----------------------- YOLACT++ CONFIGS ----------------------- #

yolact_plus_base_config = yolact_base_config.copy({
'name': 'yolact_plus_base',

'backbone': resnet101_dcn_inter3_backbone.copy({
'selected_layers': list(range(1, 4)),

'pred_aspect_ratios': [ [[1, 1/2, 2]] ]*5,
'pred_scales': [[i * 2 ** (j / 3.0) for j in range(3)] for i in [24, 48, 96, 192, 384]],
'use_pixel_scales': True,
'preapply_sqrt': False,
'use_square_anchors': False,
}),

'use_maskiou': True,
'maskiou_net': [(8, 3, {'stride': 2}), (16, 3, {'stride': 2}), (32, 3, {'stride': 2}), (64, 3, {'stride': 2}), (128, 3, {'stride': 2}), (80, 1, {})],
'maskiou_alpha': 25,
'rescore_bbox': False,
'rescore_mask': True,

'remove_small_gt_mask': 5*5,
})

yolact_plus_resnet50_config = yolact_plus_base_config.copy({
'name': 'yolact_plus_resnet50',

'backbone': resnet50_dcnv2_backbone.copy({
'selected_layers': list(range(1, 4)),

'pred_aspect_ratios': [ [[1, 1/2, 2]] ]*5,
'pred_scales': [[i * 2 ** (j / 3.0) for j in range(3)] for i in [24, 48, 96, 192, 384]],
'use_pixel_scales': True,
'preapply_sqrt': False,
'use_square_anchors': False,
}),
})


# Default config
cfg = yolact_base_config.copy()
Expand Down
76 changes: 56 additions & 20 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def parse_args(argv=None):
coco_cats_inv = {}
color_cache = defaultdict(lambda: {})

def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str=''):
def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str='', maskiou_net=None):
"""
Note: If undo_transform=False then im_h and im_w are allowed to be None.
"""
Expand All @@ -146,14 +146,34 @@ def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, ma
with timer.env('Postprocess'):
t = postprocess(dets_out, w, h, visualize_lincomb = args.display_lincomb,
crop_masks = args.crop,
score_threshold = args.score_threshold)
score_threshold = args.score_threshold,
maskiou_net = maskiou_net)
torch.cuda.synchronize()

# FIXME reduce copy
with timer.env('Copy'):
if cfg.eval_mask_branch:
# Masks are drawn on the GPU, so don't copy
masks = t[3][:args.top_k]
classes, scores, boxes = [x[:args.top_k].cpu().numpy() for x in t[:3]]
masks = t[3]
classes, scores, boxes = [x for x in t[:3]]
if isinstance(scores, list):
box_scores = scores[0].cpu().numpy()
mask_scores = scores[1].cpu().numpy()
# Re-rank predictions by mask scores
_scores = mask_scores * box_scores
idx = np.argsort(-_scores)
scores = box_scores[idx]
classes = classes.cpu().numpy()[idx]
boxes = boxes.cpu().numpy()[idx]
masks = masks[idx]
else:
scores = scores.cpu().numpy()
classes = classes.cpu().numpy()
boxes = boxes.cpu().numpy()
scores = scores[:args.top_k]
classes = classes[:args.top_k]
boxes = boxes[:args.top_k]
masks = masks[:args.top_k]

num_dets_to_consider = min(args.top_k, classes.shape[0])
for j in range(num_dets_to_consider):
Expand Down Expand Up @@ -257,12 +277,20 @@ def get_color(j, on_gpu=None):

return img_numpy

def prep_benchmark(dets_out, h, w):
def prep_benchmark(dets_out, h, w, maskiou_net=None):
with timer.env('Postprocess'):
t = postprocess(dets_out, w, h, crop_masks=args.crop, score_threshold=args.score_threshold)
t = postprocess(dets_out, w, h, crop_masks=args.crop, score_threshold=args.score_threshold, maskiou_net=maskiou_net)

with timer.env('Copy'):
classes, scores, boxes, masks = [x[:args.top_k].cpu().numpy() for x in t]
classes, scores, boxes, masks = [x[:args.top_k] for x in t]
if isinstance(scores, list):
box_scores = scores[0].cpu().numpy()
mask_scores = scores[1].cpu().numpy()
else:
scores = scores.cpu().numpy()
classes = classes.cpu().numpy()
boxes = boxes.cpu().numpy()
masks = masks.cpu().numpy()

with timer.env('Sync'):
# Just in case
Expand Down Expand Up @@ -371,7 +399,7 @@ def _bbox_iou(bbox1, bbox2, iscrowd=False):
ret = jaccard(bbox1, bbox2, iscrowd)
return ret.cpu()

def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, detections:Detections=None):
def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, detections:Detections=None, maskiou_net=None):
""" Returns a list of APs for this image, with each element being for a class """
if not args.output_coco_json:
with timer.env('Prepare gt'):
Expand All @@ -388,13 +416,19 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de
crowd_classes, gt_classes = split(gt_classes)

with timer.env('Postprocess'):
classes, scores, boxes, masks = postprocess(dets, w, h, crop_masks=args.crop, score_threshold=args.score_threshold)
classes, scores, boxes, masks = postprocess(dets, w, h, crop_masks=args.crop, score_threshold=args.score_threshold, maskiou_net=maskiou_net)

if classes.size(0) == 0:
return

classes = list(classes.cpu().numpy().astype(int))
scores = list(scores.cpu().numpy().astype(float))
if isinstance(scores, list):
box_scores = list(scores[0].cpu().numpy().astype(float))
mask_scores = list(scores[1].cpu().numpy().astype(float))
else:
scores = list(scores.cpu().numpy().astype(float))
box_scores = scores
mask_scores = scores
masks = masks.view(-1, h*w).cuda()
boxes = boxes.cuda()

Expand All @@ -406,8 +440,8 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de
for i in range(masks.shape[0]):
# Make sure that the bounding box actually makes sense and a mask was produced
if (boxes[i, 3] - boxes[i, 1]) * (boxes[i, 2] - boxes[i, 0]) > 0:
detections.add_bbox(image_id, classes[i], boxes[i,:], scores[i])
detections.add_mask(image_id, classes[i], masks[i,:,:], scores[i])
detections.add_bbox(image_id, classes[i], boxes[i,:], box_scores[i])
detections.add_mask(image_id, classes[i], masks[i,:,:], mask_scores[i])
return

with timer.env('Eval Setup'):
Expand All @@ -425,8 +459,8 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de
crowd_bbox_iou_cache = None

iou_types = [
('box', lambda i,j: bbox_iou_cache[i, j].item(), lambda i,j: crowd_bbox_iou_cache[i,j].item()),
('mask', lambda i,j: mask_iou_cache[i, j].item(), lambda i,j: crowd_mask_iou_cache[i,j].item())
('box', lambda i,j: bbox_iou_cache[i, j].item(), lambda i,j: crowd_bbox_iou_cache[i,j].item(), lambda i: box_scores[i]),
('mask', lambda i,j: mask_iou_cache[i, j].item(), lambda i,j: crowd_mask_iou_cache[i,j].item(), lambda i: mask_scores[i])
]

timer.start('Main loop')
Expand All @@ -437,7 +471,7 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de
for iouIdx in range(len(iou_thresholds)):
iou_threshold = iou_thresholds[iouIdx]

for iou_type, iou_func, crowd_func in iou_types:
for iou_type, iou_func, crowd_func, score_func in iou_types:
gt_used = [False] * len(gt_classes)

ap_obj = ap_data[iou_type][iouIdx][_class]
Expand All @@ -461,7 +495,7 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de

if max_match_idx >= 0:
gt_used[max_match_idx] = True
ap_obj.push(scores[i], True)
ap_obj.push(score_func(i), True)
else:
# If the detection matches a crowd, we can just ignore it
matched_crowd = False
Expand All @@ -481,7 +515,7 @@ def prep_metrics(ap_data, dets, img, gt, gt_masks, h, w, num_crowd, image_id, de
# same result as COCOEval. There aren't even that many crowd annotations to
# begin with, but accuracy is of the utmost importance.
if not matched_crowd:
ap_obj.push(scores[i], False)
ap_obj.push(score_func(i), False)
timer.stop('Main loop')


Expand Down Expand Up @@ -846,6 +880,7 @@ def evaluate(net:Yolact, dataset, train_mode=False):
net.detect.use_cross_class_nms = args.cross_class_nms
cfg.mask_proto_debug = args.mask_proto_debug

# TODO Currently we do not support Fast Mask Re-scroing in evalimage, evalimages, and evalvideo
if args.image is not None:
if ':' in args.image:
inp, out = args.image.split(':')
Expand Down Expand Up @@ -921,13 +956,14 @@ def evaluate(net:Yolact, dataset, train_mode=False):
with timer.env('Network Extra'):
preds = net(batch)

maskiou_net = net.get_maskiou_net()
# Perform the meat of the operation here depending on our mode.
if args.display:
img_numpy = prep_display(preds, img, h, w)
img_numpy = prep_display(preds, img, h, w, maskiou_net=maskiou_net)
elif args.benchmark:
prep_benchmark(preds, h, w)
prep_benchmark(preds, h, w, maskiou_net=maskiou_net)
else:
prep_metrics(ap_data, preds, img, gt, gt_masks, h, w, num_crowd, dataset.ids[image_idx], detections)
prep_metrics(ap_data, preds, img, gt, gt_masks, h, w, num_crowd, dataset.ids[image_idx], detections, maskiou_net=maskiou_net)

# First couple of images take longer because we're constructing the graph.
# Since that's technically initialization, don't include those in the FPS calculations.
Expand Down
Loading

0 comments on commit ef56a8d

Please sign in to comment.