Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some changes to use CPU for inference and local weights #34

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,20 @@ You can also install with the `setup.py`
python3 setup.py install
```

## Model weights

In this table you can see the urls for the different models implemented. We recommend to download them to local before do an inference:

|Model|Weights url|
|:-:|:-:|
|RetinaNetMobileNetV1|[Link](https://raw.githubusercontent.com/hukkelas/DSFD-Pytorch-Inference/master/RetinaFace_mobilenet025.pth)|
|RetinaNetResNet50|[Link](https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/8dd81669-eb84-4520-8173-dbe49d72f44cb2eef6da-3983-4a12-9085-d11555b93842c19bdf27-b924-4214-9381-e6cac30b87cf)|
|DSFDDetector|[Link 1](https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/61be4ec7-8c11-4a4a-a9f4-827144e4ab4f0c2764c1-80a0-4083-bbfa-68419f889b80e4692358-979b-458e-97da-c1a1660b3314) - [Link 2](https://drive.google.com/uc?id=1WeXlNYsM6dMP3xQQELI-4gxhwKUQxc3-&export=download)|



## Getting started

Run
```
python3 test.py
Expand Down Expand Up @@ -96,8 +109,6 @@ This is **very roughly** estimated on a 1024x687 image. The reported time is the

*Done over 100 forward passes on a MacOS Mid 2014, 15-Inch.



## Changelog
- September 1st 2020: added support for fp16/mixed precision inference
- September 24th 2020: added support for TensorRT.
Expand Down
6 changes: 3 additions & 3 deletions face_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .build import build_detector, available_detectors
from .dsfd import DSFDDetector
from .retinaface import RetinaNetMobileNetV1, RetinaNetResNet50
from face_detection.build import build_detector, available_detectors
from face_detection.dsfd import DSFDDetector
from face_detection.retinaface import RetinaNetMobileNetV1, RetinaNetResNet50
10 changes: 7 additions & 3 deletions face_detection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import typing
from abc import ABC, abstractmethod
from torchvision.ops import nms
from .box_utils import scale_boxes

from face_detection.box_utils import scale_boxes


def check_image(im: np.ndarray):
Expand All @@ -24,7 +25,9 @@ def __init__(
device: torch.device,
max_resolution: int,
fp16_inference: bool,
clip_boxes: bool):
clip_boxes: bool,
model_weights: str,
):
"""
Args:
confidence_threshold (float): Threshold to filter out bounding boxes
Expand All @@ -36,10 +39,11 @@ def __init__(
"""
self.confidence_threshold = confidence_threshold
self.nms_iou_threshold = nms_iou_threshold
self.device = device
self.device = torch.device(device)
self.max_resolution = max_resolution
self.fp16_inference = fp16_inference
self.clip_boxes = clip_boxes
self.model_weights = model_weights
self.mean = np.array(
[123, 117, 104], dtype=np.float32).reshape(1, 1, 1, 3)

Expand Down
22 changes: 14 additions & 8 deletions face_detection/build.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .registry import build_from_cfg, Registry
from .base import Detector
from .torch_utils import get_device
from typing import Optional

from face_detection.registry import build_from_cfg, Registry
from face_detection.base import Detector


available_detectors = [
"DSFDDetector",
Expand All @@ -14,22 +16,26 @@ def build_detector(
name: str = "DSFDDetector",
confidence_threshold: float = 0.5,
nms_iou_threshold: float = 0.3,
device=get_device(),
device: str = "cpu",
max_resolution: int = None,
fp16_inference: bool = False,
clip_boxes: bool = False
clip_boxes: bool = False,
model_weights: Optional[str] = None,
) -> Detector:
assert name in available_detectors,\
f"Detector not available. Chooce one of the following"+\
",".join(available_detectors)
f"""Detector not available.
Choose one of the following {','.join(available_detectors)}
"""

args = dict(
type=name,
confidence_threshold=confidence_threshold,
nms_iou_threshold=nms_iou_threshold,
device=device,
max_resolution=max_resolution,
fp16_inference=fp16_inference,
clip_boxes=clip_boxes
clip_boxes=clip_boxes,
model_weights=model_weights,
)
detector = build_from_cfg(args, DETECTOR_REGISTRY)
return detector
2 changes: 1 addition & 1 deletion face_detection/dsfd/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .detect import DSFDDetector
from face_detection.dsfd.detect import DSFDDetector
18 changes: 7 additions & 11 deletions face_detection/dsfd/detect.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import torch
import numpy as np
import typing
from .face_ssd import SSD
from .config import resnet152_model_config
from .. import torch_utils
from torch.hub import load_state_dict_from_url
from ..base import Detector
from ..build import DETECTOR_REGISTRY

model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/61be4ec7-8c11-4a4a-a9f4-827144e4ab4f0c2764c1-80a0-4083-bbfa-68419f889b80e4692358-979b-458e-97da-c1a1660b3314"
from face_detection import torch_utils
from face_detection.dsfd.face_ssd import SSD
from face_detection.dsfd.config import resnet152_model_config
from face_detection.base import Detector
from face_detection.build import DETECTOR_REGISTRY


@DETECTOR_REGISTRY.register_module
Expand All @@ -17,10 +15,8 @@ class DSFDDetector(Detector):
def __init__(
self, *args, **kwargs):
super().__init__(*args, **kwargs)
state_dict = load_state_dict_from_url(
model_url,
map_location=self.device,
progress=True)

state_dict = torch_utils.load_weights(self.model_weights)
self.net = SSD(resnet152_model_config)
self.net.load_state_dict(state_dict)
self.net.eval()
Expand Down
4 changes: 2 additions & 2 deletions face_detection/dsfd/face_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from .utils import PriorBox
from ..box_utils import batched_decode
from face_detection.dsfd.utils import PriorBox
from face_detection.box_utils import batched_decode


class FEM(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion face_detection/retinaface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Adapted from https://github.com/biubug6/Pytorch_Retinaface
# Original license: MIT
from .detect import RetinaNetMobileNetV1, RetinaNetResNet50
from face_detection.retinaface.detect import RetinaNetMobileNetV1, RetinaNetResNet50
35 changes: 15 additions & 20 deletions face_detection/retinaface/detect.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# Adapted from https://github.com/biubug6/Pytorch_Retinaface
# Original license: MIT
import torch
import numpy as np
from .. import torch_utils
import typing
from .models.retinaface import RetinaFace
from ..box_utils import batched_decode
from .utils import decode_landm
from .config import cfg_mnet, cfg_re50
from .prior_box import PriorBox
from torch.hub import load_state_dict_from_url

import torch
from torchvision.ops import nms
from ..base import Detector
from ..build import DETECTOR_REGISTRY

from face_detection import torch_utils
from face_detection.base import Detector
from face_detection.build import DETECTOR_REGISTRY
from face_detection.box_utils import batched_decode
from face_detection.retinaface.models.retinaface import RetinaFace
from face_detection.retinaface.utils import decode_landm
from face_detection.retinaface.config import cfg_mnet, cfg_re50
from face_detection.retinaface.prior_box import PriorBox


class RetinaNetDetector(Detector):
Expand All @@ -23,20 +24,14 @@ def __init__(
*args,
**kwargs):
super().__init__(*args, **kwargs)

if model == "mobilenet":
cfg = cfg_mnet
state_dict = load_state_dict_from_url(
"https://raw.githubusercontent.com/hukkelas/DSFD-Pytorch-Inference/master/RetinaFace_mobilenet025.pth",
map_location=torch_utils.get_device()
)
else:
assert model == "resnet50"
cfg = cfg_re50
state_dict = load_state_dict_from_url(
"https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/8dd81669-eb84-4520-8173-dbe49d72f44cb2eef6da-3983-4a12-9085-d11555b93842c19bdf27-b924-4214-9381-e6cac30b87cf",
map_location=torch_utils.get_device()
)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

state_dict = torch_utils.load_weights(self.model_weights)
net = RetinaFace(cfg=cfg)
net.eval()
net.load_state_dict(state_dict)
Expand Down Expand Up @@ -119,7 +114,7 @@ def _detect(
self.cfg, image_size=(height, width))
priors = priorbox.forward()
self.prior_box_cache[image.shape[2:]] = priors
priors = torch_utils.to_cuda(priors, self.device)
priors = priors.to(self.device)
prior_data = priors.data
boxes = batched_decode(loc, prior_data, self.cfg['variance'])
boxes = torch.cat((boxes, scores), dim=-1)
Expand Down
3 changes: 2 additions & 1 deletion face_detection/retinaface/models/retinaface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
import torch.nn as nn
import torchvision.models._utils as _utils
from .net import MobileNetV1, SSH, FPN

from face_detection.retinaface.models.net import MobileNetV1, SSH, FPN


class ClassHead(nn.Module):
Expand Down
17 changes: 9 additions & 8 deletions face_detection/retinaface/onnx.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Adapted from https://github.com/biubug6/Pytorch_Retinaface
# Original license: MIT
import torch
import cv2
import numpy as np
from .. import torch_utils
from .models.retinaface import RetinaFace
from ..box_utils import batched_decode
from .utils import decode_landm
from .config import cfg_re50
from .prior_box import PriorBox

import torch
from torch.hub import load_state_dict_from_url

from face_detection.retinaface.models.retinaface import RetinaFace
from face_detection.box_utils import batched_decode
from face_detection.retinaface.utils import decode_landm
from face_detection.retinaface.config import cfg_re50
from face_detection.retinaface.prior_box import PriorBox


class RetinaNetDetectorONNX(torch.nn.Module):

Expand All @@ -20,7 +21,7 @@ def __init__(self, input_imshape, inference_imshape):
cfg = cfg_re50
state_dict = load_state_dict_from_url(
"https://folk.ntnu.no/haakohu/RetinaFace_ResNet50.pth",
map_location=torch_utils.get_device()
map_location="cpu"
)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
net = RetinaFace(cfg=cfg)
Expand Down
5 changes: 3 additions & 2 deletions face_detection/retinaface/tensorrt_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import cv2
import tensorrt as trt
import torch
from .onnx import RetinaNetDetectorONNX
from .utils import python_nms

from face_detection.retinaface.onnx import RetinaNetDetectorONNX
from face_detection.retinaface.utils import python_nms


cache_dir = torch.hub._get_torch_home()
Expand Down
28 changes: 14 additions & 14 deletions face_detection/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
import numpy as np
import torch


def to_cuda(elements, device):
if torch.cuda.is_available():
if type(elements) == tuple or type(elements) == list:
return [x.to(device) for x in elements]
return elements.to(device)
return elements


def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
import os


def image_to_torch(image, device):
Expand All @@ -25,3 +12,16 @@ def image_to_torch(image, device):
image = image[None, :, :, :]
image = torch.from_numpy(image).to(device)
return image


def load_weights(weights_path):
if os.path.isfile(weights_path):
state_dict = torch.load(weights_path, map_location="cpu")
else:
state_dict = torch.hub.load_state_dict_from_url(
weights_path,
map_location="cpu"
)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

return state_dict
7 changes: 0 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
import setuptools
import torch
import torchvision

torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
assert torch_ver >= [1, 6], "Requires PyTorch >= 1.6"
torchvision_ver = [int(x) for x in torchvision.__version__.split(".")[:2]]
assert torchvision_ver >= [0, 3], "Requires torchvision >= 0.3"

setuptools.setup(
name="face_detection",
Expand Down
Loading