diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server index ed1c01d25..afa55f215 100644 --- a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server +++ b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server @@ -1,6 +1,6 @@ #!/usr/bin/env python3 - import os +import sounddevice # needed to remove ALSA error messages import argparse from typing import Optional from dataclasses import dataclass @@ -16,29 +16,6 @@ import lasr_speech_recognition_msgs.msg # type: ignore from std_msgs.msg import String # type: ignore from lasr_speech_recognition_whisper import load_model # type: ignore -# Error handler to remove ALSA error messages taken from: -# https://stackoverflow.com/questions/7088672/pyaudio-working-but-spits-out-error-messages-each-time/17673011#17673011 - -from ctypes import * -from contextlib import contextmanager - -ERROR_HANDLER_FUNC = CFUNCTYPE(None, c_char_p, c_int, c_char_p, c_int, c_char_p) - - -def py_error_handler(filename, line, function, err, fmt): - pass - - -c_error_handler = ERROR_HANDLER_FUNC(py_error_handler) - - -@contextmanager -def noalsaerr(): - asound = cdll.LoadLibrary("libasound.so") - asound.snd_lib_error_set_handler(c_error_handler) - yield - asound.snd_lib_error_set_handler(None) - @dataclass class speech_model_params: @@ -87,40 +64,25 @@ class TranscribeSpeechAction(object): self._transcription_server = rospy.Publisher( "/live_speech_transcription", String, queue_size=10 ) - with noalsaerr(): - self._model = load_model( - self._model_params.model_name, - self._model_params.device, - self._model_params.warmup, - ) - # Configure the speech recogniser object and adjust for ambient noise - self.recogniser = self._configure_recogniser(ambient_adj=True) - # Setup the action server and register execution callback - self._action_server = actionlib.SimpleActionServer( - self._action_name, - lasr_speech_recognition_msgs.msg.TranscribeSpeechAction, - execute_cb=self.execute_cb, - auto_start=False, - ) - self._action_server.register_preempt_callback(self.prempt_cb) - # Setup the timer for adjusting the microphone for ambient noise every x seconds - self._timer_duration = self._model_params.timer_duration - self._timer = rospy.Timer( - rospy.Duration(self._timer_duration), self._timer_cb - ) - self._listening = False - self._action_server.start() + self._model = load_model( + self._model_params.model_name, + self._model_params.device, + self._model_params.warmup, + ) + # Configure the speech recogniser object and adjust for ambient noise + self.recogniser = self._configure_recogniser(ambient_adj=True) + # Setup the action server and register execution callback + self._action_server = actionlib.SimpleActionServer( + self._action_name, + lasr_speech_recognition_msgs.msg.TranscribeSpeechAction, + execute_cb=self.execute_cb, + auto_start=False, + ) + self._action_server.register_preempt_callback(self.prempt_cb) + self._listening = False - def _timer_cb(self, _) -> None: - return - """Adjusts the microphone for ambient noise, unless the action server is listening.""" - if self._listening: - return - rospy.loginfo("Adjusting microphone for ambient noise...") - with noalsaerr(): - with self._configure_microphone() as source: - self.recogniser.adjust_for_ambient_noise(source) + self._action_server.start() def _reset_timer(self) -> None: """Resets the timer for adjusting the microphone for ambient noise.""" @@ -194,17 +156,13 @@ class TranscribeSpeechAction(object): rospy.loginfo("Request Received") if self._action_server.is_preempt_requested(): return - # Since we are about to listen, reset the timer for adjusting the microphone for ambient noise - # as this assumes self_timer_duration seconds of silence before adjusting - self._reset_timer() - with noalsaerr(): - with self._configure_microphone() as src: - self._listening = True - wav_data = self.recogniser.listen( - src, - timeout=self._model_params.start_timeout, - phrase_time_limit=self._model_params.end_timeout, - ).get_wav_data() + with self._configure_microphone() as src: + self._listening = True + wav_data = self.recogniser.listen( + src, + timeout=self._model_params.start_timeout, + phrase_time_limit=self._model_params.end_timeout, + ).get_wav_data() # Magic number 32768.0 is the maximum value of a 16-bit signed integer float_data = ( np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") @@ -293,12 +251,6 @@ def parse_args() -> dict: default=None, help="Microphone device index or name", ) - parser.add_argument( - "--timer_duration", - type=int, - default=20, - help="Number of seconds of silence before the ambient noise adjustment is called.", - ) parser.add_argument( "--no_warmup", action="store_true", @@ -331,8 +283,6 @@ def configure_model_params(config: dict) -> speech_model_params: model_params.sample_rate = config["sample_rate"] if config["mic_device"]: model_params.mic_device = config["mic_device"] - if config["timer_duration"]: - model_params.timer_duration = config["timer_duration"] if config["no_warmup"]: model_params.warmup = False diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.in b/common/speech/lasr_speech_recognition_whisper/requirements.in index 8209d34fc..da48c5086 100644 --- a/common/speech/lasr_speech_recognition_whisper/requirements.in +++ b/common/speech/lasr_speech_recognition_whisper/requirements.in @@ -1,5 +1,6 @@ SpeechRecognition==3.10.0 -openai-whisper==20230314 +sounddevice==0.4.6 +openai-whisper==20231117 PyAudio==0.2.13 PyYaml==6.0.1 rospkg==1.5.0 diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.txt b/common/speech/lasr_speech_recognition_whisper/requirements.txt index a6d9bdebe..1cace21e8 100644 --- a/common/speech/lasr_speech_recognition_whisper/requirements.txt +++ b/common/speech/lasr_speech_recognition_whisper/requirements.txt @@ -1,51 +1,54 @@ -catkin-pkg==0.5.2 # via rospkg -certifi==2023.7.22 # via requests -charset-normalizer==3.2.0 # via requests -cmake==3.27.2 # via triton -distro==1.8.0 # via rospkg +--extra-index-url https://pypi.ngc.nvidia.com +--trusted-host pypi.ngc.nvidia.com + +catkin-pkg==1.0.0 # via rospkg +certifi==2024.2.2 # via requests +cffi==1.16.0 # via sounddevice +charset-normalizer==3.3.2 # via requests +distro==1.9.0 # via rospkg docutils==0.20.1 # via catkin-pkg -ffmpeg-python==0.2.0 # via openai-whisper -filelock==3.12.2 # via torch, triton -future==0.18.3 # via ffmpeg-python -idna==3.4 # via requests -jinja2==3.1.2 # via torch -lit==16.0.6 # via triton -llvmlite==0.40.1 # via numba -markupsafe==2.1.3 # via jinja2 -more-itertools==10.1.0 # via openai-whisper +filelock==3.13.1 # via torch, triton +fsspec==2024.2.0 # via torch +idna==3.6 # via requests +jinja2==3.1.3 # via torch +llvmlite==0.42.0 # via numba +markupsafe==2.1.5 # via jinja2 +more-itertools==10.2.0 # via openai-whisper mpmath==1.3.0 # via sympy -networkx==3.1 # via torch -numba==0.57.1 # via openai-whisper -numpy==1.24.4 # via numba, openai-whisper -nvidia-cublas-cu11==11.10.3.66 # via nvidia-cudnn-cu11, nvidia-cusolver-cu11, torch -nvidia-cuda-cupti-cu11==11.7.101 # via torch -nvidia-cuda-nvrtc-cu11==11.7.99 # via torch -nvidia-cuda-runtime-cu11==11.7.99 # via torch -nvidia-cudnn-cu11==8.5.0.96 # via torch -nvidia-cufft-cu11==10.9.0.58 # via torch -nvidia-curand-cu11==10.2.10.91 # via torch -nvidia-cusolver-cu11==11.4.0.1 # via torch -nvidia-cusparse-cu11==11.7.4.91 # via torch -nvidia-nccl-cu11==2.14.3 # via torch -nvidia-nvtx-cu11==11.7.91 # via torch -openai-whisper==20230314 # via -r requirements.in +networkx==3.2.1 # via torch +numba==0.59.0 # via openai-whisper +numpy==1.26.4 # via numba, openai-whisper +nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch +nvidia-cuda-cupti-cu12==12.1.105 # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 # via torch +nvidia-cuda-runtime-cu12==12.1.105 # via torch +nvidia-cudnn-cu12==8.9.2.26 # via torch +nvidia-cufft-cu12==11.0.2.54 # via torch +nvidia-curand-cu12==10.3.2.106 # via torch +nvidia-cusolver-cu12==11.4.5.107 # via torch +nvidia-cusparse-cu12==12.1.0.106 # via nvidia-cusolver-cu12, torch +nvidia-nccl-cu12==2.19.3 # via torch +nvidia-nvjitlink-cu12==12.4.99 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 # via torch +openai-whisper==20231117 # via -r requirements.in pyaudio==0.2.13 # via -r requirements.in -pyparsing==3.1.1 # via catkin-pkg -python-dateutil==2.8.2 # via catkin-pkg +pycparser==2.21 # via cffi +pyparsing==3.1.2 # via catkin-pkg +python-dateutil==2.9.0.post0 # via catkin-pkg pyyaml==6.0.1 # via -r requirements.in, rospkg -regex==2023.8.8 # via tiktoken +regex==2023.12.25 # via tiktoken requests==2.31.0 # via speechrecognition, tiktoken rospkg==1.5.0 # via -r requirements.in six==1.16.0 # via python-dateutil +sounddevice==0.4.6 # via -r requirements.in speechrecognition==3.10.0 # via -r requirements.in sympy==1.12 # via torch -tiktoken==0.3.1 # via openai-whisper -torch==2.0.1 # via openai-whisper, triton -tqdm==4.66.1 # via openai-whisper -triton==2.0.0 # via openai-whisper, torch -typing-extensions==4.7.1 # via torch -urllib3==2.0.4 # via requests -wheel==0.41.1 # via nvidia-cublas-cu11, nvidia-cuda-cupti-cu11, nvidia-cuda-runtime-cu11, nvidia-curand-cu11, nvidia-cusparse-cu11, nvidia-nvtx-cu11 +tiktoken==0.6.0 # via openai-whisper +torch==2.2.1 # via openai-whisper +tqdm==4.66.2 # via openai-whisper +triton==2.2.0 # via openai-whisper, torch +typing-extensions==4.10.0 # via torch +urllib3==2.2.1 # via requests # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt index 1d550fbad..48659fa8c 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt +++ b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt @@ -7,7 +7,10 @@ project(lasr_vector_databases_faiss) ## Find catkin macros and libraries ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) ## is used, also find other catkin packages -find_package(catkin REQUIRED catkin_virtualenv) +find_package(catkin REQUIRED catkin_virtualenv COMPONENTS +rospy +lasr_vector_databases_msgs +) ## System dependencies are found with CMake's conventions # find_package(Boost REQUIRED COMPONENTS system) @@ -55,8 +58,6 @@ catkin_generate_virtualenv( ## Generate services in the 'srv' folder # add_service_files( # FILES -# Service1.srv -# Service2.srv # ) # Generate actions in the 'action' folder @@ -68,8 +69,7 @@ catkin_generate_virtualenv( # Generate added messages and services with any dependencies listed here # generate_messages( # DEPENDENCIES -# actionlib_msgs -# geometry_msgs +# std_msgs # ) ################################################ @@ -157,22 +157,13 @@ include_directories( ## Mark executable scripts (Python etc.) for installation ## in contrast to setup.py, you can choose the destination -# catkin_install_python(PROGRAMS -# nodes/qualification -# nodes/actions/wait_greet -# nodes/actions/identify -# nodes/actions/greet -# nodes/actions/get_name -# nodes/actions/learn_face -# nodes/actions/get_command -# nodes/actions/guide -# nodes/actions/find_person -# nodes/actions/detect_people -# nodes/actions/receive_object -# nodes/actions/handover_object -# nodes/better_qualification -# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} -# ) +catkin_install_python(PROGRAMS + nodes/txt_index_service + nodes/txt_query_service + scripts/test_index_service.py + scripts/test_query_service.py + DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) ## Mark executables for installation ## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html @@ -196,11 +187,10 @@ include_directories( # ) ## Mark other files for installation (e.g. launch and bag files, etc.) -# install(FILES -# # myfile1 -# # myfile2 -# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} -# ) +install(FILES + requirements.txt + DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +) ############# ## Testing ## diff --git a/common/vector_databases/lasr_vector_databases_faiss/doc/TECHNICAL.md b/common/vector_databases/lasr_vector_databases_faiss/doc/TECHNICAL.md new file mode 100644 index 000000000..e69de29bb diff --git a/common/vector_databases/lasr_vector_databases_faiss/doc/USAGE.md b/common/vector_databases/lasr_vector_databases_faiss/doc/USAGE.md new file mode 100644 index 000000000..d17914026 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/doc/USAGE.md @@ -0,0 +1,33 @@ +This package currently contains two services `txt_index_service` and `txt_query_service`. These services are used to create and search (respectively) a vector database of natural language sentence embeddings. + +# Index Service +The Index service is used to create a [FAISS](https://github.com/facebookresearch/faiss) index object containing a set of sentence embeddings, where each sentence is assumed to be a line in a given `.txt` file. This Index object is saved to disk at a specified location, and can be thought of as a Vector Database. + +## Request +The request takes two string parameters: `txt_path` which is the path to the `.txt` file we wish to create sentence embeddings for, where each line in this file is treated as a sentence; and `index_path` which is the path to a `.index` file that will be created by the Service. + +## Response +No response is given from this service. + +## Example Usage +Please see the `scripts/test_index_service.py` script for a simple example of sending a request to the service. + +# Query Service +The query service is used to search the `.index` file created by the Index Service to find the most similar sentences given an input query sentence. + +## Request +The request requires four fields: + +1. `txt_path` -- this is a `string` that is the path to the txt file that contains the original sentences that the `.index` file was populated with. +2. `index_path` -- this is a `string` that is the path to the `.index` file that was created with the Index Service, on the same txt file as the `txt_path`. +3. `query_sentence` -- this is a `string` that is the sentence that you wish to query the index with and find the most similar sentence. +4. `k` -- this is a `uint8` that is the number of closest sentences you wish to return. + +## Response +The response contains two fields: + +1. `closest_sentences` -- this is an ordered list of `string`s that contain the closest sentences to the given query sentence. +2. `cosine_similaities` -- this is an ordered list of `float32`s that contain the cosine similarity scores of the closest sentences. + +## Example Usage +Please see the `scripts/test_query_service.py` script for a simple example of sending a request to the service. \ No newline at end of file diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service new file mode 100644 index 000000000..a9c5c85b6 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +import rospy +import numpy as np +from lasr_vector_databases_msgs.srv import TxtIndexRequest, TxtIndexResponse, TxtIndex +from lasr_vector_databases_faiss import ( + load_model, + parse_txt_file, + get_sentence_embeddings, + create_vector_database, +) + + +class TxtIndexService: + def __init__(self): + rospy.init_node("txt_index_service") + rospy.Service("lasr_faiss/txt_index", TxtIndex, self.execute_cb) + self._sentence_embedding_model = load_model() + rospy.loginfo("Text index service started") + + def execute_cb(self, req: TxtIndexRequest): + txt_fp: str = req.txt_path + sentences_to_embed: list[str] = parse_txt_file(txt_fp) + sentence_embeddings: np.ndarray = get_sentence_embeddings( + sentences_to_embed, self._sentence_embedding_model + ) + index_path: str = req.index_path + create_vector_database(sentence_embeddings, index_path) + return TxtIndexResponse() + + +if __name__ == "__main__": + TxtIndexService() + rospy.spin() diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service new file mode 100644 index 000000000..e45610fd3 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +import rospy +import numpy as np + +from lasr_vector_databases_msgs.srv import ( + TxtQueryRequest, + TxtQueryResponse, + TxtQuery, +) +from lasr_vector_databases_faiss import ( + load_model, + parse_txt_file, + get_sentence_embeddings, + load_vector_database, + query_database, +) + + +class TxtQueryService: + def __init__(self): + rospy.init_node("txt_query_service") + self._sentence_embedding_model = load_model() + rospy.Service("lasr_faiss/txt_query", TxtQuery, self.execute_cb) + rospy.loginfo("Text Query service started") + + def execute_cb(self, req: TxtQueryRequest) -> TxtQueryResponse: + txt_fp: str = req.txt_path + index_path: str = req.index_path + query_sentence: str = req.query_sentence + possible_matches: list[str] = parse_txt_file(txt_fp) + query_embedding: np.ndarray = get_sentence_embeddings( + [query_sentence], self._sentence_embedding_model # requires list of strings + ) + distances, indices = query_database(index_path, query_embedding, k=req.k) + nearest_matches = [possible_matches[i] for i in indices[0]] + + return TxtQueryResponse( + closest_sentences=nearest_matches, + cosine_similarities=distances[0].tolist(), + ) + + +if __name__ == "__main__": + TxtQueryService() + rospy.spin() diff --git a/common/vector_databases/lasr_vector_databases_faiss/package.xml b/common/vector_databases/lasr_vector_databases_faiss/package.xml index f8128ea56..1546f0a38 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/package.xml +++ b/common/vector_databases/lasr_vector_databases_faiss/package.xml @@ -50,6 +50,7 @@ catkin catkin_virtualenv + lasr_vector_databases_msgs diff --git a/common/vector_databases/lasr_vector_databases_faiss/requirements.in b/common/vector_databases/lasr_vector_databases_faiss/requirements.in index 14955d38d..3259a62c9 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/requirements.in +++ b/common/vector_databases/lasr_vector_databases_faiss/requirements.in @@ -1 +1,3 @@ -faiss-cpu \ No newline at end of file +faiss-cpu +sentence-transformers +torch \ No newline at end of file diff --git a/common/vector_databases/lasr_vector_databases_faiss/scripts/test_index_service.py b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_index_service.py new file mode 100644 index 000000000..cc7f12f3c --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_index_service.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +import rospy +from lasr_vector_databases_faiss.srv import TxtIndex, TxtIndexRequest + +request = TxtIndexRequest() + +request.txt_path = ( + "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.txt" +) + +request.index_path = ( + "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.index" +) +rospy.ServiceProxy("lasr_faiss/txt_index", TxtIndex)(request) diff --git a/common/vector_databases/lasr_vector_databases_faiss/scripts/test_query_service.py b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_query_service.py new file mode 100644 index 000000000..4ae89e530 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_query_service.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +import rospy +from lasr_vector_databases_faiss.srv import TxtQuery, TxtQueryRequest + +request = TxtQueryRequest() + +request.txt_path = ( + "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.txt" +) + +request.index_path = ( + "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.index" +) + +request.query_sentence = "Do French like snails?" + +request.k = 3 + +response = rospy.ServiceProxy("lasr_faiss/txt_query", TxtQuery)(request) + +print(response.closest_sentences) +print(response.cosine_similarities) diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py index e69de29bb..698927d9c 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py @@ -0,0 +1,2 @@ +from .database_utils import create_vector_database, load_vector_database, query_database +from .get_sentence_embeddings import get_sentence_embeddings, load_model, parse_txt_file diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py deleted file mode 100755 index f7dbcbe24..000000000 --- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 -import os -import torch -import numpy as np - -# import rospy -import faiss # type: ignore -from sentence_transformers import SentenceTransformer # type: ignore -from typing import Optional - -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - - -def load_commands(command_path: str) -> list[str]: - """Loads the commands stored in the given txt file - into a list of string commands - Args: - command_path (str): path to the txt file containing - the commands -- assumes one command per line. - Returns: - list[str]: list of string commands where each entry in the - list is a command - """ - command_list = [] - with open(command_path, "r", encoding="utf8") as src: - for command in src: - # Strip newline char. - command_list.append(command[:-1]) - return command_list - - -def get_sentence_embeddings( - command_list: list[str], model: SentenceTransformer -) -> np.ndarray: - """Converts the list of command strings into an array of sentence - embeddings (where each command is a sentence and each sentence - is converted to a vector) - Args: - command_list (list[str]): list of string commands, where each - entry in the list is assumed to be a separate command - model (SentenceTransformer): model used to perform the embedding. - Assumes a method called encode that takes a list of strings - as input. - Returns: - np.ndarray: array of shape (n_commands, embedding_dim) - """ - - return model.encode( - command_list, - convert_to_numpy=True, - show_progress_bar=True, - batch_size=256, - device=DEVICE, - ) - - -def create_vector_database(vectors: np.ndarray) -> faiss.IndexFlatIP: - """Creates a vector database from an array of vectors of the same dimensionality - Args: - vectors (np.ndarray): shape (n_vectors, vector_dim) - - Returns: - faiss.IndexFlatIP: Flat index containing the vectors - """ - print("Creating vector database") - index_flat = faiss.IndexFlatIP(vectors.shape[1]) - faiss.normalize_L2(vectors) - index_flat.add(vectors) - print("Finished creating vector database") - return index_flat - - -def get_command_database( - index_path: str, command_path: Optional[str] = None -) -> faiss.IndexFlatL2: - """Gets a vector database containing a list of embedded commands. Creates the database - if the path does not exist, else, loads it into memory. - - Args: - index_path (str): Path to an existing faiss Index, or where to save a new one. - command_path (str, optional): Path of text file containing commands. - Only required if creating a new database. Defaults to None. - - Returns: - faiss.IndexFlatL2: faiss Index object containing the embedded commands. - """ - - if not os.path.exists(f"{index_path}.index"): - # rospy.loginfo("Creating new command vector database") - assert command_path is not None - command_list = load_commands(command_path) - model = SentenceTransformer("all-MiniLM-L6-v2") - command_embeddings = get_sentence_embeddings(command_list, model) - print(command_embeddings.shape) - command_database = create_vector_database(command_embeddings) - faiss.write_index(command_database, f"{index_path}.index") - # rospy.loginfo("Finished creating vector database") - else: - command_database = faiss.read_index(f"{index_path}.index") - - return command_database - - -def get_similar_commands( - command: str, - index_path: str, - command_path: str, - n_similar_commands: int = 100, - return_embeddings: bool = False, -) -> tuple[list[str], list[float]]: - """Gets the most similar commands to the given command string - Args: - command (str): command to compare against the database - index_path (str): path to the location to create or retrieve - the faiss index containing the embedded commands. - command_path (str): path to the txt file containing the commands - n_similar_commands (int, optional): number of similar commands to - return. Defaults to 100. - Returns: - list[str]: list of string commands, where each entry in the - list is a similar command - """ - command_database = get_command_database(index_path, command_path) - command_list = load_commands(command_path) - model = SentenceTransformer("all-MiniLM-L6-v2") - command_embedding = get_sentence_embeddings([command], model) - faiss.normalize_L2(command_embedding) - command_distances, command_indices = command_database.search( - command_embedding, n_similar_commands - ) - nearest_commands = [command_list[i] for i in command_indices[0]] - - if return_embeddings: - all_command_embeddings = get_sentence_embeddings(command_list, model) - # filter for only nererst commands - all_command_embeddings = all_command_embeddings[command_indices[0]] - return ( - nearest_commands, - list(command_distances[0]), - all_command_embeddings, - command_embedding, - ) - - return nearest_commands, list(command_distances[0]) - - -if __name__ == "__main__": - """Example usage of using this to find similar commands""" - command = "find Jared and asks if he needs help" - result, distances, command_embeddings, query_embedding = get_similar_commands( - command, - "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/qualification/data/command_index", - "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/qualification/data/command_list.txt", - n_similar_commands=1000, - return_embeddings=True, - ) - print(result) diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py new file mode 100644 index 000000000..6d0139072 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +import os +import numpy as np +import faiss + + +def create_vector_database( + vectors: np.ndarray, + index_path: str, + overwrite: bool = False, + index_type: str = "Flat", + normalise_vecs: bool = True, +) -> None: + """Creates a FAISS Index using the factory constructor and the given + index type, and adds the given vector to the index, and then saves + it to disk using the given path. + + Args: + vectors (np.ndarray): vector of shape (n_vectors, vector_dim) + index_path (str): path to save the index + overwrite (bool, optional): Whether to replace an existing index + at the same filepath if it exists. Defaults to False. + index_type (str, optional): FAISS Index Factory string. Defaults to "IndexFlatIP". + normalise_vecs (bool, optional): Whether to normalise the vectors before + adding them to the Index. This converts the IP metric to Cosine Similarity. + Defaults to True. + """ + + if os.path.exists(index_path) and not overwrite: + raise FileExistsError( + f"Index already exists at {index_path}. Set overwrite=True to replace it." + ) + + index = faiss.index_factory( + vectors.shape[1], index_type, faiss.METRIC_INNER_PRODUCT + ) + if normalise_vecs: + faiss.normalize_L2(vectors) + index.add(vectors) + faiss.write_index(index, index_path) + + +def load_vector_database(index_path: str, use_gpu: bool = False) -> faiss.Index: + """Loads a FAISS Index from the given filepath + + Args: + index_path (str): path to the index file + use_gpu (bool, optional): Whether to load the index onto the GPU. + Defaults to False. + + Returns: + faiss.Index: FAISS Index object + """ + print("Loading index from", index_path) + index = faiss.read_index(index_path) + print("Loaded index with ntotal:", index.ntotal) + if use_gpu: + index = faiss.index_cpu_to_all_gpus(index) + return index + + +def query_database( + index_path: str, + query_vectors: np.ndarray, + normalise: bool = True, + k: int = 1, +) -> tuple[np.ndarray, np.ndarray]: + """Queries the given index with the given query vectors + + Args: + index_path (str): path to the index file + query_vectors (np.ndarray): query vectors of shape (n_queries, vector_dim) + normalise (bool, optional): Whether to normalise the query vectors. + Defaults to True. + k (int, optional): Number of nearest neighbours to return. Defaults to 1. + + Returns: + tuple[np.ndarray, np.ndarray]: (distances, indices) of the nearest neighbours + each of shape (n_queries, n_neighbours) + """ + index = load_vector_database(index_path) + if normalise: + faiss.normalize_L2(query_vectors) + distances, indices = index.search(query_vectors, k) + return distances, indices diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py new file mode 100644 index 000000000..1ba43b48c --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +import torch +import numpy as np +from sentence_transformers import SentenceTransformer + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +def load_model(model_name: str = "all-MiniLM-L6-v2") -> SentenceTransformer: + """Loads the sentence transformer model + Args: + model_name (str): name of the model to load + Returns: + sentence_transformers.SentenceTransformer: the loaded model + """ + return SentenceTransformer(model_name, device=DEVICE) + + +def parse_txt_file(fp: str) -> list[str]: + """Parses a txt file into a list of strings, + where each element is a line in the txt file with the + newline char stripped. + Args: + fp (str): path to the txt file to load + Returns: + list[str]: list of strings where each element is a line in the txt file + """ + sentences = [] + with open(fp, "r", encoding="utf8") as src: + for line in src: + # Strip newline char. + sentences.append(line[:-1]) + return sentences + + +def get_sentence_embeddings( + sentence_list: list[str], model: SentenceTransformer +) -> np.ndarray: + """Converts the list of string sentences into an array of sentence + embeddings + Args: + sentence_list (list[str]): list of string sentences, where each + entry in the list is assumed to be a separate sentence + model (SentenceTransformer): model used to perform the embedding. + Assumes a method called encode that takes a list of strings + as input. + Returns: + np.ndarray: array of shape (n_commands, embedding_dim) + """ + return model.encode( + sentence_list, + convert_to_numpy=True, + show_progress_bar=True, + batch_size=256, + device=DEVICE, + ) diff --git a/common/vector_databases/lasr_vector_databases_msgs/CMakeLists.txt b/common/vector_databases/lasr_vector_databases_msgs/CMakeLists.txt new file mode 100644 index 000000000..63fd162dc --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_msgs/CMakeLists.txt @@ -0,0 +1,194 @@ +cmake_minimum_required(VERSION 3.0.2) +project(lasr_vector_databases_msgs) + +## Compile as C++11, supported in ROS Kinetic and newer +# add_compile_options(-std=c++11) + +## Find catkin macros and libraries +## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) +## is used, also find other catkin packages +find_package(catkin REQUIRED COMPONENTS message_generation message_runtime) + +## System dependencies are found with CMake's conventions +# find_package(Boost REQUIRED COMPONENTS system) + + +## Uncomment this if the package has a setup.py. This macro ensures +## modules and global scripts declared therein get installed +## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html +# catkin_python_setup() + +################################################ +## Declare ROS messages, services and actions ## +################################################ + +## To declare and build messages, services or actions from within this +## package, follow these steps: +## * Let MSG_DEP_SET be the set of packages whose message types you use in +## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...). +## * In the file package.xml: +## * add a build_depend tag for "message_generation" +## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET +## * If MSG_DEP_SET isn't empty the following dependency has been pulled in +## but can be declared for certainty nonetheless: +## * add a exec_depend tag for "message_runtime" +## * In this file (CMakeLists.txt): +## * add "message_generation" and every package in MSG_DEP_SET to +## find_package(catkin REQUIRED COMPONENTS ...) +## * add "message_runtime" and every package in MSG_DEP_SET to +## catkin_package(CATKIN_DEPENDS ...) +## * uncomment the add_*_files sections below as needed +## and list every .msg/.srv/.action file to be processed +## * uncomment the generate_messages entry below +## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...) + +## Generate messages in the 'msg' folder +# add_message_files( +# FILES +# ) + +## Generate services in the 'srv' folder +add_service_files( + FILES + TxtIndex.srv + TxtQuery.srv +) + +# Generate actions in the 'action' folder + +## Generate added messages and services with any dependencies listed here +generate_messages( + DEPENDENCIES +) + +################################################ +## Declare ROS dynamic reconfigure parameters ## +################################################ + +## To declare and build dynamic reconfigure parameters within this +## package, follow these steps: +## * In the file package.xml: +## * add a build_depend and a exec_depend tag for "dynamic_reconfigure" +## * In this file (CMakeLists.txt): +## * add "dynamic_reconfigure" to +## find_package(catkin REQUIRED COMPONENTS ...) +## * uncomment the "generate_dynamic_reconfigure_options" section below +## and list every .cfg file to be processed + +## Generate dynamic reconfigure parameters in the 'cfg' folder +# generate_dynamic_reconfigure_options( +# cfg/DynReconf1.cfg +# cfg/DynReconf2.cfg +# ) + +################################### +## catkin specific configuration ## +################################### +## The catkin_package macro generates cmake config files for your package +## Declare things to be passed to dependent projects +## INCLUDE_DIRS: uncomment this if your package contains header files +## LIBRARIES: libraries you create in this project that dependent projects also need +## CATKIN_DEPENDS: catkin_packages dependent projects also need +## DEPENDS: system dependencies of this project that dependent projects also need +catkin_package( +# INCLUDE_DIRS include +# LIBRARIES lasr_vision_msgs +# CATKIN_DEPENDS other_catkin_pkg +# DEPENDS system_lib +) + +########### +## Build ## +########### + +## Specify additional locations of header files +## Your package locations should be listed before other locations +include_directories( +# include +# ${catkin_INCLUDE_DIRS} +) + +## Declare a C++ library +# add_library(${PROJECT_NAME} +# src/${PROJECT_NAME}/lasr_vision_msgs.cpp +# ) + +## Add cmake target dependencies of the library +## as an example, code may need to be generated before libraries +## either from message generation or dynamic reconfigure +# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Declare a C++ executable +## With catkin_make all packages are built within a single CMake context +## The recommended prefix ensures that target names across packages don't collide +# add_executable(${PROJECT_NAME}_node src/lasr_vision_msgs_node.cpp) + +## Rename C++ executable without prefix +## The above recommended prefix causes long target names, the following renames the +## target back to the shorter version for ease of user use +## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" +# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") + +## Add cmake target dependencies of the executable +## same as for the library above +# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Specify libraries to link a library or executable target against +# target_link_libraries(${PROJECT_NAME}_node +# ${catkin_LIBRARIES} +# ) + +############# +## Install ## +############# + +# all install targets should use catkin DESTINATION variables +# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html + +## Mark executable scripts (Python etc.) for installation +## in contrast to setup.py, you can choose the destination +# catkin_install_python(PROGRAMS +# scripts/my_python_script +# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark executables for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html +# install(TARGETS ${PROJECT_NAME}_node +# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark libraries for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html +# install(TARGETS ${PROJECT_NAME} +# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION} +# ) + +## Mark cpp header files for installation +# install(DIRECTORY include/${PROJECT_NAME}/ +# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} +# FILES_MATCHING PATTERN "*.h" +# PATTERN ".svn" EXCLUDE +# ) + +## Mark other files for installation (e.g. launch and bag files, etc.) +# install(FILES +# # myfile1 +# # myfile2 +# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +# ) + +############# +## Testing ## +############# + +## Add gtest based cpp test target and link libraries +# catkin_add_gtest(${PROJECT_NAME}-test test/test_lasr_vision_msgs.cpp) +# if(TARGET ${PROJECT_NAME}-test) +# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) +# endif() + +## Add folders to be run by python nosetests +# catkin_add_nosetests(test) diff --git a/common/vector_databases/lasr_vector_databases_msgs/package.xml b/common/vector_databases/lasr_vector_databases_msgs/package.xml new file mode 100644 index 000000000..5f4e45e9f --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_msgs/package.xml @@ -0,0 +1,61 @@ + + + lasr_vector_databases_msgs + 0.0.0 + Messages required for vector databases + + + + + Matt Barker + + + + + + MIT + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + catkin + message_generation + message_runtime + + + + + + + + diff --git a/common/vector_databases/lasr_vector_databases_msgs/srv/TxtIndex.srv b/common/vector_databases/lasr_vector_databases_msgs/srv/TxtIndex.srv new file mode 100644 index 000000000..79ac01654 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_msgs/srv/TxtIndex.srv @@ -0,0 +1,7 @@ +# Path to input text file +string txt_path + +# Output path to save index +string index_path + +--- diff --git a/common/vector_databases/lasr_vector_databases_msgs/srv/TxtQuery.srv b/common/vector_databases/lasr_vector_databases_msgs/srv/TxtQuery.srv new file mode 100644 index 000000000..bbcb04613 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_msgs/srv/TxtQuery.srv @@ -0,0 +1,19 @@ +# Path to input text file +string txt_path + +# Path to index file to load +string index_path + +# Sentence to query index with +string query_sentence + +# Number of nearest sentences to return +uint8 k + +--- +# Nearest sentence +string[] closest_sentences + +# Cosine similarity of distances +float32[] cosine_similarities + diff --git a/common/vision/lasr_vision_clip/CMakeLists.txt b/common/vision/lasr_vision_clip/CMakeLists.txt index a13eb6f2a..c2ce23209 100644 --- a/common/vision/lasr_vision_clip/CMakeLists.txt +++ b/common/vision/lasr_vision_clip/CMakeLists.txt @@ -196,11 +196,10 @@ include_directories( # ) ## Mark other files for installation (e.g. launch and bag files, etc.) -# install(FILES -# # myfile1 -# # myfile2 -# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} -# ) +install(FILES + requirements.txt + DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +) ############# ## Testing ## diff --git a/skills/CMakeLists.txt b/skills/CMakeLists.txt index cc256ca3d..6a5978ed6 100644 --- a/skills/CMakeLists.txt +++ b/skills/CMakeLists.txt @@ -10,6 +10,7 @@ project(lasr_skills) find_package(catkin REQUIRED COMPONENTS rospy lasr_vision_msgs + lasr_vector_databases_msgs ) ## System dependencies are found with CMake's conventions diff --git a/skills/package.xml b/skills/package.xml index 44a4f5827..053a2c85b 100644 --- a/skills/package.xml +++ b/skills/package.xml @@ -53,6 +53,7 @@ rospy rospy lasr_vision_msgs + lasr_vector_databases_msgs diff --git a/skills/src/lasr_skills/__init__.py b/skills/src/lasr_skills/__init__.py index f1a0a65b6..4edfb33aa 100755 --- a/skills/src/lasr_skills/__init__.py +++ b/skills/src/lasr_skills/__init__.py @@ -5,6 +5,7 @@ from .wait_for_person_in_area import WaitForPersonInArea from .describe_people import DescribePeople from .look_to_point import LookToPoint -from .look_to_point import LookToPoint from .go_to_location import GoToLocation from .go_to_semantic_location import GoToSemanticLocation +from .listen import Listen +from .say import Say \ No newline at end of file diff --git a/skills/src/lasr_skills/ask_and_listen.py b/skills/src/lasr_skills/ask_and_listen.py new file mode 100644 index 000000000..6f110c40a --- /dev/null +++ b/skills/src/lasr_skills/ask_and_listen.py @@ -0,0 +1,34 @@ +import smach +from lasr_skills import Listen +from lasr_skills import Say + + +class AskAndListen(smach.StateMachine): + def __init__(self): + smach.StateMachine.__init__( + self, + outcomes=["succeeded", "failed"], + output_keys=["transcribed_speech"], + input_keys=["tts_phrase"], + ) + with self: + smach.StateMachine.add( + "SAY", + Say(), + transitions={ + "succeeded": "LISTEN", + "aborted": "failed", + "preempted": "failed", + }, + remapping={"text": "tts_phrase"}, + ) + smach.StateMachine.add( + "LISTEN", + Listen(), + transitions={ + "succeeded": "succeeded", + "aborted": "failed", + "preempted": "failed", + }, + remapping={"sequence": "transcribed_speech"}, + ) diff --git a/skills/src/lasr_skills/listen.py b/skills/src/lasr_skills/listen.py new file mode 100644 index 000000000..272f24b47 --- /dev/null +++ b/skills/src/lasr_skills/listen.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +import smach_ros +from lasr_speech_recognition_msgs.msg import ( + TranscribeSpeechAction, +) + + +class Listen(smach_ros.SimpleActionState): + def __init__(self): + smach_ros.SimpleActionState.__init__( + self, + "transcribe_speech", + TranscribeSpeechAction, + result_slots=["sequence"], + ) diff --git a/skills/src/lasr_skills/say.py b/skills/src/lasr_skills/say.py new file mode 100644 index 000000000..5e0bddb8d --- /dev/null +++ b/skills/src/lasr_skills/say.py @@ -0,0 +1,45 @@ +import smach_ros + +from pal_interaction_msgs.msg import TtsGoal, TtsAction, TtsText + +from typing import Union + + +class Say(smach_ros.SimpleActionState): + def __init__( + self, text: Union[str, None] = None, format_str: Union[str, None] = None + ): + if text is not None: + super(Say, self).__init__( + "tts", + TtsAction, + goal=TtsGoal(rawtext=TtsText(text=text, lang_id="en_GB")), + ) + elif format_str is not None: + super(Say, self).__init__( + "tts", + TtsAction, + goal_cb=lambda ud, _: ( + TtsGoal( + rawtext=TtsText( + text=format_str.format(*ud.placeholders), lang_id="en_GB" + ) + ) + if isinstance(ud.placeholders, (list, tuple)) + else TtsGoal( + rawtext=TtsText( + text=format_str.format(ud.placeholders), lang_id="en_GB" + ) + ) + ), + input_keys=["placeholders"], + ) + else: + super(Say, self).__init__( + "tts", + TtsAction, + goal_cb=lambda ud, _: TtsGoal( + rawtext=TtsText(text=ud.text, lang_id="en_GB") + ), + input_keys=["text"], + ) diff --git a/skills/src/lasr_skills/xml_question_answer.py b/skills/src/lasr_skills/xml_question_answer.py new file mode 100644 index 000000000..abede91b5 --- /dev/null +++ b/skills/src/lasr_skills/xml_question_answer.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +import rospy +import smach +import xml.etree.ElementTree as ET +from lasr_voice import Voice + +from lasr_vector_databases_msgs.srv import TxtQuery, TxtQueryRequest + + +def parse_question_xml(xml_file_path: str) -> dict: + """Parses the GPSR Q/A xml file and returns a dictionary + consisting of two lists, one for questions and one for answers, + where the index of each question corresponds to the index of its + corresponding answer. + + Args: + xml_file_path (str): full path to xml file to parse + + Returns: + dict: dictionary with keys "questions" and "answers" + each of which is a list of strings. + """ + tree = ET.parse(xml_file_path) + root = tree.getroot() + parsed_questions = [] + parsed_answers = [] + for q_a in root: + question = q_a.find("q").text + answer = q_a.find("a").text + parsed_questions.append(question) + parsed_answers.append(answer) + + return {"questions": parsed_questions, "answers": parsed_answers} + + +class XmlQuestionAnswer(smach.State): + + def __init__(self, index_path: str, txt_path: str, xml_path: str, k: int = 1): + smach.State.__init__( + self, + outcomes=["succeeded", "failed"], + input_keys=["query_sentence", "k"], + output_keys=["closest_answers"], + ) + self.index_path = index_path + self.txt_path = txt_path + self.xml_path = xml_path + self.k = k + self.txt_query = rospy.ServiceProxy("/lasr_faiss/txt_query", TxtQuery) + + def execute(self, userdata): + rospy.wait_for_service("/lasr_faiss/txt_query") + q_a_dict: dict = parse_question_xml(self.xml_path) + try: + request = TxtQueryRequest( + self.txt_path, + self.index_path, + userdata.query_sentence, + self.k, + ) + result = self.txt_query(request) + answers = [ + q_a_dict["answers"][q_a_dict["questions"].index(q)] + for q in result.closest_sentences + ] + userdata.closest_answers = answers + return "succeeded" + except rospy.ServiceException as e: + return "failed" diff --git a/tasks/gpsr/CMakeLists.txt b/tasks/gpsr/CMakeLists.txt new file mode 100644 index 000000000..b8c8a175f --- /dev/null +++ b/tasks/gpsr/CMakeLists.txt @@ -0,0 +1,200 @@ +cmake_minimum_required(VERSION 3.0.2) +project(gpsr) + +## Compile as C++11, supported in ROS Kinetic and newer +# add_compile_options(-std=c++11) + +## Find catkin macros and libraries +## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) +## is used, also find other catkin packages +find_package(catkin REQUIRED) + +## System dependencies are found with CMake's conventions +# find_package(Boost REQUIRED COMPONENTS system) + + +## Uncomment this if the package has a setup.py. This macro ensures +## modules and global scripts declared therein get installed +## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html +catkin_python_setup() + +################################################ +## Declare ROS messages, services and actions ## +################################################ + +## To declare and build messages, services or actions from within this +## package, follow these steps: +## * Let MSG_DEP_SET be the set of packages whose message types you use in +## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...). +## * In the file package.xml: +## * add a build_depend tag for "message_generation" +## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET +## * If MSG_DEP_SET isn't empty the following dependency has been pulled in +## but can be declared for certainty nonetheless: +## * add a exec_depend tag for "message_runtime" +## * In this file (CMakeLists.txt): +## * add "message_generation" and every package in MSG_DEP_SET to +## find_package(catkin REQUIRED COMPONENTS ...) +## * add "message_runtime" and every package in MSG_DEP_SET to +## catkin_package(CATKIN_DEPENDS ...) +## * uncomment the add_*_files sections below as needed +## and list every .msg/.srv/.action file to be processed +## * uncomment the generate_messages entry below +## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...) + +## Generate messages in the 'msg' folder +# add_message_files( +# FILES +# Message1.msg +# Message2.msg +# ) + +## Generate services in the 'srv' folder +# add_service_files( +# FILES +# ) + +## Generate actions in the 'action' folder +# add_action_files( +# FILES +# Action1.action +# Action2.action +# ) + +## Generate added messages and services with any dependencies listed here +# generate_messages( +# DEPENDENCIES +# std_msgs +# ) + +################################################ +## Declare ROS dynamic reconfigure parameters ## +################################################ + +## To declare and build dynamic reconfigure parameters within this +## package, follow these steps: +## * In the file package.xml: +## * add a build_depend and a exec_depend tag for "dynamic_reconfigure" +## * In this file (CMakeLists.txt): +## * add "dynamic_reconfigure" to +## find_package(catkin REQUIRED COMPONENTS ...) +## * uncomment the "generate_dynamic_reconfigure_options" section below +## and list every .cfg file to be processed + +## Generate dynamic reconfigure parameters in the 'cfg' folder +# generate_dynamic_reconfigure_options( +# cfg/DynReconf1.cfg +# cfg/DynReconf2.cfg +# ) + +################################### +## catkin specific configuration ## +################################### +## The catkin_package macro generates cmake config files for your package +## Declare things to be passed to dependent projects +## INCLUDE_DIRS: uncomment this if your package contains header files +## LIBRARIES: libraries you create in this project that dependent projects also need +## CATKIN_DEPENDS: catkin_packages dependent projects also need +## DEPENDS: system dependencies of this project that dependent projects also need +catkin_package( +# INCLUDE_DIRS include +# LIBRARIES coffee_shop +# CATKIN_DEPENDS other_catkin_pkg +# DEPENDS system_lib +) + +########### +## Build ## +########### + +## Specify additional locations of header files +## Your package locations should be listed before other locations +include_directories( +# include +# ${catkin_INCLUDE_DIRS} +) + +## Declare a C++ library +# add_library(${PROJECT_NAME} +# src/${PROJECT_NAME}/coffee_shop.cpp +# ) + +## Add cmake target dependencies of the library +## as an example, code may need to be generated before libraries +## either from message generation or dynamic reconfigure +# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Declare a C++ executable +## With catkin_make all packages are built within a single CMake context +## The recommended prefix ensures that target names across packages don't collide +# add_executable(${PROJECT_NAME}_node src/coffee_shop_node.cpp) + +## Rename C++ executable without prefix +## The above recommended prefix causes long target names, the following renames the +## target back to the shorter version for ease of user use +## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" +# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") + +## Add cmake target dependencies of the executable +## same as for the library above +# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Specify libraries to link a library or executable target against +# target_link_libraries(${PROJECT_NAME}_node +# ${catkin_LIBRARIES} +# ) + +############# +## Install ## +############# + +# all install targets should use catkin DESTINATION variables +# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html + +## Mark executable scripts (Python etc.) for installation +## in contrast to setup.py, you can choose the destination +catkin_install_python(PROGRAMS + scripts/parse_gpsr_xmls.py + nodes/commands/question_answer + DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +## Mark executables for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html +# install(TARGETS ${PROJECT_NAME}_node +# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark libraries for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html +# install(TARGETS ${PROJECT_NAME} +# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION} +# ) + +## Mark cpp header files for installation +# install(DIRECTORY include/${PROJECT_NAME}/ +# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} +# FILES_MATCHING PATTERN "*.h" +# PATTERN ".svn" EXCLUDE +# ) + +## Mark other files for installation (e.g. launch and bag files, etc.) +# install(FILES +# requirements.txt +# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +# ) + +############# +## Testing ## +############# + +## Add gtest based cpp test target and link libraries +# catkin_add_gtest(${PROJECT_NAME}-test test/test_coffee_shop.cpp) +# if(TARGET ${PROJECT_NAME}-test) +# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) +# endif() + +## Add folders to be run by python nosetests +# catkin_add_nosetests(test) diff --git a/tasks/gpsr/launch/commands/question_answer.launch b/tasks/gpsr/launch/commands/question_answer.launch new file mode 100644 index 000000000..2f7855173 --- /dev/null +++ b/tasks/gpsr/launch/commands/question_answer.launch @@ -0,0 +1,30 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/tasks/gpsr/nodes/commands/question_answer b/tasks/gpsr/nodes/commands/question_answer new file mode 100644 index 000000000..9c2301c6c --- /dev/null +++ b/tasks/gpsr/nodes/commands/question_answer @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +import rospy +import argparse +import smach +from lasr_skills.xml_question_answer import XmlQuestionAnswer +from lasr_skills.ask_and_listen import AskAndListen +from lasr_skills.say import Say + + +class QuestionAnswerStateMachine(smach.StateMachine): + def __init__(self, input_data: dict): + smach.StateMachine.__init__( + self, + outcomes=["succeeded", "failed"], + output_keys=["closest_answers"], + ) + self.userdata.tts_phrase = "I hear you have a question for me; ask away!" + with self: + smach.StateMachine.add( + "GET_QUESTION", + AskAndListen(), + transitions={"succeeded": "XML_QUESTION_ANSWER", "failed": "failed"}, + remapping={ + "tts_phrase:": "tts_phrase", + "transcribed_speech": "transcribed_speech", + }, + ) + smach.StateMachine.add( + "XML_QUESTION_ANSWER", + XmlQuestionAnswer( + input_data["index_path"], + input_data["txt_path"], + input_data["xml_path"], + input_data["k"], + ), + transitions={"succeeded": "SAY_ANSWER", "failed": "failed"}, + remapping={ + "query_sentence": "transcribed_speech", + "closest_answers": "closest_answers", + }, + ) + smach.StateMachine.add( + "SAY_ANSWER", + Say(format_str="The answer to your question is: {}"), + transitions={ + "succeeded": "succeeded", + "aborted": "failed", + "preempted": "failed", + }, + remapping={"placeholders": "closest_answers"}, + ) + + +def parse_args() -> dict: + parser = argparse.ArgumentParser(description="GPSR Question Answer") + parser.add_argument( + "--k", + type=int, + help="The number of closest answers to return", + required=True, + ) + parser.add_argument( + "--index_path", + type=str, + help="The path to the index file that is populated with the sentences embeddings of the questions", + required=True, + ) + parser.add_argument( + "--txt_path", + type=str, + help="The path to the txt file containing a list of questions.", + required=True, + ) + parser.add_argument( + "--xml_path", + type=str, + help="The path to the xml file containing question/answer pairs", + required=True, + ) + args, _ = parser.parse_known_args() + args.k = int(args.k) + return vars(args) + + +if __name__ == "__main__": + rospy.init_node("gpsr_question_answer") + args: dict = parse_args() + while not rospy.is_shutdown(): + q_a_sm = QuestionAnswerStateMachine(args) + outcome = q_a_sm.execute() + if outcome == "succeeded": + rospy.loginfo("Question Answer State Machine succeeded") + else: + rospy.logerr("Question Answer State Machine failed") + + rospy.spin() diff --git a/tasks/gpsr/package.xml b/tasks/gpsr/package.xml new file mode 100644 index 000000000..e94bc707c --- /dev/null +++ b/tasks/gpsr/package.xml @@ -0,0 +1,64 @@ + + + gpsr + 0.0.0 + The gpsr task package + + + + + Matt Barker + Siyao Li + Jared Swift + Nicole Lehchevska + + + + + + MIT + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + catkin + lasr_vector_databases_faiss + lasr_vector_databases_msgs + lasr_speech_recognition_whisper + + + + + + + diff --git a/tasks/gpsr/scripts/parse_gpsr_xmls.py b/tasks/gpsr/scripts/parse_gpsr_xmls.py new file mode 100644 index 000000000..e85fdf706 --- /dev/null +++ b/tasks/gpsr/scripts/parse_gpsr_xmls.py @@ -0,0 +1,27 @@ +import xml.etree.ElementTree as ET + + +def parse_question_xml(xml_file_path: str) -> dict: + """Parses the GPSR Q/A xml file and returns a dictionary + consisting of two lists, one for questions and one for answers, + where the index of each question corresponds to the index of its + corresponding answer. + + Args: + xml_file_path (str): full path to xml file to parse + + Returns: + dict: dictionary with keys "questions" and "answers" + each of which is a list of strings. + """ + tree = ET.parse(xml_file_path) + root = tree.getroot() + parsed_questions = [] + parsed_answers = [] + for q_a in root: + question = q_a.find("q").text + answer = q_a.find("a").text + parsed_questions.append(question) + parsed_answers.append(answer) + + return {"questions": parsed_questions, "answers": parsed_answers} diff --git a/tasks/gpsr/setup.py b/tasks/gpsr/setup.py new file mode 100644 index 000000000..386c709e5 --- /dev/null +++ b/tasks/gpsr/setup.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 + +from distutils.core import setup +from catkin_pkg.python_setup import generate_distutils_setup + +setup_args = generate_distutils_setup(packages=["gpsr"], package_dir={"": "src"}) + +setup(**setup_args)