diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server
index 95f3b1e36..ed1c01d25 100644
--- a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server
+++ b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server
@@ -13,6 +13,7 @@ import torch
import actionlib
import speech_recognition as sr # type: ignore
import lasr_speech_recognition_msgs.msg # type: ignore
+from std_msgs.msg import String # type: ignore
from lasr_speech_recognition_whisper import load_model # type: ignore
# Error handler to remove ALSA error messages taken from:
@@ -83,7 +84,9 @@ class TranscribeSpeechAction(object):
self._action_name = action_name
self._model_params = model_params
-
+ self._transcription_server = rospy.Publisher(
+ "/live_speech_transcription", String, queue_size=10
+ )
with noalsaerr():
self._model = load_model(
self._model_params.model_name,
@@ -91,7 +94,7 @@ class TranscribeSpeechAction(object):
self._model_params.warmup,
)
# Configure the speech recogniser object and adjust for ambient noise
- self.recogniser = self._configure_recogniser()
+ self.recogniser = self._configure_recogniser(ambient_adj=True)
# Setup the action server and register execution callback
self._action_server = actionlib.SimpleActionServer(
self._action_name,
@@ -110,6 +113,7 @@ class TranscribeSpeechAction(object):
self._action_server.start()
def _timer_cb(self, _) -> None:
+ return
"""Adjusts the microphone for ambient noise, unless the action server is listening."""
if self._listening:
return
@@ -223,7 +227,7 @@ class TranscribeSpeechAction(object):
rospy.loginfo(
f"Time taken: {transcription_end_time - transcription_start_time:.2f}s"
)
-
+ self._transcription_server.publish(phrase)
if self._action_server.is_preempt_requested():
self._listening = False
return
@@ -248,13 +252,13 @@ def parse_args() -> dict:
)
parser.add_argument(
- "--action-name",
+ "--action_name",
type=str,
default="transcribe_speech",
help="Name of the action server.",
)
parser.add_argument(
- "--model-name",
+ "--model_name",
type=str,
default="medium.en",
help="Name of the speech recognition model.",
@@ -300,8 +304,8 @@ def parse_args() -> dict:
action="store_true",
help="Disable warming up the model by running inference on a test file.",
)
-
- return vars(parser.parse_args())
+ args, unknown = parser.parse_known_args()
+ return vars(args)
def configure_model_params(config: dict) -> speech_model_params:
@@ -346,6 +350,8 @@ def configure_whisper_cache() -> None:
if __name__ == "__main__":
configure_whisper_cache()
config = parse_args()
- rospy.init_node(config["action_name"])
- server = TranscribeSpeechAction(rospy.get_name(), configure_model_params(config))
+ rospy.init_node("speech_transcription_node")
+ server = TranscribeSpeechAction(
+ config["action_name"], configure_model_params(config)
+ )
rospy.spin()
diff --git a/common/vector_databases/lasr_vector_databases_faiss/.gitignore b/common/vector_databases/lasr_vector_databases_faiss/.gitignore
new file mode 100644
index 000000000..0d45877c1
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/.gitignore
@@ -0,0 +1,2 @@
+data/**
+!data/.gitkeep
\ No newline at end of file
diff --git a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt
new file mode 100644
index 000000000..1d550fbad
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt
@@ -0,0 +1,216 @@
+cmake_minimum_required(VERSION 3.0.2)
+project(lasr_vector_databases_faiss)
+
+## Compile as C++11, supported in ROS Kinetic and newer
+# add_compile_options(-std=c++11)
+
+## Find catkin macros and libraries
+## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
+## is used, also find other catkin packages
+find_package(catkin REQUIRED catkin_virtualenv)
+
+## System dependencies are found with CMake's conventions
+# find_package(Boost REQUIRED COMPONENTS system)
+
+
+## Uncomment this if the package has a setup.py. This macro ensures
+## modules and global scripts declared therein get installed
+## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html
+catkin_python_setup()
+catkin_generate_virtualenv(
+ INPUT_REQUIREMENTS requirements.in
+ PYTHON_INTERPRETER python3.9
+)
+################################################
+## Declare ROS messages, services and actions ##
+################################################
+
+## To declare and build messages, services or actions from within this
+## package, follow these steps:
+## * Let MSG_DEP_SET be the set of packages whose message types you use in
+## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...).
+## * In the file package.xml:
+## * add a build_depend tag for "message_generation"
+## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET
+## * If MSG_DEP_SET isn't empty the following dependency has been pulled in
+## but can be declared for certainty nonetheless:
+## * add a exec_depend tag for "message_runtime"
+## * In this file (CMakeLists.txt):
+## * add "message_generation" and every package in MSG_DEP_SET to
+## find_package(catkin REQUIRED COMPONENTS ...)
+## * add "message_runtime" and every package in MSG_DEP_SET to
+## catkin_package(CATKIN_DEPENDS ...)
+## * uncomment the add_*_files sections below as needed
+## and list every .msg/.srv/.action file to be processed
+## * uncomment the generate_messages entry below
+## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...)
+
+## Generate messages in the 'msg' folder
+# add_message_files(
+# FILES
+# Message1.msg
+# Message2.msg
+# )
+
+## Generate services in the 'srv' folder
+# add_service_files(
+# FILES
+# Service1.srv
+# Service2.srv
+# )
+
+# Generate actions in the 'action' folder
+# add_action_files(
+# DIRECTORY action
+# FILES WaitGreet.action Identify.action Greet.action GetName.action LearnFace.action GetCommand.action Guide.action DetectPeople.action FindPerson.action ReceiveObject.action HandoverObject.action
+# )
+
+# Generate added messages and services with any dependencies listed here
+# generate_messages(
+# DEPENDENCIES
+# actionlib_msgs
+# geometry_msgs
+# )
+
+################################################
+## Declare ROS dynamic reconfigure parameters ##
+################################################
+
+## To declare and build dynamic reconfigure parameters within this
+## package, follow these steps:
+## * In the file package.xml:
+## * add a build_depend and a exec_depend tag for "dynamic_reconfigure"
+## * In this file (CMakeLists.txt):
+## * add "dynamic_reconfigure" to
+## find_package(catkin REQUIRED COMPONENTS ...)
+## * uncomment the "generate_dynamic_reconfigure_options" section below
+## and list every .cfg file to be processed
+
+## Generate dynamic reconfigure parameters in the 'cfg' folder
+# generate_dynamic_reconfigure_options(
+# cfg/DynReconf1.cfg
+# cfg/DynReconf2.cfg
+# )
+
+###################################
+## catkin specific configuration ##
+###################################
+## The catkin_package macro generates cmake config files for your package
+## Declare things to be passed to dependent projects
+## INCLUDE_DIRS: uncomment this if your package contains header files
+## LIBRARIES: libraries you create in this project that dependent projects also need
+## CATKIN_DEPENDS: catkin_packages dependent projects also need
+## DEPENDS: system dependencies of this project that dependent projects also need
+catkin_package(
+# INCLUDE_DIRS include
+# LIBRARIES qualification
+# DEPENDS system_lib
+)
+
+###########
+## Build ##
+###########
+
+## Specify additional locations of header files
+## Your package locations should be listed before other locations
+include_directories(
+# include
+ ${catkin_INCLUDE_DIRS}
+)
+
+## Declare a C++ library
+# add_library(${PROJECT_NAME}
+# src/${PROJECT_NAME}/qualification.cpp
+# )
+
+## Add cmake target dependencies of the library
+## as an example, code may need to be generated before libraries
+## either from message generation or dynamic reconfigure
+# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
+
+## Declare a C++ executable
+## With catkin_make all packages are built within a single CMake context
+## The recommended prefix ensures that target names across packages don't collide
+# add_executable(${PROJECT_NAME}_node src/qualification_node.cpp)
+
+## Rename C++ executable without prefix
+## The above recommended prefix causes long target names, the following renames the
+## target back to the shorter version for ease of user use
+## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node"
+# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "")
+
+## Add cmake target dependencies of the executable
+## same as for the library above
+# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
+
+## Specify libraries to link a library or executable target against
+# target_link_libraries(${PROJECT_NAME}_node
+# ${catkin_LIBRARIES}
+# )
+
+#############
+## Install ##
+#############
+
+# all install targets should use catkin DESTINATION variables
+# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html
+
+## Mark executable scripts (Python etc.) for installation
+## in contrast to setup.py, you can choose the destination
+# catkin_install_python(PROGRAMS
+# nodes/qualification
+# nodes/actions/wait_greet
+# nodes/actions/identify
+# nodes/actions/greet
+# nodes/actions/get_name
+# nodes/actions/learn_face
+# nodes/actions/get_command
+# nodes/actions/guide
+# nodes/actions/find_person
+# nodes/actions/detect_people
+# nodes/actions/receive_object
+# nodes/actions/handover_object
+# nodes/better_qualification
+# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+# )
+
+## Mark executables for installation
+## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html
+# install(TARGETS ${PROJECT_NAME}_node
+# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+# )
+
+## Mark libraries for installation
+## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html
+# install(TARGETS ${PROJECT_NAME}
+# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION}
+# )
+
+## Mark cpp header files for installation
+# install(DIRECTORY include/${PROJECT_NAME}/
+# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}
+# FILES_MATCHING PATTERN "*.h"
+# PATTERN ".svn" EXCLUDE
+# )
+
+## Mark other files for installation (e.g. launch and bag files, etc.)
+# install(FILES
+# # myfile1
+# # myfile2
+# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+# )
+
+#############
+## Testing ##
+#############
+
+## Add gtest based cpp test target and link libraries
+# catkin_add_gtest(${PROJECT_NAME}-test test/test_qualification.cpp)
+# if(TARGET ${PROJECT_NAME}-test)
+# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME})
+# endif()
+
+## Add folders to be run by python nosetests
+# catkin_add_nosetests(test)
diff --git a/common/vector_databases/lasr_vector_databases_faiss/data/.gitkeep b/common/vector_databases/lasr_vector_databases_faiss/data/.gitkeep
new file mode 100644
index 000000000..e69de29bb
diff --git a/common/vector_databases/lasr_vector_databases_faiss/package.xml b/common/vector_databases/lasr_vector_databases_faiss/package.xml
new file mode 100644
index 000000000..f8128ea56
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/package.xml
@@ -0,0 +1,58 @@
+
+
+ lasr_vector_databases_faiss
+ 0.0.0
+ The faiss package
+
+
+
+
+ Matt Barker
+
+
+
+
+
+ MIT
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ catkin
+ catkin_virtualenv
+
+
+
+ requirements.txt
+
+
diff --git a/common/vector_databases/lasr_vector_databases_faiss/requirements.in b/common/vector_databases/lasr_vector_databases_faiss/requirements.in
new file mode 100644
index 000000000..14955d38d
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/requirements.in
@@ -0,0 +1 @@
+faiss-cpu
\ No newline at end of file
diff --git a/common/vector_databases/lasr_vector_databases_faiss/requirements.txt b/common/vector_databases/lasr_vector_databases_faiss/requirements.txt
new file mode 100644
index 000000000..4557b8f0c
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/requirements.txt
@@ -0,0 +1,53 @@
+--extra-index-url https://pypi.ngc.nvidia.com
+--trusted-host pypi.ngc.nvidia.com
+
+certifi==2024.2.2 # via requests
+charset-normalizer==3.3.2 # via requests
+click==8.1.7 # via nltk
+faiss-cpu==1.7.4 # via -r requirements.in
+filelock==3.13.1 # via huggingface-hub, torch, transformers, triton
+fsspec==2024.2.0 # via huggingface-hub, torch
+ftfy==6.1.3 # via -r requirements.in, clip
+huggingface-hub==0.20.3 # via sentence-transformers, tokenizers, transformers
+idna==3.6 # via requests
+jinja2==3.1.3 # via torch
+joblib==1.3.2 # via nltk, scikit-learn
+markupsafe==2.1.5 # via jinja2
+mpmath==1.3.0 # via sympy
+networkx==3.2.1 # via torch
+nltk==3.8.1 # via sentence-transformers
+numpy==1.26.3 # via opencv-python, scikit-learn, scipy, sentence-transformers, torchvision, transformers
+nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch
+nvidia-cuda-cupti-cu12==12.1.105 # via torch
+nvidia-cuda-nvrtc-cu12==12.1.105 # via torch
+nvidia-cuda-runtime-cu12==12.1.105 # via torch
+nvidia-cudnn-cu12==8.9.2.26 # via torch
+nvidia-cufft-cu12==11.0.2.54 # via torch
+nvidia-curand-cu12==10.3.2.106 # via torch
+nvidia-cusolver-cu12==11.4.5.107 # via torch
+nvidia-cusparse-cu12==12.1.0.106 # via nvidia-cusolver-cu12, torch
+nvidia-nccl-cu12==2.19.3 # via torch
+nvidia-nvjitlink-cu12==12.3.101 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105 # via torch
+opencv-python==4.9.0.80 # via -r requirements.in
+packaging==23.2 # via huggingface-hub, transformers
+pillow==10.2.0 # via sentence-transformers, torchvision
+pyyaml==6.0.1 # via huggingface-hub, transformers
+regex==2023.12.25 # via -r requirements.in, clip, nltk, transformers
+requests==2.31.0 # via huggingface-hub, torchvision, transformers
+safetensors==0.4.2 # via transformers
+scikit-learn==1.4.0 # via sentence-transformers
+scipy==1.12.0 # via scikit-learn, sentence-transformers
+sentence-transformers==2.3.1 # via -r requirements.in
+sentencepiece==0.1.99 # via sentence-transformers
+sympy==1.12 # via torch
+threadpoolctl==3.2.0 # via scikit-learn
+tokenizers==0.15.1 # via transformers
+torch==2.2.0 # via clip, sentence-transformers, torchvision
+torchvision==0.17.0 # via clip
+tqdm==4.66.1 # via -r requirements.in, clip, huggingface-hub, nltk, sentence-transformers, transformers
+transformers==4.37.2 # via sentence-transformers
+triton==2.2.0 # via torch
+typing-extensions==4.9.0 # via huggingface-hub, torch
+urllib3==2.2.0 # via requests
+wcwidth==0.2.13 # via ftfy
diff --git a/common/vector_databases/lasr_vector_databases_faiss/setup.py b/common/vector_databases/lasr_vector_databases_faiss/setup.py
new file mode 100644
index 000000000..443aec4a4
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/setup.py
@@ -0,0 +1,10 @@
+#!/usr/bin/env python3
+
+from distutils.core import setup
+from catkin_pkg.python_setup import generate_distutils_setup
+
+setup_args = generate_distutils_setup(
+ packages=["lasr_vector_databases_faiss"], package_dir={"": "src"}
+)
+
+setup(**setup_args)
diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py
new file mode 100755
index 000000000..f7dbcbe24
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py
@@ -0,0 +1,157 @@
+#!/usr/bin/env python3
+import os
+import torch
+import numpy as np
+
+# import rospy
+import faiss # type: ignore
+from sentence_transformers import SentenceTransformer # type: ignore
+from typing import Optional
+
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+def load_commands(command_path: str) -> list[str]:
+ """Loads the commands stored in the given txt file
+ into a list of string commands
+ Args:
+ command_path (str): path to the txt file containing
+ the commands -- assumes one command per line.
+ Returns:
+ list[str]: list of string commands where each entry in the
+ list is a command
+ """
+ command_list = []
+ with open(command_path, "r", encoding="utf8") as src:
+ for command in src:
+ # Strip newline char.
+ command_list.append(command[:-1])
+ return command_list
+
+
+def get_sentence_embeddings(
+ command_list: list[str], model: SentenceTransformer
+) -> np.ndarray:
+ """Converts the list of command strings into an array of sentence
+ embeddings (where each command is a sentence and each sentence
+ is converted to a vector)
+ Args:
+ command_list (list[str]): list of string commands, where each
+ entry in the list is assumed to be a separate command
+ model (SentenceTransformer): model used to perform the embedding.
+ Assumes a method called encode that takes a list of strings
+ as input.
+ Returns:
+ np.ndarray: array of shape (n_commands, embedding_dim)
+ """
+
+ return model.encode(
+ command_list,
+ convert_to_numpy=True,
+ show_progress_bar=True,
+ batch_size=256,
+ device=DEVICE,
+ )
+
+
+def create_vector_database(vectors: np.ndarray) -> faiss.IndexFlatIP:
+ """Creates a vector database from an array of vectors of the same dimensionality
+ Args:
+ vectors (np.ndarray): shape (n_vectors, vector_dim)
+
+ Returns:
+ faiss.IndexFlatIP: Flat index containing the vectors
+ """
+ print("Creating vector database")
+ index_flat = faiss.IndexFlatIP(vectors.shape[1])
+ faiss.normalize_L2(vectors)
+ index_flat.add(vectors)
+ print("Finished creating vector database")
+ return index_flat
+
+
+def get_command_database(
+ index_path: str, command_path: Optional[str] = None
+) -> faiss.IndexFlatL2:
+ """Gets a vector database containing a list of embedded commands. Creates the database
+ if the path does not exist, else, loads it into memory.
+
+ Args:
+ index_path (str): Path to an existing faiss Index, or where to save a new one.
+ command_path (str, optional): Path of text file containing commands.
+ Only required if creating a new database. Defaults to None.
+
+ Returns:
+ faiss.IndexFlatL2: faiss Index object containing the embedded commands.
+ """
+
+ if not os.path.exists(f"{index_path}.index"):
+ # rospy.loginfo("Creating new command vector database")
+ assert command_path is not None
+ command_list = load_commands(command_path)
+ model = SentenceTransformer("all-MiniLM-L6-v2")
+ command_embeddings = get_sentence_embeddings(command_list, model)
+ print(command_embeddings.shape)
+ command_database = create_vector_database(command_embeddings)
+ faiss.write_index(command_database, f"{index_path}.index")
+ # rospy.loginfo("Finished creating vector database")
+ else:
+ command_database = faiss.read_index(f"{index_path}.index")
+
+ return command_database
+
+
+def get_similar_commands(
+ command: str,
+ index_path: str,
+ command_path: str,
+ n_similar_commands: int = 100,
+ return_embeddings: bool = False,
+) -> tuple[list[str], list[float]]:
+ """Gets the most similar commands to the given command string
+ Args:
+ command (str): command to compare against the database
+ index_path (str): path to the location to create or retrieve
+ the faiss index containing the embedded commands.
+ command_path (str): path to the txt file containing the commands
+ n_similar_commands (int, optional): number of similar commands to
+ return. Defaults to 100.
+ Returns:
+ list[str]: list of string commands, where each entry in the
+ list is a similar command
+ """
+ command_database = get_command_database(index_path, command_path)
+ command_list = load_commands(command_path)
+ model = SentenceTransformer("all-MiniLM-L6-v2")
+ command_embedding = get_sentence_embeddings([command], model)
+ faiss.normalize_L2(command_embedding)
+ command_distances, command_indices = command_database.search(
+ command_embedding, n_similar_commands
+ )
+ nearest_commands = [command_list[i] for i in command_indices[0]]
+
+ if return_embeddings:
+ all_command_embeddings = get_sentence_embeddings(command_list, model)
+ # filter for only nererst commands
+ all_command_embeddings = all_command_embeddings[command_indices[0]]
+ return (
+ nearest_commands,
+ list(command_distances[0]),
+ all_command_embeddings,
+ command_embedding,
+ )
+
+ return nearest_commands, list(command_distances[0])
+
+
+if __name__ == "__main__":
+ """Example usage of using this to find similar commands"""
+ command = "find Jared and asks if he needs help"
+ result, distances, command_embeddings, query_embedding = get_similar_commands(
+ command,
+ "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/qualification/data/command_index",
+ "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/qualification/data/command_list.txt",
+ n_similar_commands=1000,
+ return_embeddings=True,
+ )
+ print(result)
diff --git a/common/vision/lasr_vision_clip/.gitignore b/common/vision/lasr_vision_clip/.gitignore
new file mode 100644
index 000000000..0d45877c1
--- /dev/null
+++ b/common/vision/lasr_vision_clip/.gitignore
@@ -0,0 +1,2 @@
+data/**
+!data/.gitkeep
\ No newline at end of file
diff --git a/common/vision/lasr_vision_clip/CMakeLists.txt b/common/vision/lasr_vision_clip/CMakeLists.txt
new file mode 100644
index 000000000..a13eb6f2a
--- /dev/null
+++ b/common/vision/lasr_vision_clip/CMakeLists.txt
@@ -0,0 +1,216 @@
+cmake_minimum_required(VERSION 3.0.2)
+project(lasr_vision_clip)
+
+## Compile as C++11, supported in ROS Kinetic and newer
+# add_compile_options(-std=c++11)
+
+## Find catkin macros and libraries
+## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
+## is used, also find other catkin packages
+find_package(catkin REQUIRED catkin_virtualenv)
+
+## System dependencies are found with CMake's conventions
+# find_package(Boost REQUIRED COMPONENTS system)
+
+
+## Uncomment this if the package has a setup.py. This macro ensures
+## modules and global scripts declared therein get installed
+## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html
+catkin_python_setup()
+catkin_generate_virtualenv(
+ INPUT_REQUIREMENTS requirements.in
+ PYTHON_INTERPRETER python3.9
+)
+################################################
+## Declare ROS messages, services and actions ##
+################################################
+
+## To declare and build messages, services or actions from within this
+## package, follow these steps:
+## * Let MSG_DEP_SET be the set of packages whose message types you use in
+## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...).
+## * In the file package.xml:
+## * add a build_depend tag for "message_generation"
+## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET
+## * If MSG_DEP_SET isn't empty the following dependency has been pulled in
+## but can be declared for certainty nonetheless:
+## * add a exec_depend tag for "message_runtime"
+## * In this file (CMakeLists.txt):
+## * add "message_generation" and every package in MSG_DEP_SET to
+## find_package(catkin REQUIRED COMPONENTS ...)
+## * add "message_runtime" and every package in MSG_DEP_SET to
+## catkin_package(CATKIN_DEPENDS ...)
+## * uncomment the add_*_files sections below as needed
+## and list every .msg/.srv/.action file to be processed
+## * uncomment the generate_messages entry below
+## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...)
+
+## Generate messages in the 'msg' folder
+# add_message_files(
+# FILES
+# Message1.msg
+# Message2.msg
+# )
+
+## Generate services in the 'srv' folder
+# add_service_files(
+# FILES
+# Service1.srv
+# Service2.srv
+# )
+
+# Generate actions in the 'action' folder
+# add_action_files(
+# DIRECTORY action
+# FILES WaitGreet.action Identify.action Greet.action GetName.action LearnFace.action GetCommand.action Guide.action DetectPeople.action FindPerson.action ReceiveObject.action HandoverObject.action
+# )
+
+# Generate added messages and services with any dependencies listed here
+# generate_messages(
+# DEPENDENCIES
+# actionlib_msgs
+# geometry_msgs
+# )
+
+################################################
+## Declare ROS dynamic reconfigure parameters ##
+################################################
+
+## To declare and build dynamic reconfigure parameters within this
+## package, follow these steps:
+## * In the file package.xml:
+## * add a build_depend and a exec_depend tag for "dynamic_reconfigure"
+## * In this file (CMakeLists.txt):
+## * add "dynamic_reconfigure" to
+## find_package(catkin REQUIRED COMPONENTS ...)
+## * uncomment the "generate_dynamic_reconfigure_options" section below
+## and list every .cfg file to be processed
+
+## Generate dynamic reconfigure parameters in the 'cfg' folder
+# generate_dynamic_reconfigure_options(
+# cfg/DynReconf1.cfg
+# cfg/DynReconf2.cfg
+# )
+
+###################################
+## catkin specific configuration ##
+###################################
+## The catkin_package macro generates cmake config files for your package
+## Declare things to be passed to dependent projects
+## INCLUDE_DIRS: uncomment this if your package contains header files
+## LIBRARIES: libraries you create in this project that dependent projects also need
+## CATKIN_DEPENDS: catkin_packages dependent projects also need
+## DEPENDS: system dependencies of this project that dependent projects also need
+catkin_package(
+# INCLUDE_DIRS include
+# LIBRARIES qualification
+# DEPENDS system_lib
+)
+
+###########
+## Build ##
+###########
+
+## Specify additional locations of header files
+## Your package locations should be listed before other locations
+include_directories(
+# include
+ ${catkin_INCLUDE_DIRS}
+)
+
+## Declare a C++ library
+# add_library(${PROJECT_NAME}
+# src/${PROJECT_NAME}/qualification.cpp
+# )
+
+## Add cmake target dependencies of the library
+## as an example, code may need to be generated before libraries
+## either from message generation or dynamic reconfigure
+# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
+
+## Declare a C++ executable
+## With catkin_make all packages are built within a single CMake context
+## The recommended prefix ensures that target names across packages don't collide
+# add_executable(${PROJECT_NAME}_node src/qualification_node.cpp)
+
+## Rename C++ executable without prefix
+## The above recommended prefix causes long target names, the following renames the
+## target back to the shorter version for ease of user use
+## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node"
+# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "")
+
+## Add cmake target dependencies of the executable
+## same as for the library above
+# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
+
+## Specify libraries to link a library or executable target against
+# target_link_libraries(${PROJECT_NAME}_node
+# ${catkin_LIBRARIES}
+# )
+
+#############
+## Install ##
+#############
+
+# all install targets should use catkin DESTINATION variables
+# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html
+
+## Mark executable scripts (Python etc.) for installation
+## in contrast to setup.py, you can choose the destination
+# catkin_install_python(PROGRAMS
+# nodes/qualification
+# nodes/actions/wait_greet
+# nodes/actions/identify
+# nodes/actions/greet
+# nodes/actions/get_name
+# nodes/actions/learn_face
+# nodes/actions/get_command
+# nodes/actions/guide
+# nodes/actions/find_person
+# nodes/actions/detect_people
+# nodes/actions/receive_object
+# nodes/actions/handover_object
+# nodes/better_qualification
+# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+# )
+
+## Mark executables for installation
+## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html
+# install(TARGETS ${PROJECT_NAME}_node
+# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+# )
+
+## Mark libraries for installation
+## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html
+# install(TARGETS ${PROJECT_NAME}
+# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION}
+# )
+
+## Mark cpp header files for installation
+# install(DIRECTORY include/${PROJECT_NAME}/
+# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}
+# FILES_MATCHING PATTERN "*.h"
+# PATTERN ".svn" EXCLUDE
+# )
+
+## Mark other files for installation (e.g. launch and bag files, etc.)
+# install(FILES
+# # myfile1
+# # myfile2
+# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+# )
+
+#############
+## Testing ##
+#############
+
+## Add gtest based cpp test target and link libraries
+# catkin_add_gtest(${PROJECT_NAME}-test test/test_qualification.cpp)
+# if(TARGET ${PROJECT_NAME}-test)
+# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME})
+# endif()
+
+## Add folders to be run by python nosetests
+# catkin_add_nosetests(test)
diff --git a/common/vision/lasr_vision_clip/data/.gitkeep b/common/vision/lasr_vision_clip/data/.gitkeep
new file mode 100644
index 000000000..e69de29bb
diff --git a/common/vision/lasr_vision_clip/package.xml b/common/vision/lasr_vision_clip/package.xml
new file mode 100644
index 000000000..e64ef13ef
--- /dev/null
+++ b/common/vision/lasr_vision_clip/package.xml
@@ -0,0 +1,58 @@
+
+
+ lasr_vision_clip
+ 0.0.0
+ The lasr_vision_clip package
+
+
+
+
+ Matt Barker
+
+
+
+
+
+ MIT
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ catkin
+ catkin_virtualenv
+
+
+
+ requirements.txt
+
+
diff --git a/common/vision/lasr_vision_clip/requirements.in b/common/vision/lasr_vision_clip/requirements.in
new file mode 100644
index 000000000..b03f1e5cc
--- /dev/null
+++ b/common/vision/lasr_vision_clip/requirements.in
@@ -0,0 +1,2 @@
+sentence-transformers
+opencv-python
\ No newline at end of file
diff --git a/common/vision/lasr_vision_clip/requirements.txt b/common/vision/lasr_vision_clip/requirements.txt
new file mode 100644
index 000000000..6d7e59d07
--- /dev/null
+++ b/common/vision/lasr_vision_clip/requirements.txt
@@ -0,0 +1,53 @@
+--extra-index-url https://pypi.ngc.nvidia.com
+--trusted-host pypi.ngc.nvidia.com
+
+certifi==2024.2.2 # via requests
+charset-normalizer==3.3.2 # via requests
+click==8.1.7 # via nltk
+clip @ git+https://github.com/openai/CLIP.git # via -r requirements.in
+filelock==3.13.1 # via huggingface-hub, torch, transformers, triton
+fsspec==2024.2.0 # via huggingface-hub, torch
+ftfy==6.1.3 # via -r requirements.in, clip
+huggingface-hub==0.20.3 # via sentence-transformers, tokenizers, transformers
+idna==3.6 # via requests
+jinja2==3.1.3 # via torch
+joblib==1.3.2 # via nltk, scikit-learn
+markupsafe==2.1.5 # via jinja2
+mpmath==1.3.0 # via sympy
+networkx==3.2.1 # via torch
+nltk==3.8.1 # via sentence-transformers
+numpy==1.26.3 # via opencv-python, scikit-learn, scipy, sentence-transformers, torchvision, transformers
+nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch
+nvidia-cuda-cupti-cu12==12.1.105 # via torch
+nvidia-cuda-nvrtc-cu12==12.1.105 # via torch
+nvidia-cuda-runtime-cu12==12.1.105 # via torch
+nvidia-cudnn-cu12==8.9.2.26 # via torch
+nvidia-cufft-cu12==11.0.2.54 # via torch
+nvidia-curand-cu12==10.3.2.106 # via torch
+nvidia-cusolver-cu12==11.4.5.107 # via torch
+nvidia-cusparse-cu12==12.1.0.106 # via nvidia-cusolver-cu12, torch
+nvidia-nccl-cu12==2.19.3 # via torch
+nvidia-nvjitlink-cu12==12.3.101 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105 # via torch
+opencv-python==4.9.0.80 # via -r requirements.in
+packaging==23.2 # via huggingface-hub, transformers
+pillow==10.2.0 # via sentence-transformers, torchvision
+pyyaml==6.0.1 # via huggingface-hub, transformers
+regex==2023.12.25 # via -r requirements.in, clip, nltk, transformers
+requests==2.31.0 # via huggingface-hub, torchvision, transformers
+safetensors==0.4.2 # via transformers
+scikit-learn==1.4.0 # via sentence-transformers
+scipy==1.12.0 # via scikit-learn, sentence-transformers
+sentence-transformers==2.3.1 # via -r requirements.in
+sentencepiece==0.1.99 # via sentence-transformers
+sympy==1.12 # via torch
+threadpoolctl==3.2.0 # via scikit-learn
+tokenizers==0.15.1 # via transformers
+torch==2.2.0 # via clip, sentence-transformers, torchvision
+torchvision==0.17.0 # via clip
+tqdm==4.66.1 # via -r requirements.in, clip, huggingface-hub, nltk, sentence-transformers, transformers
+transformers==4.37.2 # via sentence-transformers
+triton==2.2.0 # via torch
+typing-extensions==4.9.0 # via huggingface-hub, torch
+urllib3==2.2.0 # via requests
+wcwidth==0.2.13 # via ftfy
diff --git a/common/vision/lasr_vision_clip/setup.py b/common/vision/lasr_vision_clip/setup.py
new file mode 100644
index 000000000..d02820e3c
--- /dev/null
+++ b/common/vision/lasr_vision_clip/setup.py
@@ -0,0 +1,10 @@
+#!/usr/bin/env python3
+
+from distutils.core import setup
+from catkin_pkg.python_setup import generate_distutils_setup
+
+setup_args = generate_distutils_setup(
+ packages=["lasr_vision_clip"], package_dir={"": "src"}
+)
+
+setup(**setup_args)
diff --git a/common/vision/lasr_vision_clip/src/lasr_vision_clip/__init__.py b/common/vision/lasr_vision_clip/src/lasr_vision_clip/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/common/vision/lasr_vision_clip/src/lasr_vision_clip/clip_utils.py b/common/vision/lasr_vision_clip/src/lasr_vision_clip/clip_utils.py
new file mode 100644
index 000000000..81d6dd25a
--- /dev/null
+++ b/common/vision/lasr_vision_clip/src/lasr_vision_clip/clip_utils.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python3
+import torch
+import rospy
+import cv2
+import cv2_img
+import numpy as np
+from copy import deepcopy
+from sentence_transformers import SentenceTransformer, util
+
+from sensor_msgs.msg import Image
+
+
+def load_model(device: str = "cuda"):
+ """Load the CLIP model.
+
+ Args:
+ model_name (str): the model name
+ device (str, optional): the device to use. Defaults to "cuda".
+
+ Returns:
+ Any: the model and preprocess function
+ """
+ model = SentenceTransformer("clip-ViT-B-32", device=device)
+ return model
+
+
+def run_clip(
+ model: SentenceTransformer, labels: list[str], img: np.ndarray
+) -> torch.Tensor:
+ """Run the CLIP model.
+
+ Args:
+ model (Any): clip model loaded into memory
+ labels (List[str]): list of string labels to query image similarity to.
+ img (np.ndarray): the image to query
+
+ Returns:
+ List[float]: the cosine similarity scores between the image and label embeddings.
+ """
+ txt = model.encode(labels)
+ img = model.encode(img)
+ with torch.no_grad():
+ cos_scores = util.cos_sim(img, txt)
+ return cos_scores
+
+
+def query_image_stream(
+ model: SentenceTransformer,
+ answers: list[str],
+ annotate: bool = False,
+) -> tuple[str, torch.Tensor, Image]:
+ """Queries the CLIP model with the latest image from the robot's camera
+ and a set of possible image captions and returns the most likely caption.
+
+ Args:
+ model (SentenceTransformer): clip model to run inference on, loaded into memory
+ answers(list[str]): list of possible answers
+ annotate(bool, optional): whether to annotate the image with the most likely, and
+ second most likely, caption. Defaults to False.
+ returns:
+ tuple(str, torch.Tensor, Image): the most likely answer, the scores, and the annotated image msg
+ """
+ img_msg = rospy.wait_for_message("/xtion/rgb/image_raw", Image)
+ img_pil = cv2_img.msg_to_pillow_img(img_msg)
+
+ cos_scores = run_clip(model, answers, img_pil)
+ max_score = cos_scores.argmax()
+ # get second highest score in tensor
+ max_val = deepcopy(cos_scores[0, max_score])
+ cos_scores[0, max_score] = 0
+ second_max_score = cos_scores.argmax()
+ cos_scores[0, max_score] = max_val
+ # Annotate the image
+
+ cv2_im = cv2_img.msg_to_cv2_img(img_msg)
+ if annotate:
+ cv2.putText(
+ cv2_im,
+ f"Most likely caption: {answers[max_score]}",
+ (10, 30),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.5,
+ (0, 255, 0),
+ 2,
+ cv2.LINE_AA,
+ )
+ # add second score below
+ cv2.putText(
+ cv2_im,
+ f"Second most likely caption: {answers[second_max_score]}",
+ (10, 60),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.5,
+ (0, 255, 0),
+ 2,
+ cv2.LINE_AA,
+ )
+
+ img = cv2_img.cv2_img_to_msg(cv2_im)
+ return answers[max_score], cos_scores, img
diff --git a/common/vision/lasr_vision_deepface/examples/greet b/common/vision/lasr_vision_deepface/examples/greet
index 63778c6af..70a87a038 100644
--- a/common/vision/lasr_vision_deepface/examples/greet
+++ b/common/vision/lasr_vision_deepface/examples/greet
@@ -7,51 +7,63 @@ from copy import deepcopy
from sensor_msgs.msg import Image
from lasr_vision_msgs.srv import Recognise, RecogniseRequest
+from lasr_voice import Voice
if len(sys.argv) < 3:
- print('Usage: rosrun lase_recognition greet ')
+ print("Usage: rosrun lase_recognition greet ")
exit()
listen_topic = sys.argv[1]
dataset = sys.argv[2]
people_in_frame = []
-last_received_time = None
+
+
+people_in_frame = {}
def detect(image):
rospy.loginfo("Received image message")
global people_in_frame
- people_in_frame = []
try:
- detect_service = rospy.ServiceProxy('/recognise', Recognise)
+ detect_service = rospy.ServiceProxy("/recognise", Recognise)
req = RecogniseRequest()
req.image_raw = image
req.dataset = dataset
- req.confidence = 0.5
+ req.confidence = 0.4
resp = detect_service(req)
for detection in resp.detections:
- people_in_frame.append(detection.name)
- print(resp)
+ people_in_frame[detection.name] = rospy.Time.now()
except rospy.ServiceException as e:
rospy.logerr("Service call failed: %s" % e)
+
def greet():
- print(f"Hello, {' '.join(people_in_frame)}")
+ voice = Voice()
+ voice.speak(f"Hello, {' '.join(people_in_frame)}")
+
+
def image_callback(image):
- global last_received_time
- if last_received_time is None or rospy.Time.now() - last_received_time >= rospy.Duration(5.0):
- prev_people_in_frame = deepcopy(people_in_frame)
- detect(image)
- if people_in_frame != prev_people_in_frame:
- greet()
- last_received_time = rospy.Time.now()
+ global people_in_frame
+ prev_people_in_frame = list(people_in_frame.keys())
+ # remove detections from people_in_frame that are older than 5 seconds long
+ detect(image)
+ for person in list(people_in_frame.keys()):
+ if rospy.Time.now() - people_in_frame[person] > rospy.Duration(10):
+ del people_in_frame[person]
+ if (
+ list(people_in_frame.keys()) != prev_people_in_frame
+ and len(people_in_frame) > 0
+ ) or (len(prev_people_in_frame) == 0 and len(people_in_frame) > 0):
+ greet()
+
def listener():
- rospy.init_node('image_listener', anonymous=True)
- rospy.wait_for_service('/recognise')
- rospy.Subscriber(listen_topic, Image, image_callback)
+ rospy.init_node("image_listener", anonymous=True)
+ rospy.wait_for_service("/recognise")
+ rospy.Subscriber(listen_topic, Image, image_callback, queue_size=1)
rospy.spin()
-if __name__ == '__main__':
+
+if __name__ == "__main__":
listener()
diff --git a/common/vision/lasr_vision_deepface/examples/relay b/common/vision/lasr_vision_deepface/examples/relay
index af1c9bd4d..5c35ee887 100644
--- a/common/vision/lasr_vision_deepface/examples/relay
+++ b/common/vision/lasr_vision_deepface/examples/relay
@@ -28,7 +28,7 @@ def detect(image):
req = RecogniseRequest()
req.image_raw = image
req.dataset = dataset
- req.confidence = 0.7
+ req.confidence = 0.4
resp = detect_service(req)
print(resp)
except rospy.ServiceException as e:
diff --git a/common/vision/lasr_vision_deepface/launch/example.launch b/common/vision/lasr_vision_deepface/launch/example.launch
new file mode 100644
index 000000000..5d3281867
--- /dev/null
+++ b/common/vision/lasr_vision_deepface/launch/example.launch
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/common/vision/lasr_vision_deepface/nodes/service b/common/vision/lasr_vision_deepface/nodes/service
index 12149844c..d6a71384a 100644
--- a/common/vision/lasr_vision_deepface/nodes/service
+++ b/common/vision/lasr_vision_deepface/nodes/service
@@ -4,31 +4,86 @@ import re
import rospy
import lasr_vision_deepface as face_recognition
from sensor_msgs.msg import Image
-from lasr_vision_msgs.srv import Recognise, RecogniseRequest, RecogniseResponse
+from lasr_vision_msgs.srv import (
+ Recognise,
+ RecogniseRequest,
+ RecogniseResponse,
+ LearnFace,
+ LearnFaceRequest,
+ LearnFaceResponse,
+)
+
rospy.init_node("recognise_service")
# Determine variables
DEBUG = rospy.get_param("~debug", False)
-debug_publishers = {}
+recognise_debug_publishers = {}
+learn_face_debug_publishers = {}
if DEBUG:
- debug_publisher = rospy.Publisher("/recognise/debug", Image, queue_size=1)
+ recognise_debug_publisher = rospy.Publisher("/recognise/debug", Image, queue_size=1)
+ learn_face_debug_publisher = rospy.Publisher(
+ "/learn_face/debug", Image, queue_size=1
+ )
+ cropped_face_publisher = rospy.Publisher(
+ "/learn_face/debug/cropped_query_face", Image, queue_size=1
+ )
def detect(request: RecogniseRequest) -> RecogniseResponse:
debug_publisher = None
+ similar_face_debug_publisher = None
+ cropped_face_publisher = None
if DEBUG:
- if request.dataset in debug_publishers:
- debug_publisher = debug_publishers[request.dataset]
+ if request.dataset in recognise_debug_publishers:
+ debug_publisher = recognise_debug_publishers[request.dataset][0]
+ similar_face_debug_publisher = recognise_debug_publishers[request.dataset][
+ 1
+ ]
+ cropped_face_publisher = recognise_debug_publisher[request.dataset][2]
else:
topic_name = re.sub(r"[\W_]+", "", request.dataset)
debug_publisher = rospy.Publisher(
f"/recognise/debug/{topic_name}", Image, queue_size=1
)
- return face_recognition.detect(request, debug_publisher)
+ similar_face_debug_publisher = rospy.Publisher(
+ f"/recognise/debug/{topic_name}/similar_face", Image, queue_size=1
+ )
+ cropped_face_publisher = rospy.Publisher(
+ "/recognise/debug/cropped_query_face", Image, queue_size=1
+ )
+ recognise_debug_publishers[request.dataset] = (
+ debug_publisher,
+ similar_face_debug_publisher,
+ cropped_face_publisher,
+ )
+ return face_recognition.detect(
+ request, debug_publisher, similar_face_debug_publisher, cropped_face_publisher
+ )
+
+
+def learn_face(request: LearnFaceRequest) -> LearnFaceResponse:
+ debug_publisher = None
+ if DEBUG:
+ if request.dataset in learn_face_debug_publishers:
+ debug_publisher = learn_face_debug_publishers[request.dataset]
+ else:
+ topic_name = re.sub(r"[\W_]+", "", request.dataset)
+ debug_publisher = rospy.Publisher(
+ f"/learn_face/debug/{topic_name}", Image, queue_size=1
+ )
+ face_recognition.create_dataset(
+ "/xtion/rgb/image_raw",
+ request.dataset,
+ request.name,
+ request.n_images,
+ debug_publisher,
+ )
+ return LearnFaceResponse()
rospy.Service("/recognise", Recognise, detect)
-rospy.loginfo("Face Recognition service starter")
+rospy.Service("/learn_face", LearnFace, learn_face)
+rospy.loginfo("Face Recognition service started")
rospy.spin()
diff --git a/common/vision/lasr_vision_deepface/scripts/create_dataset b/common/vision/lasr_vision_deepface/scripts/create_dataset
index 9796b5665..85f8e6b3c 100644
--- a/common/vision/lasr_vision_deepface/scripts/create_dataset
+++ b/common/vision/lasr_vision_deepface/scripts/create_dataset
@@ -19,27 +19,7 @@ else:
size = 50
import rospy
-import rospkg
-from sensor_msgs.msg import Image
-import os
-import cv2_img
-import cv2
-
-DATASET_ROOT = os.path.join(
- rospkg.RosPack().get_path("lasr_vision_deepface"), "datasets"
-)
-DATASET_PATH = os.path.join(DATASET_ROOT, dataset, name)
-if not os.path.exists(DATASET_PATH):
- os.makedirs(DATASET_PATH)
rospy.init_node("create_dataset")
-rospy.loginfo(f"Taking {size} pictures of {name} and saving to {DATASET_PATH}")
-for i in range(size):
- img_msg = rospy.wait_for_message(topic, Image)
- cv_im = cv2_img.msg_to_cv2_img(img_msg)
- face_cropped_cv_im = face_recognition.detect_face(cv_im)
- if face_cropped_cv_im is None:
- continue
- cv2.imwrite(os.path.join(DATASET_PATH, f"{name}_{i+1}.png"), face_cropped_cv_im)
- rospy.loginfo(f"Took picture {i+1}")
+face_recognition.create_dataset(topic, dataset, name, size)
diff --git a/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/__init__.py b/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/__init__.py
index 5a5bb56d8..e69de29bb 100644
--- a/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/__init__.py
+++ b/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/__init__.py
@@ -1 +0,0 @@
-from .deepface import detect, detect_face
diff --git a/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/deepface.py b/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/deepface.py
index af8895dad..ce9c4771b 100644
--- a/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/deepface.py
+++ b/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/deepface.py
@@ -5,14 +5,19 @@
import rospkg
import rospy
import os
+import numpy as np
+import pandas as pd
from lasr_vision_msgs.msg import Detection
from lasr_vision_msgs.srv import RecogniseRequest, RecogniseResponse
+from sensor_msgs.msg import Image
+
DATASET_ROOT = os.path.join(
rospkg.RosPack().get_path("lasr_vision_deepface"), "datasets"
)
+
Mat = int # np.typing.NDArray[np.uint8]
@@ -28,11 +33,90 @@ def detect_face(cv_im: Mat) -> Mat | None:
return None
facial_area = faces[0]["facial_area"]
x, y, w, h = facial_area["x"], facial_area["y"], facial_area["w"], facial_area["h"]
+
+ # add padding to the face
+ x -= 10
+ y -= 10
+ w += 20
+ h += 20
+
return cv_im[:][y : y + h, x : x + w]
+def create_image_collage(images, output_size=(640, 480)):
+
+ # Calculate grid dimensions
+ num_images = len(images)
+ rows = int(np.sqrt(num_images))
+ print(num_images, rows)
+ cols = (num_images + rows - 1) // rows # Ceiling division
+
+ # Resize images to fit in the grid
+ resized_images = [
+ cv2.resize(img, (output_size[0] // cols, output_size[1] // rows))
+ for img in images
+ ]
+
+ # Create the final image grid
+ grid_image = np.zeros((output_size[1], output_size[0], 3), dtype=np.uint8)
+
+ # Populate the grid with resized images
+ for i in range(rows):
+ for j in range(cols):
+ idx = i * cols + j
+ if idx < num_images:
+ y_start = i * (output_size[1] // rows)
+ y_end = (i + 1) * (output_size[1] // rows)
+ x_start = j * (output_size[0] // cols)
+ x_end = (j + 1) * (output_size[0] // cols)
+
+ grid_image[y_start:y_end, x_start:x_end] = resized_images[idx]
+
+ return grid_image
+
+
+def create_dataset(
+ topic: str,
+ dataset: str,
+ name: str,
+ size: int,
+ debug_publisher: rospy.Publisher | None,
+) -> None:
+ dataset_path = os.path.join(DATASET_ROOT, dataset, name)
+ if not os.path.exists(dataset_path):
+ os.makedirs(dataset_path)
+ rospy.loginfo(f"Taking {size} pictures of {name} and saving to {dataset_path}")
+
+ images = []
+ for i in range(size):
+ img_msg = rospy.wait_for_message(topic, Image)
+ cv_im = cv2_img.msg_to_cv2_img(img_msg)
+ face_cropped_cv_im = detect_face(cv_im)
+ if face_cropped_cv_im is None:
+ continue
+ cv2.imwrite(os.path.join(dataset_path, f"{name}_{i+1}.png"), face_cropped_cv_im) # type: ignore
+ rospy.loginfo(f"Took picture {i+1}")
+ images.append(face_cropped_cv_im)
+ if debug_publisher is not None:
+ debug_publisher.publish(
+ cv2_img.cv2_img_to_msg(create_image_collage(images))
+ )
+
+ # Force retraining
+ DeepFace.find(
+ cv_im,
+ os.path.join(DATASET_ROOT, dataset),
+ enforce_detection=False,
+ silent=True,
+ detector_backend="mtcnn",
+ )
+
+
def detect(
- request: RecogniseRequest, debug_publisher: rospy.Publisher | None
+ request: RecogniseRequest,
+ debug_publisher: rospy.Publisher | None,
+ debug_inference_pub: rospy.Publisher | None,
+ cropped_detect_pub: rospy.Publisher | None,
) -> RecogniseResponse:
# Decode the image
rospy.loginfo("Decoding")
@@ -49,7 +133,6 @@ def detect(
enforce_detection=True,
silent=True,
detector_backend="mtcnn",
- threshold=request.confidence,
)
except ValueError:
return response
@@ -66,14 +149,19 @@ def detect(
row["source_h"][0],
)
detection.xywh = [x, y, w, h]
- detection.confidence = 1.0 - row["distance"][0]
+ detection.confidence = row["distance"][0]
response.detections.append(detection)
+ cropped_image = cv_im[:][y : y + h, x : x + w]
+
+ if cropped_detect_pub is not None:
+ cropped_detect_pub.publish(cv2_img.cv2_img_to_msg(cropped_image))
+
# Draw bounding boxes and labels for debugging
cv2.rectangle(cv_im, (x, y), (x + w, y + h), (0, 0, 255), 2)
cv2.putText(
cv_im,
- f"{detection.name} ({detection.confidence})",
+ f"{detection.name} Distance: ({detection.confidence:.2f})",
(x, y - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
@@ -84,5 +172,16 @@ def detect(
# publish to debug topic
if debug_publisher is not None:
debug_publisher.publish(cv2_img.cv2_img_to_msg(cv_im))
+ if debug_inference_pub is not None:
+ result = pd.concat(result)
+ # check for empty result
+ if not result.empty:
+ result_paths = list(result["identity"])
+ if len(result_paths) > 5:
+ result_paths = result_paths[:5]
+ result_images = [cv2.imread(path) for path in result_paths]
+ debug_inference_pub.publish(
+ cv2_img.cv2_img_to_msg(create_image_collage(result_images))
+ )
return response
diff --git a/common/vision/lasr_vision_msgs/CMakeLists.txt b/common/vision/lasr_vision_msgs/CMakeLists.txt
index 0f781a520..f79567dc5 100644
--- a/common/vision/lasr_vision_msgs/CMakeLists.txt
+++ b/common/vision/lasr_vision_msgs/CMakeLists.txt
@@ -60,14 +60,10 @@ add_service_files(
BodyPixDetection.srv
TorchFaceFeatureDetection.srv
Recognise.srv
+ LearnFace.srv
)
-## Generate actions in the 'action' folder
-# add_action_files(
-# FILES
-# Action1.action
-# Action2.action
-# )
+# Generate actions in the 'action' folder
## Generate added messages and services with any dependencies listed here
generate_messages(
diff --git a/common/vision/lasr_vision_msgs/srv/LearnFace.srv b/common/vision/lasr_vision_msgs/srv/LearnFace.srv
new file mode 100644
index 000000000..5376e53d5
--- /dev/null
+++ b/common/vision/lasr_vision_msgs/srv/LearnFace.srv
@@ -0,0 +1,10 @@
+# Name to associate
+string name
+
+# Dataset to add face to
+string dataset
+
+# Number of images to take
+int32 n_images
+
+---