This repository provides pre-trained Keypoint-Mask RCNN that predicts instance mask, keypoints and boxes. All models are trained with detectron2.
The pre-trained keypoint R-CNN models in the detectron2 model zoo do not have mask heads and, thus, only predict keypoints and boxes (not instance masks).
Follow the detectron2 installation instructions.
The following works for me:
pip install 'git+https://github.com/facebookresearch/detectron2.git'
Name | pre-train | inference time (s/im) | box AP | kp. AP | mask AP | Weight path |
---|---|---|---|---|---|---|
R-50 FPN (detectron2 pre-trained) | IN1k | -- | 53.6 | 64.0 | --* | [weight] |
R-50 FPN-3x (detectron2 pre-trained) | IN1k | -- | 55.4 | 65.5 | --* | [weight] |
R-50 Mask R-CNN FPN-1x (ours) | IN1k | -- | 55.1 | 65.3 | 47.9 | [weight] [metrics] |
MViTv2-B Cascade Mask R-CNN (ours) [original config] | IN21K, sup, COCO | -- | 65.6 | 67.2 | 55.0 | [weight] [metrics] |
regnety_4gf_FPN [original config] | COCO | -- | 59.4 | 67.0 | 51.4 | [weight] [metrics] |
*: The pre-trained detectron2 Keypoint R-CNN has no mask head.
You can directly use detect.py
:
python3 detect.py path_to_image.png
The script accepts config-file arguments to change the detection model:
python3 detect.py --help
Usage: detect.py [OPTIONS] IMPATH
Options:
--config-file TEXT Path to a config file
--model-url TEXT Path to model weight
--score-threshold FLOAT
--help Show this message and exit.
For example, to use the MViTv2 model, you can type:
python3 detect.py images/11_Meeting_Meeting_11_Meeting_Meeting_11_176.jpg --model-url https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/mvitv2_b_keypoint_cascade_rcnn.pth --config-file configs/mvitv2_b_keypoint_cascade_rcnn.py
Or, in python you can simply use the following lines:
from detect import KeypointDetector
config_file = "configs/keypoint_maskrcnn_R_50_FPN_1x.py"
model_url = "https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth"
score_threshold = .5
impath = # a path to an image.
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu
detector = KeypointDetector(config_file, model_url, score_threshold, device)
im = np.array(Image.open(impath).convert("RGB"))
instances = detector.predict(im)
visualized_prediction = detector.visualize_prediction(im, instances)
Image.fromarray(visualized_prediction).show()
To train the network, use lazyconfig_train_net.py
found in the detectron2 repository.