diff --git a/README.md b/README.md index fb12f00..6093e37 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,20 @@ You can also install with the `setup.py` python3 setup.py install ``` +## Model weights + +In this table you can see the urls for the different models implemented. We recommend to download them to local before do an inference: + +|Model|Weights url| +|:-:|:-:| +|RetinaNetMobileNetV1|[Link](https://raw.githubusercontent.com/hukkelas/DSFD-Pytorch-Inference/master/RetinaFace_mobilenet025.pth)| +|RetinaNetResNet50|[Link](https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/8dd81669-eb84-4520-8173-dbe49d72f44cb2eef6da-3983-4a12-9085-d11555b93842c19bdf27-b924-4214-9381-e6cac30b87cf)| +|DSFDDetector|[Link 1](https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/61be4ec7-8c11-4a4a-a9f4-827144e4ab4f0c2764c1-80a0-4083-bbfa-68419f889b80e4692358-979b-458e-97da-c1a1660b3314) - [Link 2](https://drive.google.com/uc?id=1WeXlNYsM6dMP3xQQELI-4gxhwKUQxc3-&export=download)| + + + ## Getting started + Run ``` python3 test.py @@ -96,8 +109,6 @@ This is **very roughly** estimated on a 1024x687 image. The reported time is the *Done over 100 forward passes on a MacOS Mid 2014, 15-Inch. - - ## Changelog - September 1st 2020: added support for fp16/mixed precision inference - September 24th 2020: added support for TensorRT. diff --git a/face_detection/__init__.py b/face_detection/__init__.py index e589bff..531a1ce 100644 --- a/face_detection/__init__.py +++ b/face_detection/__init__.py @@ -1,3 +1,3 @@ -from .build import build_detector, available_detectors -from .dsfd import DSFDDetector -from .retinaface import RetinaNetMobileNetV1, RetinaNetResNet50 \ No newline at end of file +from face_detection.build import build_detector, available_detectors +from face_detection.dsfd import DSFDDetector +from face_detection.retinaface import RetinaNetMobileNetV1, RetinaNetResNet50 \ No newline at end of file diff --git a/face_detection/base.py b/face_detection/base.py index b0ec89c..3dd1fcb 100644 --- a/face_detection/base.py +++ b/face_detection/base.py @@ -3,7 +3,8 @@ import typing from abc import ABC, abstractmethod from torchvision.ops import nms -from .box_utils import scale_boxes + +from face_detection.box_utils import scale_boxes def check_image(im: np.ndarray): @@ -24,7 +25,9 @@ def __init__( device: torch.device, max_resolution: int, fp16_inference: bool, - clip_boxes: bool): + clip_boxes: bool, + model_weights: str, + ): """ Args: confidence_threshold (float): Threshold to filter out bounding boxes @@ -36,10 +39,11 @@ def __init__( """ self.confidence_threshold = confidence_threshold self.nms_iou_threshold = nms_iou_threshold - self.device = device + self.device = torch.device(device) self.max_resolution = max_resolution self.fp16_inference = fp16_inference self.clip_boxes = clip_boxes + self.model_weights = model_weights self.mean = np.array( [123, 117, 104], dtype=np.float32).reshape(1, 1, 1, 3) diff --git a/face_detection/build.py b/face_detection/build.py index ebf84ac..306b3be 100644 --- a/face_detection/build.py +++ b/face_detection/build.py @@ -1,6 +1,8 @@ -from .registry import build_from_cfg, Registry -from .base import Detector -from .torch_utils import get_device +from typing import Optional + +from face_detection.registry import build_from_cfg, Registry +from face_detection.base import Detector + available_detectors = [ "DSFDDetector", @@ -14,14 +16,17 @@ def build_detector( name: str = "DSFDDetector", confidence_threshold: float = 0.5, nms_iou_threshold: float = 0.3, - device=get_device(), + device: str = "cpu", max_resolution: int = None, fp16_inference: bool = False, - clip_boxes: bool = False + clip_boxes: bool = False, + model_weights: Optional[str] = None, ) -> Detector: assert name in available_detectors,\ - f"Detector not available. Chooce one of the following"+\ - ",".join(available_detectors) + f"""Detector not available. + Choose one of the following {','.join(available_detectors)} + """ + args = dict( type=name, confidence_threshold=confidence_threshold, @@ -29,7 +34,8 @@ def build_detector( device=device, max_resolution=max_resolution, fp16_inference=fp16_inference, - clip_boxes=clip_boxes + clip_boxes=clip_boxes, + model_weights=model_weights, ) detector = build_from_cfg(args, DETECTOR_REGISTRY) return detector diff --git a/face_detection/dsfd/__init__.py b/face_detection/dsfd/__init__.py index 9121fca..e4327ff 100644 --- a/face_detection/dsfd/__init__.py +++ b/face_detection/dsfd/__init__.py @@ -1 +1 @@ -from .detect import DSFDDetector \ No newline at end of file +from face_detection.dsfd.detect import DSFDDetector \ No newline at end of file diff --git a/face_detection/dsfd/detect.py b/face_detection/dsfd/detect.py index c25f5f1..bec509f 100644 --- a/face_detection/dsfd/detect.py +++ b/face_detection/dsfd/detect.py @@ -1,14 +1,12 @@ import torch import numpy as np import typing -from .face_ssd import SSD -from .config import resnet152_model_config -from .. import torch_utils -from torch.hub import load_state_dict_from_url -from ..base import Detector -from ..build import DETECTOR_REGISTRY -model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/61be4ec7-8c11-4a4a-a9f4-827144e4ab4f0c2764c1-80a0-4083-bbfa-68419f889b80e4692358-979b-458e-97da-c1a1660b3314" +from face_detection import torch_utils +from face_detection.dsfd.face_ssd import SSD +from face_detection.dsfd.config import resnet152_model_config +from face_detection.base import Detector +from face_detection.build import DETECTOR_REGISTRY @DETECTOR_REGISTRY.register_module @@ -17,10 +15,8 @@ class DSFDDetector(Detector): def __init__( self, *args, **kwargs): super().__init__(*args, **kwargs) - state_dict = load_state_dict_from_url( - model_url, - map_location=self.device, - progress=True) + + state_dict = torch_utils.load_weights(self.model_weights) self.net = SSD(resnet152_model_config) self.net.load_state_dict(state_dict) self.net.eval() diff --git a/face_detection/dsfd/face_ssd.py b/face_detection/dsfd/face_ssd.py index 7f3bec9..f24b28e 100644 --- a/face_detection/dsfd/face_ssd.py +++ b/face_detection/dsfd/face_ssd.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch.nn.functional as F import torchvision -from .utils import PriorBox -from ..box_utils import batched_decode +from face_detection.dsfd.utils import PriorBox +from face_detection.box_utils import batched_decode class FEM(nn.Module): diff --git a/face_detection/retinaface/__init__.py b/face_detection/retinaface/__init__.py index 8c1a129..00dbc63 100644 --- a/face_detection/retinaface/__init__.py +++ b/face_detection/retinaface/__init__.py @@ -1,3 +1,3 @@ # Adapted from https://github.com/biubug6/Pytorch_Retinaface # Original license: MIT -from .detect import RetinaNetMobileNetV1, RetinaNetResNet50 \ No newline at end of file +from face_detection.retinaface.detect import RetinaNetMobileNetV1, RetinaNetResNet50 \ No newline at end of file diff --git a/face_detection/retinaface/detect.py b/face_detection/retinaface/detect.py index b3fe62f..5964f8c 100644 --- a/face_detection/retinaface/detect.py +++ b/face_detection/retinaface/detect.py @@ -1,18 +1,19 @@ # Adapted from https://github.com/biubug6/Pytorch_Retinaface # Original license: MIT -import torch import numpy as np -from .. import torch_utils import typing -from .models.retinaface import RetinaFace -from ..box_utils import batched_decode -from .utils import decode_landm -from .config import cfg_mnet, cfg_re50 -from .prior_box import PriorBox -from torch.hub import load_state_dict_from_url + +import torch from torchvision.ops import nms -from ..base import Detector -from ..build import DETECTOR_REGISTRY + +from face_detection import torch_utils +from face_detection.base import Detector +from face_detection.build import DETECTOR_REGISTRY +from face_detection.box_utils import batched_decode +from face_detection.retinaface.models.retinaface import RetinaFace +from face_detection.retinaface.utils import decode_landm +from face_detection.retinaface.config import cfg_mnet, cfg_re50 +from face_detection.retinaface.prior_box import PriorBox class RetinaNetDetector(Detector): @@ -23,20 +24,14 @@ def __init__( *args, **kwargs): super().__init__(*args, **kwargs) + if model == "mobilenet": cfg = cfg_mnet - state_dict = load_state_dict_from_url( - "https://raw.githubusercontent.com/hukkelas/DSFD-Pytorch-Inference/master/RetinaFace_mobilenet025.pth", - map_location=torch_utils.get_device() - ) else: assert model == "resnet50" cfg = cfg_re50 - state_dict = load_state_dict_from_url( - "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/8dd81669-eb84-4520-8173-dbe49d72f44cb2eef6da-3983-4a12-9085-d11555b93842c19bdf27-b924-4214-9381-e6cac30b87cf", - map_location=torch_utils.get_device() - ) - state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + + state_dict = torch_utils.load_weights(self.model_weights) net = RetinaFace(cfg=cfg) net.eval() net.load_state_dict(state_dict) @@ -119,7 +114,7 @@ def _detect( self.cfg, image_size=(height, width)) priors = priorbox.forward() self.prior_box_cache[image.shape[2:]] = priors - priors = torch_utils.to_cuda(priors, self.device) + priors = priors.to(self.device) prior_data = priors.data boxes = batched_decode(loc, prior_data, self.cfg['variance']) boxes = torch.cat((boxes, scores), dim=-1) diff --git a/face_detection/retinaface/models/retinaface.py b/face_detection/retinaface/models/retinaface.py index 77fb358..7f99bab 100644 --- a/face_detection/retinaface/models/retinaface.py +++ b/face_detection/retinaface/models/retinaface.py @@ -3,7 +3,8 @@ import torch import torch.nn as nn import torchvision.models._utils as _utils -from .net import MobileNetV1, SSH, FPN + +from face_detection.retinaface.models.net import MobileNetV1, SSH, FPN class ClassHead(nn.Module): diff --git a/face_detection/retinaface/onnx.py b/face_detection/retinaface/onnx.py index 1f53c04..27b76d0 100644 --- a/face_detection/retinaface/onnx.py +++ b/face_detection/retinaface/onnx.py @@ -1,16 +1,17 @@ # Adapted from https://github.com/biubug6/Pytorch_Retinaface # Original license: MIT -import torch import cv2 import numpy as np -from .. import torch_utils -from .models.retinaface import RetinaFace -from ..box_utils import batched_decode -from .utils import decode_landm -from .config import cfg_re50 -from .prior_box import PriorBox + +import torch from torch.hub import load_state_dict_from_url +from face_detection.retinaface.models.retinaface import RetinaFace +from face_detection.box_utils import batched_decode +from face_detection.retinaface.utils import decode_landm +from face_detection.retinaface.config import cfg_re50 +from face_detection.retinaface.prior_box import PriorBox + class RetinaNetDetectorONNX(torch.nn.Module): @@ -20,7 +21,7 @@ def __init__(self, input_imshape, inference_imshape): cfg = cfg_re50 state_dict = load_state_dict_from_url( "https://folk.ntnu.no/haakohu/RetinaFace_ResNet50.pth", - map_location=torch_utils.get_device() + map_location="cpu" ) state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} net = RetinaFace(cfg=cfg) diff --git a/face_detection/retinaface/tensorrt_wrap.py b/face_detection/retinaface/tensorrt_wrap.py index 106052c..be52d97 100644 --- a/face_detection/retinaface/tensorrt_wrap.py +++ b/face_detection/retinaface/tensorrt_wrap.py @@ -5,8 +5,9 @@ import cv2 import tensorrt as trt import torch -from .onnx import RetinaNetDetectorONNX -from .utils import python_nms + +from face_detection.retinaface.onnx import RetinaNetDetectorONNX +from face_detection.retinaface.utils import python_nms cache_dir = torch.hub._get_torch_home() diff --git a/face_detection/torch_utils.py b/face_detection/torch_utils.py index 2f891af..4f282a4 100644 --- a/face_detection/torch_utils.py +++ b/face_detection/torch_utils.py @@ -1,19 +1,6 @@ import numpy as np import torch - - -def to_cuda(elements, device): - if torch.cuda.is_available(): - if type(elements) == tuple or type(elements) == list: - return [x.to(device) for x in elements] - return elements.to(device) - return elements - - -def get_device(): - if torch.cuda.is_available(): - return torch.device("cuda") - return torch.device("cpu") +import os def image_to_torch(image, device): @@ -25,3 +12,16 @@ def image_to_torch(image, device): image = image[None, :, :, :] image = torch.from_numpy(image).to(device) return image + + +def load_weights(weights_path): + if os.path.isfile(weights_path): + state_dict = torch.load(weights_path, map_location="cpu") + else: + state_dict = torch.hub.load_state_dict_from_url( + weights_path, + map_location="cpu" + ) + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + + return state_dict diff --git a/setup.py b/setup.py index 61dfd4c..7eacf85 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,4 @@ import setuptools -import torch -import torchvision - -torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] -assert torch_ver >= [1, 6], "Requires PyTorch >= 1.6" -torchvision_ver = [int(x) for x in torchvision.__version__.split(".")[:2]] -assert torchvision_ver >= [0, 3], "Requires torchvision >= 0.3" setuptools.setup( name="face_detection", diff --git a/test.py b/test.py index b64c959..2f6874d 100644 --- a/test.py +++ b/test.py @@ -2,6 +2,9 @@ import os import cv2 import time +import argparse +import logging + import face_detection @@ -12,27 +15,65 @@ def draw_faces(im, bboxes): if __name__ == "__main__": - impaths = "images" - impaths = glob.glob(os.path.join(impaths, "*.jpg")) + + logging.basicConfig(level=getattr(logging, "INFO")) + + parser = argparse.ArgumentParser( + prog="DSDF face detector", + description="Face detector based on AI" + ) + parser.add_argument("--img_path", type=str, required=True, + help="path to single image or a folder where many images are stored") + parser.add_argument("--model", type=str, required=True, + choices=["DSFDDetector", "RetinaNetResNet50", "RetinaNetMobileNetV1"], + default="DSFDDetector", + help="Model to use") + parser.add_argument("--model_weights", type=str, required=True, + help="Path to the downloaded model weights") + parser.add_argument("--confidence_threshold", type=float, default=0.3) + parser.add_argument("--nms_iou_threshold", type=float, default=0.5) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--max_resolution", type=int, default=1080) + parser.add_argument("--fp16_inference", type=bool, default=True) + parser.add_argument("--clip_boxes", type=bool, default=False) + parser.add_argument("--out_folder", type=str, default="Folder where the output images will be saved") + + args = parser.parse_args() + + if os.path.isfile(args.img_path): + logging.info("Single image detected") + impaths = [args.img_path,] + else: + impaths = glob.glob(os.path.join(args.img_path, "*")) + logging.info(f"Many images detected (total={len(impaths)})") + detector = face_detection.build_detector( - "DSFDDetector", - max_resolution=1080 + name=args.model, + confidence_threshold=args.confidence_threshold, + nms_iou_threshold=args.nms_iou_threshold, + device=args.device, + max_resolution=args.max_resolution, + fp16_inference=args.fp16_inference, + clip_boxes=args.clip_boxes, + model_weights=args.model_weights, ) + logging.info(f"Model {args.model} loaded with weights {args.model_weights}") + + if not os.path.isdir(args.out_folder): + os.makedirs(args.out_folder) + for impath in impaths: - if impath.endswith("out.jpg"): continue im = cv2.imread(impath) - print("Processing:", impath) + logging.info(f"Processing: {impath}") t = time.time() dets = detector.detect( im[:, :, ::-1] )[:, :4] - print(f"Detection time: {time.time()- t:.3f}") + logging.info(f"Detection time: {time.time()- t:.3f}") draw_faces(im, dets) + imname = os.path.basename(impath).split(".")[0] - output_path = os.path.join( - os.path.dirname(impath), - f"{imname}_out.jpg" - ) + output_path = os.path.join(args.out_folder,f"{imname}.jpg") cv2.imwrite(output_path, im) \ No newline at end of file