Skip to content

Commit

Permalink
feat(WIP): create a dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
jws-1 committed Jan 24, 2024
1 parent faace13 commit 684151d
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 9 deletions.
1 change: 1 addition & 0 deletions common/vision/lasr_face_recognition/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ include_directories(
## Mark executable scripts (Python etc.) for installation
## in contrast to setup.py, you can choose the destination
catkin_install_python(PROGRAMS
scripts/create_dataset
examples/relay
examples/greet
nodes/service
Expand Down
57 changes: 57 additions & 0 deletions common/vision/lasr_face_recognition/examples/greet
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3

import sys
import rospy
from copy import deepcopy

from sensor_msgs.msg import Image

from lasr_vision_msgs.srv import Recognise, RecogniseRequest

if len(sys.argv) < 3:
print('Usage: rosrun lase_recognition greet <source_topic> <dataset>')
exit()

listen_topic = sys.argv[1]
dataset = sys.argv[2]
people_in_frame = []
last_received_time = None


def detect(image):
rospy.loginfo("Received image message")
global people_in_frame
people_in_frame = []
try:
detect_service = rospy.ServiceProxy('/recognise', Recognise)
req = RecogniseRequest()
req.image_raw = image
req.dataset = dataset
req.confidence = 0.5
resp = detect_service(req)
for detection in resp.detections:
people_in_frame.append(detection.name)
print(resp)
except rospy.ServiceException as e:
rospy.logerr("Service call failed: %s" % e)

def greet():
print(f"Hello, {' '.join(people_in_frame)}")

def image_callback(image):
global last_received_time
if last_received_time is None or rospy.Time.now() - last_received_time >= rospy.Duration(5.0):
prev_people_in_frame = deepcopy(people_in_frame)
detect(image)
if people_in_frame != prev_people_in_frame:
greet()
last_received_time = rospy.Time.now()

def listener():
rospy.init_node('image_listener', anonymous=True)
rospy.wait_for_service('/recognise')
rospy.Subscriber(listen_topic, Image, image_callback)
rospy.spin()

if __name__ == '__main__':
listener()
2 changes: 1 addition & 1 deletion common/vision/lasr_face_recognition/launch/camera.launch
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
</include>

<!-- show debug topic -->
<node name="image_view" pkg="rqt_image_view" type="rqt_image_view" respawn="false" output="screen" args="/yolov8/debug" />
<node name="image_view" pkg="rqt_image_view" type="rqt_image_view" respawn="false" output="screen" args="/recognise/debug" />

<!-- start relay service -->
<node name="relay" pkg="lasr_face_recognition" type="relay" respawn="false" output="screen" args="/camera/image_raw $(arg dataset)" />
Expand Down
1 change: 1 addition & 0 deletions common/vision/lasr_face_recognition/requirements.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
deepface==0.0.81
numpy>=1.2.1
13 changes: 8 additions & 5 deletions common/vision/lasr_face_recognition/scripts/create_dataset
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#!/usr/bin/env python3

import sys
import lasr_face_recognition as face_recognition

if len(sys.argv) < 3:
print("usage: rosrun lasr_face_recognition create_dataset.py <dataset> <name> [size=20]")
print("usage: rosrun lasr_face_recognition create_dataset.py <dataset> <name> [size=50]")
exit(0)

dataset = sys.argv[1]
Expand All @@ -12,7 +13,7 @@ name = sys.argv[2]
if len(sys.argv) > 3:
size = sys.argv[3]
else:
size = 20
size = 50

import rospy
import rospkg
Expand All @@ -32,6 +33,8 @@ rospy.loginfo(f"Taking {size} pictures of {name} and saving to {DATASET_PATH}")
for i in range(size):
img_msg = rospy.wait_for_message("/xtion/rgb/image_raw", Image)
cv_im = cv2_img.msg_to_cv2_img(img_msg)
cv2.imwrite(os.path.join(DATASET_PATH), f"{name}_{i+1}.png", cv_im)
rospy.loginfo(f"Took picutre {i+1}")
rospy.sleep(rospy.Duration(1.0))
face_cropped_cv_im = face_recognition.detect_face(cv_im)
if face_cropped_cv_im is None:
continue
cv2.imwrite(os.path.join(DATASET_PATH, f"{name}_{i+1}.png"), face_cropped_cv_im)
rospy.loginfo(f"Took picture {i+1}")
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .deepface import detect
from .deepface import detect, detect_face
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,23 @@
import rospkg
import rospy
import os
import numpy as np

from lasr_vision_msgs.msg import Detection
from lasr_vision_msgs.srv import RecogniseRequest, RecogniseResponse

DATASET_ROOT = os.path.join(rospkg.RosPack().get_path("lasr_face_recognition"), "datasets")

Mat = int#np.typing.NDArray[np.uint8]

def detect_face(cv_im : Mat) -> Mat | None:
faces = DeepFace.extract_faces(cv_im, target_size=(224, 244), detector_backend="mtcnn", enforce_detection = False)
if not faces:
return None
facial_area = faces[0]["facial_area"]
x,y,w,h = facial_area["x"], facial_area["y"], facial_area["w"], facial_area["h"]
return cv_im[:][y:y+h, x:x+w]

def detect(request : RecogniseRequest, debug_publisher: rospy.Publisher | None) -> RecogniseResponse:

# Decode the image
Expand All @@ -18,7 +29,7 @@ def detect(request : RecogniseRequest, debug_publisher: rospy.Publisher | None)

# Run inference
rospy.loginfo("Running inference")
result = DeepFace.find(cv_im, os.path.join(DATASET_ROOT, request.dataset), enforce_detection=False)
result = DeepFace.find(cv_im, os.path.join(DATASET_ROOT, request.dataset), enforce_detection=False, silent=True)

response = RecogniseResponse()

Expand All @@ -29,11 +40,12 @@ def detect(request : RecogniseRequest, debug_publisher: rospy.Publisher | None)
detection.name = row["identity"][0].split("/")[-1].split("_")[0]
x, y, w, h = row["source_x"][0], row["source_y"][0], row["source_w"][0], row["source_h"][0]
detection.xywh = [x, y, w, h]
confidence = row["VGG-Face_cosine"]
response.detections.append(detection)

# Draw bounding boxes and labels for debugging
cv2.rectangle(cv_im, (x, y), (x+w, y+h), (0, 0, 255), 2)
cv2.putText(cv_im, detection.name, (x,y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
cv2.putText(cv_im, f"{detection.name} ({confidence})", (x,y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

# publish to debug topic
if debug_publisher is not None:
Expand Down

0 comments on commit 684151d

Please sign in to comment.