-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
146 lines (119 loc) · 4.84 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import cv2
import argparse
import warnings
import numpy as np
from models import SCRFD, Attribute
from utils.helpers import Face, draw_face_info
warnings.filterwarnings("ignore")
def load_models(detection_model_path: str, attribute_model_path: str):
"""Loads the detection and attribute models.
Args:
detection_model_path (str): Path to the detection model file.
attribute_model_path (str): Path to the attribute model file.
Returns
tuple: A tuple containing the detection model and the attribute model.
"""
try:
detection_model = SCRFD(model_path=detection_model_path)
attribute_model = Attribute(model_path=attribute_model_path)
except Exception as e:
print(f"Error loading models: {e}")
raise
return detection_model, attribute_model
def inference_image(detection_model, attribute_model, image_path, save_output):
"""Processes a single image for face detection and attributes.
Args:
detection_model (SCRFD): The face detection model.
attribute_model (Attribute): The attribute detection model.
image_path (str): Path to the input image.
save_output (str): Path to save the output image.
"""
frame = cv2.imread(image_path)
if frame is None:
print("Failed to load image")
return
process_frame(detection_model, attribute_model, frame)
if save_output:
cv2.imwrite(save_output, frame)
cv2.imshow("FaceDetection", frame)
cv2.waitKey(0)
cv2.destroyAllWindows()
def inference_video(detection_model, attribute_model, video_source, save_output):
"""Processes a video source for face detection and attributes.
Args:
detection_model (SCRFD): The face detection model.
attribute_model (Attribute): The attribute detection model.
video_source (str or int): Path to the input video file or camera index.
save_output (str): Path to save the output video.
"""
if video_source.isdigit() or video_source == '0':
cap = cv2.VideoCapture(int(video_source))
else:
cap = cv2.VideoCapture(video_source)
if not cap.isOpened():
print("Failed to open video source")
return
out = None
if save_output:
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(save_output, fourcc, 30.0, (int(cap.get(3)), int(cap.get(4))))
while True:
ret, frame = cap.read()
if not ret:
break
process_frame(detection_model, attribute_model, frame)
if save_output:
out.write(frame)
cv2.imshow("FaceDetection", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
cap.release()
if save_output:
out.release()
cv2.destroyAllWindows()
def process_frame(detection_model, attribute_model, frame):
"""Detects faces and attributes in a frame and draws the information.
Args:
detection_model (SCRFD): The face detection model.
attribute_model (Attribute): The attribute detection model.
frame (np.ndarray): The image frame to process.
"""
boxes_list, points_list = detection_model.detect(frame)
for boxes, keypoints in zip(boxes_list, points_list):
*bbox, conf_score = boxes
gender, age = attribute_model.get(frame, bbox)
face = Face(kps=keypoints, bbox=bbox, age=age, gender=gender)
draw_face_info(frame, face)
def run_face_analysis(detection_weights, attribute_weights, input_source, save_output=None):
"""Runs face detection on the given input source."""
detection_model, attribute_model = load_models(detection_weights, attribute_weights)
if isinstance(input_source, str) and input_source.lower().endswith(('.jpg', '.png', '.jpeg')):
inference_image(detection_model, attribute_model, input_source, save_output)
else:
inference_video(detection_model, attribute_model, input_source, save_output)
def main():
"""Main function to run face detection from command line."""
parser = argparse.ArgumentParser(description="Run face detection on an image or video")
parser.add_argument(
'--detection-weights',
type=str,
default="weights/det_10g.onnx",
help='Path to the detection model weights file'
)
parser.add_argument(
'--attribute-weights',
type=str,
default="weights/genderage.onnx",
help='Path to the attribute model weights file'
)
parser.add_argument(
'--source',
type=str,
default="assets/in_image.jpg",
help='Path to the input image or video file or camera index (0, 1, ...)'
)
parser.add_argument('--output', type=str, help='Path to save the output image or video')
args = parser.parse_args()
run_face_analysis(args.detection_weights, args.attribute_weights, args.source, args.output)
if __name__ == "__main__":
main()