diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server index 95f3b1e36..ed1c01d25 100644 --- a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server +++ b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server @@ -13,6 +13,7 @@ import torch import actionlib import speech_recognition as sr # type: ignore import lasr_speech_recognition_msgs.msg # type: ignore +from std_msgs.msg import String # type: ignore from lasr_speech_recognition_whisper import load_model # type: ignore # Error handler to remove ALSA error messages taken from: @@ -83,7 +84,9 @@ class TranscribeSpeechAction(object): self._action_name = action_name self._model_params = model_params - + self._transcription_server = rospy.Publisher( + "/live_speech_transcription", String, queue_size=10 + ) with noalsaerr(): self._model = load_model( self._model_params.model_name, @@ -91,7 +94,7 @@ class TranscribeSpeechAction(object): self._model_params.warmup, ) # Configure the speech recogniser object and adjust for ambient noise - self.recogniser = self._configure_recogniser() + self.recogniser = self._configure_recogniser(ambient_adj=True) # Setup the action server and register execution callback self._action_server = actionlib.SimpleActionServer( self._action_name, @@ -110,6 +113,7 @@ class TranscribeSpeechAction(object): self._action_server.start() def _timer_cb(self, _) -> None: + return """Adjusts the microphone for ambient noise, unless the action server is listening.""" if self._listening: return @@ -223,7 +227,7 @@ class TranscribeSpeechAction(object): rospy.loginfo( f"Time taken: {transcription_end_time - transcription_start_time:.2f}s" ) - + self._transcription_server.publish(phrase) if self._action_server.is_preempt_requested(): self._listening = False return @@ -248,13 +252,13 @@ def parse_args() -> dict: ) parser.add_argument( - "--action-name", + "--action_name", type=str, default="transcribe_speech", help="Name of the action server.", ) parser.add_argument( - "--model-name", + "--model_name", type=str, default="medium.en", help="Name of the speech recognition model.", @@ -300,8 +304,8 @@ def parse_args() -> dict: action="store_true", help="Disable warming up the model by running inference on a test file.", ) - - return vars(parser.parse_args()) + args, unknown = parser.parse_known_args() + return vars(args) def configure_model_params(config: dict) -> speech_model_params: @@ -346,6 +350,8 @@ def configure_whisper_cache() -> None: if __name__ == "__main__": configure_whisper_cache() config = parse_args() - rospy.init_node(config["action_name"]) - server = TranscribeSpeechAction(rospy.get_name(), configure_model_params(config)) + rospy.init_node("speech_transcription_node") + server = TranscribeSpeechAction( + config["action_name"], configure_model_params(config) + ) rospy.spin() diff --git a/common/vector_databases/lasr_vector_databases_faiss/.gitignore b/common/vector_databases/lasr_vector_databases_faiss/.gitignore new file mode 100644 index 000000000..0d45877c1 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/.gitignore @@ -0,0 +1,2 @@ +data/** +!data/.gitkeep \ No newline at end of file diff --git a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt new file mode 100644 index 000000000..1d550fbad --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt @@ -0,0 +1,216 @@ +cmake_minimum_required(VERSION 3.0.2) +project(lasr_vector_databases_faiss) + +## Compile as C++11, supported in ROS Kinetic and newer +# add_compile_options(-std=c++11) + +## Find catkin macros and libraries +## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) +## is used, also find other catkin packages +find_package(catkin REQUIRED catkin_virtualenv) + +## System dependencies are found with CMake's conventions +# find_package(Boost REQUIRED COMPONENTS system) + + +## Uncomment this if the package has a setup.py. This macro ensures +## modules and global scripts declared therein get installed +## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html +catkin_python_setup() +catkin_generate_virtualenv( + INPUT_REQUIREMENTS requirements.in + PYTHON_INTERPRETER python3.9 +) +################################################ +## Declare ROS messages, services and actions ## +################################################ + +## To declare and build messages, services or actions from within this +## package, follow these steps: +## * Let MSG_DEP_SET be the set of packages whose message types you use in +## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...). +## * In the file package.xml: +## * add a build_depend tag for "message_generation" +## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET +## * If MSG_DEP_SET isn't empty the following dependency has been pulled in +## but can be declared for certainty nonetheless: +## * add a exec_depend tag for "message_runtime" +## * In this file (CMakeLists.txt): +## * add "message_generation" and every package in MSG_DEP_SET to +## find_package(catkin REQUIRED COMPONENTS ...) +## * add "message_runtime" and every package in MSG_DEP_SET to +## catkin_package(CATKIN_DEPENDS ...) +## * uncomment the add_*_files sections below as needed +## and list every .msg/.srv/.action file to be processed +## * uncomment the generate_messages entry below +## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...) + +## Generate messages in the 'msg' folder +# add_message_files( +# FILES +# Message1.msg +# Message2.msg +# ) + +## Generate services in the 'srv' folder +# add_service_files( +# FILES +# Service1.srv +# Service2.srv +# ) + +# Generate actions in the 'action' folder +# add_action_files( +# DIRECTORY action +# FILES WaitGreet.action Identify.action Greet.action GetName.action LearnFace.action GetCommand.action Guide.action DetectPeople.action FindPerson.action ReceiveObject.action HandoverObject.action +# ) + +# Generate added messages and services with any dependencies listed here +# generate_messages( +# DEPENDENCIES +# actionlib_msgs +# geometry_msgs +# ) + +################################################ +## Declare ROS dynamic reconfigure parameters ## +################################################ + +## To declare and build dynamic reconfigure parameters within this +## package, follow these steps: +## * In the file package.xml: +## * add a build_depend and a exec_depend tag for "dynamic_reconfigure" +## * In this file (CMakeLists.txt): +## * add "dynamic_reconfigure" to +## find_package(catkin REQUIRED COMPONENTS ...) +## * uncomment the "generate_dynamic_reconfigure_options" section below +## and list every .cfg file to be processed + +## Generate dynamic reconfigure parameters in the 'cfg' folder +# generate_dynamic_reconfigure_options( +# cfg/DynReconf1.cfg +# cfg/DynReconf2.cfg +# ) + +################################### +## catkin specific configuration ## +################################### +## The catkin_package macro generates cmake config files for your package +## Declare things to be passed to dependent projects +## INCLUDE_DIRS: uncomment this if your package contains header files +## LIBRARIES: libraries you create in this project that dependent projects also need +## CATKIN_DEPENDS: catkin_packages dependent projects also need +## DEPENDS: system dependencies of this project that dependent projects also need +catkin_package( +# INCLUDE_DIRS include +# LIBRARIES qualification +# DEPENDS system_lib +) + +########### +## Build ## +########### + +## Specify additional locations of header files +## Your package locations should be listed before other locations +include_directories( +# include + ${catkin_INCLUDE_DIRS} +) + +## Declare a C++ library +# add_library(${PROJECT_NAME} +# src/${PROJECT_NAME}/qualification.cpp +# ) + +## Add cmake target dependencies of the library +## as an example, code may need to be generated before libraries +## either from message generation or dynamic reconfigure +# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Declare a C++ executable +## With catkin_make all packages are built within a single CMake context +## The recommended prefix ensures that target names across packages don't collide +# add_executable(${PROJECT_NAME}_node src/qualification_node.cpp) + +## Rename C++ executable without prefix +## The above recommended prefix causes long target names, the following renames the +## target back to the shorter version for ease of user use +## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" +# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") + +## Add cmake target dependencies of the executable +## same as for the library above +# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Specify libraries to link a library or executable target against +# target_link_libraries(${PROJECT_NAME}_node +# ${catkin_LIBRARIES} +# ) + +############# +## Install ## +############# + +# all install targets should use catkin DESTINATION variables +# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html + +## Mark executable scripts (Python etc.) for installation +## in contrast to setup.py, you can choose the destination +# catkin_install_python(PROGRAMS +# nodes/qualification +# nodes/actions/wait_greet +# nodes/actions/identify +# nodes/actions/greet +# nodes/actions/get_name +# nodes/actions/learn_face +# nodes/actions/get_command +# nodes/actions/guide +# nodes/actions/find_person +# nodes/actions/detect_people +# nodes/actions/receive_object +# nodes/actions/handover_object +# nodes/better_qualification +# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark executables for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html +# install(TARGETS ${PROJECT_NAME}_node +# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark libraries for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html +# install(TARGETS ${PROJECT_NAME} +# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION} +# ) + +## Mark cpp header files for installation +# install(DIRECTORY include/${PROJECT_NAME}/ +# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} +# FILES_MATCHING PATTERN "*.h" +# PATTERN ".svn" EXCLUDE +# ) + +## Mark other files for installation (e.g. launch and bag files, etc.) +# install(FILES +# # myfile1 +# # myfile2 +# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +# ) + +############# +## Testing ## +############# + +## Add gtest based cpp test target and link libraries +# catkin_add_gtest(${PROJECT_NAME}-test test/test_qualification.cpp) +# if(TARGET ${PROJECT_NAME}-test) +# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) +# endif() + +## Add folders to be run by python nosetests +# catkin_add_nosetests(test) diff --git a/common/vector_databases/lasr_vector_databases_faiss/data/.gitkeep b/common/vector_databases/lasr_vector_databases_faiss/data/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/common/vector_databases/lasr_vector_databases_faiss/package.xml b/common/vector_databases/lasr_vector_databases_faiss/package.xml new file mode 100644 index 000000000..f8128ea56 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/package.xml @@ -0,0 +1,58 @@ + + + lasr_vector_databases_faiss + 0.0.0 + The faiss package + + + + + Matt Barker + + + + + + MIT + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + catkin + catkin_virtualenv + + + + requirements.txt + + diff --git a/common/vector_databases/lasr_vector_databases_faiss/requirements.in b/common/vector_databases/lasr_vector_databases_faiss/requirements.in new file mode 100644 index 000000000..14955d38d --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/requirements.in @@ -0,0 +1 @@ +faiss-cpu \ No newline at end of file diff --git a/common/vector_databases/lasr_vector_databases_faiss/requirements.txt b/common/vector_databases/lasr_vector_databases_faiss/requirements.txt new file mode 100644 index 000000000..4557b8f0c --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/requirements.txt @@ -0,0 +1,53 @@ +--extra-index-url https://pypi.ngc.nvidia.com +--trusted-host pypi.ngc.nvidia.com + +certifi==2024.2.2 # via requests +charset-normalizer==3.3.2 # via requests +click==8.1.7 # via nltk +faiss-cpu==1.7.4 # via -r requirements.in +filelock==3.13.1 # via huggingface-hub, torch, transformers, triton +fsspec==2024.2.0 # via huggingface-hub, torch +ftfy==6.1.3 # via -r requirements.in, clip +huggingface-hub==0.20.3 # via sentence-transformers, tokenizers, transformers +idna==3.6 # via requests +jinja2==3.1.3 # via torch +joblib==1.3.2 # via nltk, scikit-learn +markupsafe==2.1.5 # via jinja2 +mpmath==1.3.0 # via sympy +networkx==3.2.1 # via torch +nltk==3.8.1 # via sentence-transformers +numpy==1.26.3 # via opencv-python, scikit-learn, scipy, sentence-transformers, torchvision, transformers +nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch +nvidia-cuda-cupti-cu12==12.1.105 # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 # via torch +nvidia-cuda-runtime-cu12==12.1.105 # via torch +nvidia-cudnn-cu12==8.9.2.26 # via torch +nvidia-cufft-cu12==11.0.2.54 # via torch +nvidia-curand-cu12==10.3.2.106 # via torch +nvidia-cusolver-cu12==11.4.5.107 # via torch +nvidia-cusparse-cu12==12.1.0.106 # via nvidia-cusolver-cu12, torch +nvidia-nccl-cu12==2.19.3 # via torch +nvidia-nvjitlink-cu12==12.3.101 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 # via torch +opencv-python==4.9.0.80 # via -r requirements.in +packaging==23.2 # via huggingface-hub, transformers +pillow==10.2.0 # via sentence-transformers, torchvision +pyyaml==6.0.1 # via huggingface-hub, transformers +regex==2023.12.25 # via -r requirements.in, clip, nltk, transformers +requests==2.31.0 # via huggingface-hub, torchvision, transformers +safetensors==0.4.2 # via transformers +scikit-learn==1.4.0 # via sentence-transformers +scipy==1.12.0 # via scikit-learn, sentence-transformers +sentence-transformers==2.3.1 # via -r requirements.in +sentencepiece==0.1.99 # via sentence-transformers +sympy==1.12 # via torch +threadpoolctl==3.2.0 # via scikit-learn +tokenizers==0.15.1 # via transformers +torch==2.2.0 # via clip, sentence-transformers, torchvision +torchvision==0.17.0 # via clip +tqdm==4.66.1 # via -r requirements.in, clip, huggingface-hub, nltk, sentence-transformers, transformers +transformers==4.37.2 # via sentence-transformers +triton==2.2.0 # via torch +typing-extensions==4.9.0 # via huggingface-hub, torch +urllib3==2.2.0 # via requests +wcwidth==0.2.13 # via ftfy diff --git a/common/vector_databases/lasr_vector_databases_faiss/setup.py b/common/vector_databases/lasr_vector_databases_faiss/setup.py new file mode 100644 index 000000000..443aec4a4 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/setup.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +from distutils.core import setup +from catkin_pkg.python_setup import generate_distutils_setup + +setup_args = generate_distutils_setup( + packages=["lasr_vector_databases_faiss"], package_dir={"": "src"} +) + +setup(**setup_args) diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py new file mode 100755 index 000000000..f7dbcbe24 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +import os +import torch +import numpy as np + +# import rospy +import faiss # type: ignore +from sentence_transformers import SentenceTransformer # type: ignore +from typing import Optional + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +def load_commands(command_path: str) -> list[str]: + """Loads the commands stored in the given txt file + into a list of string commands + Args: + command_path (str): path to the txt file containing + the commands -- assumes one command per line. + Returns: + list[str]: list of string commands where each entry in the + list is a command + """ + command_list = [] + with open(command_path, "r", encoding="utf8") as src: + for command in src: + # Strip newline char. + command_list.append(command[:-1]) + return command_list + + +def get_sentence_embeddings( + command_list: list[str], model: SentenceTransformer +) -> np.ndarray: + """Converts the list of command strings into an array of sentence + embeddings (where each command is a sentence and each sentence + is converted to a vector) + Args: + command_list (list[str]): list of string commands, where each + entry in the list is assumed to be a separate command + model (SentenceTransformer): model used to perform the embedding. + Assumes a method called encode that takes a list of strings + as input. + Returns: + np.ndarray: array of shape (n_commands, embedding_dim) + """ + + return model.encode( + command_list, + convert_to_numpy=True, + show_progress_bar=True, + batch_size=256, + device=DEVICE, + ) + + +def create_vector_database(vectors: np.ndarray) -> faiss.IndexFlatIP: + """Creates a vector database from an array of vectors of the same dimensionality + Args: + vectors (np.ndarray): shape (n_vectors, vector_dim) + + Returns: + faiss.IndexFlatIP: Flat index containing the vectors + """ + print("Creating vector database") + index_flat = faiss.IndexFlatIP(vectors.shape[1]) + faiss.normalize_L2(vectors) + index_flat.add(vectors) + print("Finished creating vector database") + return index_flat + + +def get_command_database( + index_path: str, command_path: Optional[str] = None +) -> faiss.IndexFlatL2: + """Gets a vector database containing a list of embedded commands. Creates the database + if the path does not exist, else, loads it into memory. + + Args: + index_path (str): Path to an existing faiss Index, or where to save a new one. + command_path (str, optional): Path of text file containing commands. + Only required if creating a new database. Defaults to None. + + Returns: + faiss.IndexFlatL2: faiss Index object containing the embedded commands. + """ + + if not os.path.exists(f"{index_path}.index"): + # rospy.loginfo("Creating new command vector database") + assert command_path is not None + command_list = load_commands(command_path) + model = SentenceTransformer("all-MiniLM-L6-v2") + command_embeddings = get_sentence_embeddings(command_list, model) + print(command_embeddings.shape) + command_database = create_vector_database(command_embeddings) + faiss.write_index(command_database, f"{index_path}.index") + # rospy.loginfo("Finished creating vector database") + else: + command_database = faiss.read_index(f"{index_path}.index") + + return command_database + + +def get_similar_commands( + command: str, + index_path: str, + command_path: str, + n_similar_commands: int = 100, + return_embeddings: bool = False, +) -> tuple[list[str], list[float]]: + """Gets the most similar commands to the given command string + Args: + command (str): command to compare against the database + index_path (str): path to the location to create or retrieve + the faiss index containing the embedded commands. + command_path (str): path to the txt file containing the commands + n_similar_commands (int, optional): number of similar commands to + return. Defaults to 100. + Returns: + list[str]: list of string commands, where each entry in the + list is a similar command + """ + command_database = get_command_database(index_path, command_path) + command_list = load_commands(command_path) + model = SentenceTransformer("all-MiniLM-L6-v2") + command_embedding = get_sentence_embeddings([command], model) + faiss.normalize_L2(command_embedding) + command_distances, command_indices = command_database.search( + command_embedding, n_similar_commands + ) + nearest_commands = [command_list[i] for i in command_indices[0]] + + if return_embeddings: + all_command_embeddings = get_sentence_embeddings(command_list, model) + # filter for only nererst commands + all_command_embeddings = all_command_embeddings[command_indices[0]] + return ( + nearest_commands, + list(command_distances[0]), + all_command_embeddings, + command_embedding, + ) + + return nearest_commands, list(command_distances[0]) + + +if __name__ == "__main__": + """Example usage of using this to find similar commands""" + command = "find Jared and asks if he needs help" + result, distances, command_embeddings, query_embedding = get_similar_commands( + command, + "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/qualification/data/command_index", + "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/qualification/data/command_list.txt", + n_similar_commands=1000, + return_embeddings=True, + ) + print(result) diff --git a/common/vision/lasr_vision_clip/.gitignore b/common/vision/lasr_vision_clip/.gitignore new file mode 100644 index 000000000..0d45877c1 --- /dev/null +++ b/common/vision/lasr_vision_clip/.gitignore @@ -0,0 +1,2 @@ +data/** +!data/.gitkeep \ No newline at end of file diff --git a/common/vision/lasr_vision_clip/CMakeLists.txt b/common/vision/lasr_vision_clip/CMakeLists.txt new file mode 100644 index 000000000..a13eb6f2a --- /dev/null +++ b/common/vision/lasr_vision_clip/CMakeLists.txt @@ -0,0 +1,216 @@ +cmake_minimum_required(VERSION 3.0.2) +project(lasr_vision_clip) + +## Compile as C++11, supported in ROS Kinetic and newer +# add_compile_options(-std=c++11) + +## Find catkin macros and libraries +## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) +## is used, also find other catkin packages +find_package(catkin REQUIRED catkin_virtualenv) + +## System dependencies are found with CMake's conventions +# find_package(Boost REQUIRED COMPONENTS system) + + +## Uncomment this if the package has a setup.py. This macro ensures +## modules and global scripts declared therein get installed +## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html +catkin_python_setup() +catkin_generate_virtualenv( + INPUT_REQUIREMENTS requirements.in + PYTHON_INTERPRETER python3.9 +) +################################################ +## Declare ROS messages, services and actions ## +################################################ + +## To declare and build messages, services or actions from within this +## package, follow these steps: +## * Let MSG_DEP_SET be the set of packages whose message types you use in +## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...). +## * In the file package.xml: +## * add a build_depend tag for "message_generation" +## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET +## * If MSG_DEP_SET isn't empty the following dependency has been pulled in +## but can be declared for certainty nonetheless: +## * add a exec_depend tag for "message_runtime" +## * In this file (CMakeLists.txt): +## * add "message_generation" and every package in MSG_DEP_SET to +## find_package(catkin REQUIRED COMPONENTS ...) +## * add "message_runtime" and every package in MSG_DEP_SET to +## catkin_package(CATKIN_DEPENDS ...) +## * uncomment the add_*_files sections below as needed +## and list every .msg/.srv/.action file to be processed +## * uncomment the generate_messages entry below +## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...) + +## Generate messages in the 'msg' folder +# add_message_files( +# FILES +# Message1.msg +# Message2.msg +# ) + +## Generate services in the 'srv' folder +# add_service_files( +# FILES +# Service1.srv +# Service2.srv +# ) + +# Generate actions in the 'action' folder +# add_action_files( +# DIRECTORY action +# FILES WaitGreet.action Identify.action Greet.action GetName.action LearnFace.action GetCommand.action Guide.action DetectPeople.action FindPerson.action ReceiveObject.action HandoverObject.action +# ) + +# Generate added messages and services with any dependencies listed here +# generate_messages( +# DEPENDENCIES +# actionlib_msgs +# geometry_msgs +# ) + +################################################ +## Declare ROS dynamic reconfigure parameters ## +################################################ + +## To declare and build dynamic reconfigure parameters within this +## package, follow these steps: +## * In the file package.xml: +## * add a build_depend and a exec_depend tag for "dynamic_reconfigure" +## * In this file (CMakeLists.txt): +## * add "dynamic_reconfigure" to +## find_package(catkin REQUIRED COMPONENTS ...) +## * uncomment the "generate_dynamic_reconfigure_options" section below +## and list every .cfg file to be processed + +## Generate dynamic reconfigure parameters in the 'cfg' folder +# generate_dynamic_reconfigure_options( +# cfg/DynReconf1.cfg +# cfg/DynReconf2.cfg +# ) + +################################### +## catkin specific configuration ## +################################### +## The catkin_package macro generates cmake config files for your package +## Declare things to be passed to dependent projects +## INCLUDE_DIRS: uncomment this if your package contains header files +## LIBRARIES: libraries you create in this project that dependent projects also need +## CATKIN_DEPENDS: catkin_packages dependent projects also need +## DEPENDS: system dependencies of this project that dependent projects also need +catkin_package( +# INCLUDE_DIRS include +# LIBRARIES qualification +# DEPENDS system_lib +) + +########### +## Build ## +########### + +## Specify additional locations of header files +## Your package locations should be listed before other locations +include_directories( +# include + ${catkin_INCLUDE_DIRS} +) + +## Declare a C++ library +# add_library(${PROJECT_NAME} +# src/${PROJECT_NAME}/qualification.cpp +# ) + +## Add cmake target dependencies of the library +## as an example, code may need to be generated before libraries +## either from message generation or dynamic reconfigure +# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Declare a C++ executable +## With catkin_make all packages are built within a single CMake context +## The recommended prefix ensures that target names across packages don't collide +# add_executable(${PROJECT_NAME}_node src/qualification_node.cpp) + +## Rename C++ executable without prefix +## The above recommended prefix causes long target names, the following renames the +## target back to the shorter version for ease of user use +## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" +# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") + +## Add cmake target dependencies of the executable +## same as for the library above +# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Specify libraries to link a library or executable target against +# target_link_libraries(${PROJECT_NAME}_node +# ${catkin_LIBRARIES} +# ) + +############# +## Install ## +############# + +# all install targets should use catkin DESTINATION variables +# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html + +## Mark executable scripts (Python etc.) for installation +## in contrast to setup.py, you can choose the destination +# catkin_install_python(PROGRAMS +# nodes/qualification +# nodes/actions/wait_greet +# nodes/actions/identify +# nodes/actions/greet +# nodes/actions/get_name +# nodes/actions/learn_face +# nodes/actions/get_command +# nodes/actions/guide +# nodes/actions/find_person +# nodes/actions/detect_people +# nodes/actions/receive_object +# nodes/actions/handover_object +# nodes/better_qualification +# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark executables for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html +# install(TARGETS ${PROJECT_NAME}_node +# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark libraries for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html +# install(TARGETS ${PROJECT_NAME} +# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION} +# ) + +## Mark cpp header files for installation +# install(DIRECTORY include/${PROJECT_NAME}/ +# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} +# FILES_MATCHING PATTERN "*.h" +# PATTERN ".svn" EXCLUDE +# ) + +## Mark other files for installation (e.g. launch and bag files, etc.) +# install(FILES +# # myfile1 +# # myfile2 +# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +# ) + +############# +## Testing ## +############# + +## Add gtest based cpp test target and link libraries +# catkin_add_gtest(${PROJECT_NAME}-test test/test_qualification.cpp) +# if(TARGET ${PROJECT_NAME}-test) +# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) +# endif() + +## Add folders to be run by python nosetests +# catkin_add_nosetests(test) diff --git a/common/vision/lasr_vision_clip/data/.gitkeep b/common/vision/lasr_vision_clip/data/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/common/vision/lasr_vision_clip/package.xml b/common/vision/lasr_vision_clip/package.xml new file mode 100644 index 000000000..e64ef13ef --- /dev/null +++ b/common/vision/lasr_vision_clip/package.xml @@ -0,0 +1,58 @@ + + + lasr_vision_clip + 0.0.0 + The lasr_vision_clip package + + + + + Matt Barker + + + + + + MIT + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + catkin + catkin_virtualenv + + + + requirements.txt + + diff --git a/common/vision/lasr_vision_clip/requirements.in b/common/vision/lasr_vision_clip/requirements.in new file mode 100644 index 000000000..b03f1e5cc --- /dev/null +++ b/common/vision/lasr_vision_clip/requirements.in @@ -0,0 +1,2 @@ +sentence-transformers +opencv-python \ No newline at end of file diff --git a/common/vision/lasr_vision_clip/requirements.txt b/common/vision/lasr_vision_clip/requirements.txt new file mode 100644 index 000000000..6d7e59d07 --- /dev/null +++ b/common/vision/lasr_vision_clip/requirements.txt @@ -0,0 +1,53 @@ +--extra-index-url https://pypi.ngc.nvidia.com +--trusted-host pypi.ngc.nvidia.com + +certifi==2024.2.2 # via requests +charset-normalizer==3.3.2 # via requests +click==8.1.7 # via nltk +clip @ git+https://github.com/openai/CLIP.git # via -r requirements.in +filelock==3.13.1 # via huggingface-hub, torch, transformers, triton +fsspec==2024.2.0 # via huggingface-hub, torch +ftfy==6.1.3 # via -r requirements.in, clip +huggingface-hub==0.20.3 # via sentence-transformers, tokenizers, transformers +idna==3.6 # via requests +jinja2==3.1.3 # via torch +joblib==1.3.2 # via nltk, scikit-learn +markupsafe==2.1.5 # via jinja2 +mpmath==1.3.0 # via sympy +networkx==3.2.1 # via torch +nltk==3.8.1 # via sentence-transformers +numpy==1.26.3 # via opencv-python, scikit-learn, scipy, sentence-transformers, torchvision, transformers +nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch +nvidia-cuda-cupti-cu12==12.1.105 # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 # via torch +nvidia-cuda-runtime-cu12==12.1.105 # via torch +nvidia-cudnn-cu12==8.9.2.26 # via torch +nvidia-cufft-cu12==11.0.2.54 # via torch +nvidia-curand-cu12==10.3.2.106 # via torch +nvidia-cusolver-cu12==11.4.5.107 # via torch +nvidia-cusparse-cu12==12.1.0.106 # via nvidia-cusolver-cu12, torch +nvidia-nccl-cu12==2.19.3 # via torch +nvidia-nvjitlink-cu12==12.3.101 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 # via torch +opencv-python==4.9.0.80 # via -r requirements.in +packaging==23.2 # via huggingface-hub, transformers +pillow==10.2.0 # via sentence-transformers, torchvision +pyyaml==6.0.1 # via huggingface-hub, transformers +regex==2023.12.25 # via -r requirements.in, clip, nltk, transformers +requests==2.31.0 # via huggingface-hub, torchvision, transformers +safetensors==0.4.2 # via transformers +scikit-learn==1.4.0 # via sentence-transformers +scipy==1.12.0 # via scikit-learn, sentence-transformers +sentence-transformers==2.3.1 # via -r requirements.in +sentencepiece==0.1.99 # via sentence-transformers +sympy==1.12 # via torch +threadpoolctl==3.2.0 # via scikit-learn +tokenizers==0.15.1 # via transformers +torch==2.2.0 # via clip, sentence-transformers, torchvision +torchvision==0.17.0 # via clip +tqdm==4.66.1 # via -r requirements.in, clip, huggingface-hub, nltk, sentence-transformers, transformers +transformers==4.37.2 # via sentence-transformers +triton==2.2.0 # via torch +typing-extensions==4.9.0 # via huggingface-hub, torch +urllib3==2.2.0 # via requests +wcwidth==0.2.13 # via ftfy diff --git a/common/vision/lasr_vision_clip/setup.py b/common/vision/lasr_vision_clip/setup.py new file mode 100644 index 000000000..d02820e3c --- /dev/null +++ b/common/vision/lasr_vision_clip/setup.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +from distutils.core import setup +from catkin_pkg.python_setup import generate_distutils_setup + +setup_args = generate_distutils_setup( + packages=["lasr_vision_clip"], package_dir={"": "src"} +) + +setup(**setup_args) diff --git a/common/vision/lasr_vision_clip/src/lasr_vision_clip/__init__.py b/common/vision/lasr_vision_clip/src/lasr_vision_clip/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/vision/lasr_vision_clip/src/lasr_vision_clip/clip_utils.py b/common/vision/lasr_vision_clip/src/lasr_vision_clip/clip_utils.py new file mode 100644 index 000000000..81d6dd25a --- /dev/null +++ b/common/vision/lasr_vision_clip/src/lasr_vision_clip/clip_utils.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +import torch +import rospy +import cv2 +import cv2_img +import numpy as np +from copy import deepcopy +from sentence_transformers import SentenceTransformer, util + +from sensor_msgs.msg import Image + + +def load_model(device: str = "cuda"): + """Load the CLIP model. + + Args: + model_name (str): the model name + device (str, optional): the device to use. Defaults to "cuda". + + Returns: + Any: the model and preprocess function + """ + model = SentenceTransformer("clip-ViT-B-32", device=device) + return model + + +def run_clip( + model: SentenceTransformer, labels: list[str], img: np.ndarray +) -> torch.Tensor: + """Run the CLIP model. + + Args: + model (Any): clip model loaded into memory + labels (List[str]): list of string labels to query image similarity to. + img (np.ndarray): the image to query + + Returns: + List[float]: the cosine similarity scores between the image and label embeddings. + """ + txt = model.encode(labels) + img = model.encode(img) + with torch.no_grad(): + cos_scores = util.cos_sim(img, txt) + return cos_scores + + +def query_image_stream( + model: SentenceTransformer, + answers: list[str], + annotate: bool = False, +) -> tuple[str, torch.Tensor, Image]: + """Queries the CLIP model with the latest image from the robot's camera + and a set of possible image captions and returns the most likely caption. + + Args: + model (SentenceTransformer): clip model to run inference on, loaded into memory + answers(list[str]): list of possible answers + annotate(bool, optional): whether to annotate the image with the most likely, and + second most likely, caption. Defaults to False. + returns: + tuple(str, torch.Tensor, Image): the most likely answer, the scores, and the annotated image msg + """ + img_msg = rospy.wait_for_message("/xtion/rgb/image_raw", Image) + img_pil = cv2_img.msg_to_pillow_img(img_msg) + + cos_scores = run_clip(model, answers, img_pil) + max_score = cos_scores.argmax() + # get second highest score in tensor + max_val = deepcopy(cos_scores[0, max_score]) + cos_scores[0, max_score] = 0 + second_max_score = cos_scores.argmax() + cos_scores[0, max_score] = max_val + # Annotate the image + + cv2_im = cv2_img.msg_to_cv2_img(img_msg) + if annotate: + cv2.putText( + cv2_im, + f"Most likely caption: {answers[max_score]}", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 0), + 2, + cv2.LINE_AA, + ) + # add second score below + cv2.putText( + cv2_im, + f"Second most likely caption: {answers[second_max_score]}", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 0), + 2, + cv2.LINE_AA, + ) + + img = cv2_img.cv2_img_to_msg(cv2_im) + return answers[max_score], cos_scores, img diff --git a/common/vision/lasr_vision_deepface/examples/greet b/common/vision/lasr_vision_deepface/examples/greet index 63778c6af..70a87a038 100644 --- a/common/vision/lasr_vision_deepface/examples/greet +++ b/common/vision/lasr_vision_deepface/examples/greet @@ -7,51 +7,63 @@ from copy import deepcopy from sensor_msgs.msg import Image from lasr_vision_msgs.srv import Recognise, RecogniseRequest +from lasr_voice import Voice if len(sys.argv) < 3: - print('Usage: rosrun lase_recognition greet ') + print("Usage: rosrun lase_recognition greet ") exit() listen_topic = sys.argv[1] dataset = sys.argv[2] people_in_frame = [] -last_received_time = None + + +people_in_frame = {} def detect(image): rospy.loginfo("Received image message") global people_in_frame - people_in_frame = [] try: - detect_service = rospy.ServiceProxy('/recognise', Recognise) + detect_service = rospy.ServiceProxy("/recognise", Recognise) req = RecogniseRequest() req.image_raw = image req.dataset = dataset - req.confidence = 0.5 + req.confidence = 0.4 resp = detect_service(req) for detection in resp.detections: - people_in_frame.append(detection.name) - print(resp) + people_in_frame[detection.name] = rospy.Time.now() except rospy.ServiceException as e: rospy.logerr("Service call failed: %s" % e) + def greet(): - print(f"Hello, {' '.join(people_in_frame)}") + voice = Voice() + voice.speak(f"Hello, {' '.join(people_in_frame)}") + + def image_callback(image): - global last_received_time - if last_received_time is None or rospy.Time.now() - last_received_time >= rospy.Duration(5.0): - prev_people_in_frame = deepcopy(people_in_frame) - detect(image) - if people_in_frame != prev_people_in_frame: - greet() - last_received_time = rospy.Time.now() + global people_in_frame + prev_people_in_frame = list(people_in_frame.keys()) + # remove detections from people_in_frame that are older than 5 seconds long + detect(image) + for person in list(people_in_frame.keys()): + if rospy.Time.now() - people_in_frame[person] > rospy.Duration(10): + del people_in_frame[person] + if ( + list(people_in_frame.keys()) != prev_people_in_frame + and len(people_in_frame) > 0 + ) or (len(prev_people_in_frame) == 0 and len(people_in_frame) > 0): + greet() + def listener(): - rospy.init_node('image_listener', anonymous=True) - rospy.wait_for_service('/recognise') - rospy.Subscriber(listen_topic, Image, image_callback) + rospy.init_node("image_listener", anonymous=True) + rospy.wait_for_service("/recognise") + rospy.Subscriber(listen_topic, Image, image_callback, queue_size=1) rospy.spin() -if __name__ == '__main__': + +if __name__ == "__main__": listener() diff --git a/common/vision/lasr_vision_deepface/examples/relay b/common/vision/lasr_vision_deepface/examples/relay index af1c9bd4d..5c35ee887 100644 --- a/common/vision/lasr_vision_deepface/examples/relay +++ b/common/vision/lasr_vision_deepface/examples/relay @@ -28,7 +28,7 @@ def detect(image): req = RecogniseRequest() req.image_raw = image req.dataset = dataset - req.confidence = 0.7 + req.confidence = 0.4 resp = detect_service(req) print(resp) except rospy.ServiceException as e: diff --git a/common/vision/lasr_vision_deepface/launch/example.launch b/common/vision/lasr_vision_deepface/launch/example.launch new file mode 100644 index 000000000..5d3281867 --- /dev/null +++ b/common/vision/lasr_vision_deepface/launch/example.launch @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/common/vision/lasr_vision_deepface/nodes/service b/common/vision/lasr_vision_deepface/nodes/service index 12149844c..d6a71384a 100644 --- a/common/vision/lasr_vision_deepface/nodes/service +++ b/common/vision/lasr_vision_deepface/nodes/service @@ -4,31 +4,86 @@ import re import rospy import lasr_vision_deepface as face_recognition from sensor_msgs.msg import Image -from lasr_vision_msgs.srv import Recognise, RecogniseRequest, RecogniseResponse +from lasr_vision_msgs.srv import ( + Recognise, + RecogniseRequest, + RecogniseResponse, + LearnFace, + LearnFaceRequest, + LearnFaceResponse, +) + rospy.init_node("recognise_service") # Determine variables DEBUG = rospy.get_param("~debug", False) -debug_publishers = {} +recognise_debug_publishers = {} +learn_face_debug_publishers = {} if DEBUG: - debug_publisher = rospy.Publisher("/recognise/debug", Image, queue_size=1) + recognise_debug_publisher = rospy.Publisher("/recognise/debug", Image, queue_size=1) + learn_face_debug_publisher = rospy.Publisher( + "/learn_face/debug", Image, queue_size=1 + ) + cropped_face_publisher = rospy.Publisher( + "/learn_face/debug/cropped_query_face", Image, queue_size=1 + ) def detect(request: RecogniseRequest) -> RecogniseResponse: debug_publisher = None + similar_face_debug_publisher = None + cropped_face_publisher = None if DEBUG: - if request.dataset in debug_publishers: - debug_publisher = debug_publishers[request.dataset] + if request.dataset in recognise_debug_publishers: + debug_publisher = recognise_debug_publishers[request.dataset][0] + similar_face_debug_publisher = recognise_debug_publishers[request.dataset][ + 1 + ] + cropped_face_publisher = recognise_debug_publisher[request.dataset][2] else: topic_name = re.sub(r"[\W_]+", "", request.dataset) debug_publisher = rospy.Publisher( f"/recognise/debug/{topic_name}", Image, queue_size=1 ) - return face_recognition.detect(request, debug_publisher) + similar_face_debug_publisher = rospy.Publisher( + f"/recognise/debug/{topic_name}/similar_face", Image, queue_size=1 + ) + cropped_face_publisher = rospy.Publisher( + "/recognise/debug/cropped_query_face", Image, queue_size=1 + ) + recognise_debug_publishers[request.dataset] = ( + debug_publisher, + similar_face_debug_publisher, + cropped_face_publisher, + ) + return face_recognition.detect( + request, debug_publisher, similar_face_debug_publisher, cropped_face_publisher + ) + + +def learn_face(request: LearnFaceRequest) -> LearnFaceResponse: + debug_publisher = None + if DEBUG: + if request.dataset in learn_face_debug_publishers: + debug_publisher = learn_face_debug_publishers[request.dataset] + else: + topic_name = re.sub(r"[\W_]+", "", request.dataset) + debug_publisher = rospy.Publisher( + f"/learn_face/debug/{topic_name}", Image, queue_size=1 + ) + face_recognition.create_dataset( + "/xtion/rgb/image_raw", + request.dataset, + request.name, + request.n_images, + debug_publisher, + ) + return LearnFaceResponse() rospy.Service("/recognise", Recognise, detect) -rospy.loginfo("Face Recognition service starter") +rospy.Service("/learn_face", LearnFace, learn_face) +rospy.loginfo("Face Recognition service started") rospy.spin() diff --git a/common/vision/lasr_vision_deepface/scripts/create_dataset b/common/vision/lasr_vision_deepface/scripts/create_dataset index 9796b5665..85f8e6b3c 100644 --- a/common/vision/lasr_vision_deepface/scripts/create_dataset +++ b/common/vision/lasr_vision_deepface/scripts/create_dataset @@ -19,27 +19,7 @@ else: size = 50 import rospy -import rospkg -from sensor_msgs.msg import Image -import os -import cv2_img -import cv2 - -DATASET_ROOT = os.path.join( - rospkg.RosPack().get_path("lasr_vision_deepface"), "datasets" -) -DATASET_PATH = os.path.join(DATASET_ROOT, dataset, name) -if not os.path.exists(DATASET_PATH): - os.makedirs(DATASET_PATH) rospy.init_node("create_dataset") -rospy.loginfo(f"Taking {size} pictures of {name} and saving to {DATASET_PATH}") -for i in range(size): - img_msg = rospy.wait_for_message(topic, Image) - cv_im = cv2_img.msg_to_cv2_img(img_msg) - face_cropped_cv_im = face_recognition.detect_face(cv_im) - if face_cropped_cv_im is None: - continue - cv2.imwrite(os.path.join(DATASET_PATH, f"{name}_{i+1}.png"), face_cropped_cv_im) - rospy.loginfo(f"Took picture {i+1}") +face_recognition.create_dataset(topic, dataset, name, size) diff --git a/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/__init__.py b/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/__init__.py index 5a5bb56d8..e69de29bb 100644 --- a/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/__init__.py +++ b/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/__init__.py @@ -1 +0,0 @@ -from .deepface import detect, detect_face diff --git a/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/deepface.py b/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/deepface.py index af8895dad..ce9c4771b 100644 --- a/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/deepface.py +++ b/common/vision/lasr_vision_deepface/src/lasr_vision_deepface/deepface.py @@ -5,14 +5,19 @@ import rospkg import rospy import os +import numpy as np +import pandas as pd from lasr_vision_msgs.msg import Detection from lasr_vision_msgs.srv import RecogniseRequest, RecogniseResponse +from sensor_msgs.msg import Image + DATASET_ROOT = os.path.join( rospkg.RosPack().get_path("lasr_vision_deepface"), "datasets" ) + Mat = int # np.typing.NDArray[np.uint8] @@ -28,11 +33,90 @@ def detect_face(cv_im: Mat) -> Mat | None: return None facial_area = faces[0]["facial_area"] x, y, w, h = facial_area["x"], facial_area["y"], facial_area["w"], facial_area["h"] + + # add padding to the face + x -= 10 + y -= 10 + w += 20 + h += 20 + return cv_im[:][y : y + h, x : x + w] +def create_image_collage(images, output_size=(640, 480)): + + # Calculate grid dimensions + num_images = len(images) + rows = int(np.sqrt(num_images)) + print(num_images, rows) + cols = (num_images + rows - 1) // rows # Ceiling division + + # Resize images to fit in the grid + resized_images = [ + cv2.resize(img, (output_size[0] // cols, output_size[1] // rows)) + for img in images + ] + + # Create the final image grid + grid_image = np.zeros((output_size[1], output_size[0], 3), dtype=np.uint8) + + # Populate the grid with resized images + for i in range(rows): + for j in range(cols): + idx = i * cols + j + if idx < num_images: + y_start = i * (output_size[1] // rows) + y_end = (i + 1) * (output_size[1] // rows) + x_start = j * (output_size[0] // cols) + x_end = (j + 1) * (output_size[0] // cols) + + grid_image[y_start:y_end, x_start:x_end] = resized_images[idx] + + return grid_image + + +def create_dataset( + topic: str, + dataset: str, + name: str, + size: int, + debug_publisher: rospy.Publisher | None, +) -> None: + dataset_path = os.path.join(DATASET_ROOT, dataset, name) + if not os.path.exists(dataset_path): + os.makedirs(dataset_path) + rospy.loginfo(f"Taking {size} pictures of {name} and saving to {dataset_path}") + + images = [] + for i in range(size): + img_msg = rospy.wait_for_message(topic, Image) + cv_im = cv2_img.msg_to_cv2_img(img_msg) + face_cropped_cv_im = detect_face(cv_im) + if face_cropped_cv_im is None: + continue + cv2.imwrite(os.path.join(dataset_path, f"{name}_{i+1}.png"), face_cropped_cv_im) # type: ignore + rospy.loginfo(f"Took picture {i+1}") + images.append(face_cropped_cv_im) + if debug_publisher is not None: + debug_publisher.publish( + cv2_img.cv2_img_to_msg(create_image_collage(images)) + ) + + # Force retraining + DeepFace.find( + cv_im, + os.path.join(DATASET_ROOT, dataset), + enforce_detection=False, + silent=True, + detector_backend="mtcnn", + ) + + def detect( - request: RecogniseRequest, debug_publisher: rospy.Publisher | None + request: RecogniseRequest, + debug_publisher: rospy.Publisher | None, + debug_inference_pub: rospy.Publisher | None, + cropped_detect_pub: rospy.Publisher | None, ) -> RecogniseResponse: # Decode the image rospy.loginfo("Decoding") @@ -49,7 +133,6 @@ def detect( enforce_detection=True, silent=True, detector_backend="mtcnn", - threshold=request.confidence, ) except ValueError: return response @@ -66,14 +149,19 @@ def detect( row["source_h"][0], ) detection.xywh = [x, y, w, h] - detection.confidence = 1.0 - row["distance"][0] + detection.confidence = row["distance"][0] response.detections.append(detection) + cropped_image = cv_im[:][y : y + h, x : x + w] + + if cropped_detect_pub is not None: + cropped_detect_pub.publish(cv2_img.cv2_img_to_msg(cropped_image)) + # Draw bounding boxes and labels for debugging cv2.rectangle(cv_im, (x, y), (x + w, y + h), (0, 0, 255), 2) cv2.putText( cv_im, - f"{detection.name} ({detection.confidence})", + f"{detection.name} Distance: ({detection.confidence:.2f})", (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, @@ -84,5 +172,16 @@ def detect( # publish to debug topic if debug_publisher is not None: debug_publisher.publish(cv2_img.cv2_img_to_msg(cv_im)) + if debug_inference_pub is not None: + result = pd.concat(result) + # check for empty result + if not result.empty: + result_paths = list(result["identity"]) + if len(result_paths) > 5: + result_paths = result_paths[:5] + result_images = [cv2.imread(path) for path in result_paths] + debug_inference_pub.publish( + cv2_img.cv2_img_to_msg(create_image_collage(result_images)) + ) return response diff --git a/common/vision/lasr_vision_msgs/CMakeLists.txt b/common/vision/lasr_vision_msgs/CMakeLists.txt index 0f781a520..f79567dc5 100644 --- a/common/vision/lasr_vision_msgs/CMakeLists.txt +++ b/common/vision/lasr_vision_msgs/CMakeLists.txt @@ -60,14 +60,10 @@ add_service_files( BodyPixDetection.srv TorchFaceFeatureDetection.srv Recognise.srv + LearnFace.srv ) -## Generate actions in the 'action' folder -# add_action_files( -# FILES -# Action1.action -# Action2.action -# ) +# Generate actions in the 'action' folder ## Generate added messages and services with any dependencies listed here generate_messages( diff --git a/common/vision/lasr_vision_msgs/srv/LearnFace.srv b/common/vision/lasr_vision_msgs/srv/LearnFace.srv new file mode 100644 index 000000000..5376e53d5 --- /dev/null +++ b/common/vision/lasr_vision_msgs/srv/LearnFace.srv @@ -0,0 +1,10 @@ +# Name to associate +string name + +# Dataset to add face to +string dataset + +# Number of images to take +int32 n_images + +---