diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server
index ed1c01d25..afa55f215 100644
--- a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server
+++ b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
-
import os
+import sounddevice # needed to remove ALSA error messages
import argparse
from typing import Optional
from dataclasses import dataclass
@@ -16,29 +16,6 @@ import lasr_speech_recognition_msgs.msg # type: ignore
from std_msgs.msg import String # type: ignore
from lasr_speech_recognition_whisper import load_model # type: ignore
-# Error handler to remove ALSA error messages taken from:
-# https://stackoverflow.com/questions/7088672/pyaudio-working-but-spits-out-error-messages-each-time/17673011#17673011
-
-from ctypes import *
-from contextlib import contextmanager
-
-ERROR_HANDLER_FUNC = CFUNCTYPE(None, c_char_p, c_int, c_char_p, c_int, c_char_p)
-
-
-def py_error_handler(filename, line, function, err, fmt):
- pass
-
-
-c_error_handler = ERROR_HANDLER_FUNC(py_error_handler)
-
-
-@contextmanager
-def noalsaerr():
- asound = cdll.LoadLibrary("libasound.so")
- asound.snd_lib_error_set_handler(c_error_handler)
- yield
- asound.snd_lib_error_set_handler(None)
-
@dataclass
class speech_model_params:
@@ -87,40 +64,25 @@ class TranscribeSpeechAction(object):
self._transcription_server = rospy.Publisher(
"/live_speech_transcription", String, queue_size=10
)
- with noalsaerr():
- self._model = load_model(
- self._model_params.model_name,
- self._model_params.device,
- self._model_params.warmup,
- )
- # Configure the speech recogniser object and adjust for ambient noise
- self.recogniser = self._configure_recogniser(ambient_adj=True)
- # Setup the action server and register execution callback
- self._action_server = actionlib.SimpleActionServer(
- self._action_name,
- lasr_speech_recognition_msgs.msg.TranscribeSpeechAction,
- execute_cb=self.execute_cb,
- auto_start=False,
- )
- self._action_server.register_preempt_callback(self.prempt_cb)
- # Setup the timer for adjusting the microphone for ambient noise every x seconds
- self._timer_duration = self._model_params.timer_duration
- self._timer = rospy.Timer(
- rospy.Duration(self._timer_duration), self._timer_cb
- )
- self._listening = False
- self._action_server.start()
+ self._model = load_model(
+ self._model_params.model_name,
+ self._model_params.device,
+ self._model_params.warmup,
+ )
+ # Configure the speech recogniser object and adjust for ambient noise
+ self.recogniser = self._configure_recogniser(ambient_adj=True)
+ # Setup the action server and register execution callback
+ self._action_server = actionlib.SimpleActionServer(
+ self._action_name,
+ lasr_speech_recognition_msgs.msg.TranscribeSpeechAction,
+ execute_cb=self.execute_cb,
+ auto_start=False,
+ )
+ self._action_server.register_preempt_callback(self.prempt_cb)
+ self._listening = False
- def _timer_cb(self, _) -> None:
- return
- """Adjusts the microphone for ambient noise, unless the action server is listening."""
- if self._listening:
- return
- rospy.loginfo("Adjusting microphone for ambient noise...")
- with noalsaerr():
- with self._configure_microphone() as source:
- self.recogniser.adjust_for_ambient_noise(source)
+ self._action_server.start()
def _reset_timer(self) -> None:
"""Resets the timer for adjusting the microphone for ambient noise."""
@@ -194,17 +156,13 @@ class TranscribeSpeechAction(object):
rospy.loginfo("Request Received")
if self._action_server.is_preempt_requested():
return
- # Since we are about to listen, reset the timer for adjusting the microphone for ambient noise
- # as this assumes self_timer_duration seconds of silence before adjusting
- self._reset_timer()
- with noalsaerr():
- with self._configure_microphone() as src:
- self._listening = True
- wav_data = self.recogniser.listen(
- src,
- timeout=self._model_params.start_timeout,
- phrase_time_limit=self._model_params.end_timeout,
- ).get_wav_data()
+ with self._configure_microphone() as src:
+ self._listening = True
+ wav_data = self.recogniser.listen(
+ src,
+ timeout=self._model_params.start_timeout,
+ phrase_time_limit=self._model_params.end_timeout,
+ ).get_wav_data()
# Magic number 32768.0 is the maximum value of a 16-bit signed integer
float_data = (
np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C")
@@ -293,12 +251,6 @@ def parse_args() -> dict:
default=None,
help="Microphone device index or name",
)
- parser.add_argument(
- "--timer_duration",
- type=int,
- default=20,
- help="Number of seconds of silence before the ambient noise adjustment is called.",
- )
parser.add_argument(
"--no_warmup",
action="store_true",
@@ -331,8 +283,6 @@ def configure_model_params(config: dict) -> speech_model_params:
model_params.sample_rate = config["sample_rate"]
if config["mic_device"]:
model_params.mic_device = config["mic_device"]
- if config["timer_duration"]:
- model_params.timer_duration = config["timer_duration"]
if config["no_warmup"]:
model_params.warmup = False
diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.in b/common/speech/lasr_speech_recognition_whisper/requirements.in
index 8209d34fc..da48c5086 100644
--- a/common/speech/lasr_speech_recognition_whisper/requirements.in
+++ b/common/speech/lasr_speech_recognition_whisper/requirements.in
@@ -1,5 +1,6 @@
SpeechRecognition==3.10.0
-openai-whisper==20230314
+sounddevice==0.4.6
+openai-whisper==20231117
PyAudio==0.2.13
PyYaml==6.0.1
rospkg==1.5.0
diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.txt b/common/speech/lasr_speech_recognition_whisper/requirements.txt
index a6d9bdebe..1cace21e8 100644
--- a/common/speech/lasr_speech_recognition_whisper/requirements.txt
+++ b/common/speech/lasr_speech_recognition_whisper/requirements.txt
@@ -1,51 +1,54 @@
-catkin-pkg==0.5.2 # via rospkg
-certifi==2023.7.22 # via requests
-charset-normalizer==3.2.0 # via requests
-cmake==3.27.2 # via triton
-distro==1.8.0 # via rospkg
+--extra-index-url https://pypi.ngc.nvidia.com
+--trusted-host pypi.ngc.nvidia.com
+
+catkin-pkg==1.0.0 # via rospkg
+certifi==2024.2.2 # via requests
+cffi==1.16.0 # via sounddevice
+charset-normalizer==3.3.2 # via requests
+distro==1.9.0 # via rospkg
docutils==0.20.1 # via catkin-pkg
-ffmpeg-python==0.2.0 # via openai-whisper
-filelock==3.12.2 # via torch, triton
-future==0.18.3 # via ffmpeg-python
-idna==3.4 # via requests
-jinja2==3.1.2 # via torch
-lit==16.0.6 # via triton
-llvmlite==0.40.1 # via numba
-markupsafe==2.1.3 # via jinja2
-more-itertools==10.1.0 # via openai-whisper
+filelock==3.13.1 # via torch, triton
+fsspec==2024.2.0 # via torch
+idna==3.6 # via requests
+jinja2==3.1.3 # via torch
+llvmlite==0.42.0 # via numba
+markupsafe==2.1.5 # via jinja2
+more-itertools==10.2.0 # via openai-whisper
mpmath==1.3.0 # via sympy
-networkx==3.1 # via torch
-numba==0.57.1 # via openai-whisper
-numpy==1.24.4 # via numba, openai-whisper
-nvidia-cublas-cu11==11.10.3.66 # via nvidia-cudnn-cu11, nvidia-cusolver-cu11, torch
-nvidia-cuda-cupti-cu11==11.7.101 # via torch
-nvidia-cuda-nvrtc-cu11==11.7.99 # via torch
-nvidia-cuda-runtime-cu11==11.7.99 # via torch
-nvidia-cudnn-cu11==8.5.0.96 # via torch
-nvidia-cufft-cu11==10.9.0.58 # via torch
-nvidia-curand-cu11==10.2.10.91 # via torch
-nvidia-cusolver-cu11==11.4.0.1 # via torch
-nvidia-cusparse-cu11==11.7.4.91 # via torch
-nvidia-nccl-cu11==2.14.3 # via torch
-nvidia-nvtx-cu11==11.7.91 # via torch
-openai-whisper==20230314 # via -r requirements.in
+networkx==3.2.1 # via torch
+numba==0.59.0 # via openai-whisper
+numpy==1.26.4 # via numba, openai-whisper
+nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch
+nvidia-cuda-cupti-cu12==12.1.105 # via torch
+nvidia-cuda-nvrtc-cu12==12.1.105 # via torch
+nvidia-cuda-runtime-cu12==12.1.105 # via torch
+nvidia-cudnn-cu12==8.9.2.26 # via torch
+nvidia-cufft-cu12==11.0.2.54 # via torch
+nvidia-curand-cu12==10.3.2.106 # via torch
+nvidia-cusolver-cu12==11.4.5.107 # via torch
+nvidia-cusparse-cu12==12.1.0.106 # via nvidia-cusolver-cu12, torch
+nvidia-nccl-cu12==2.19.3 # via torch
+nvidia-nvjitlink-cu12==12.4.99 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105 # via torch
+openai-whisper==20231117 # via -r requirements.in
pyaudio==0.2.13 # via -r requirements.in
-pyparsing==3.1.1 # via catkin-pkg
-python-dateutil==2.8.2 # via catkin-pkg
+pycparser==2.21 # via cffi
+pyparsing==3.1.2 # via catkin-pkg
+python-dateutil==2.9.0.post0 # via catkin-pkg
pyyaml==6.0.1 # via -r requirements.in, rospkg
-regex==2023.8.8 # via tiktoken
+regex==2023.12.25 # via tiktoken
requests==2.31.0 # via speechrecognition, tiktoken
rospkg==1.5.0 # via -r requirements.in
six==1.16.0 # via python-dateutil
+sounddevice==0.4.6 # via -r requirements.in
speechrecognition==3.10.0 # via -r requirements.in
sympy==1.12 # via torch
-tiktoken==0.3.1 # via openai-whisper
-torch==2.0.1 # via openai-whisper, triton
-tqdm==4.66.1 # via openai-whisper
-triton==2.0.0 # via openai-whisper, torch
-typing-extensions==4.7.1 # via torch
-urllib3==2.0.4 # via requests
-wheel==0.41.1 # via nvidia-cublas-cu11, nvidia-cuda-cupti-cu11, nvidia-cuda-runtime-cu11, nvidia-curand-cu11, nvidia-cusparse-cu11, nvidia-nvtx-cu11
+tiktoken==0.6.0 # via openai-whisper
+torch==2.2.1 # via openai-whisper
+tqdm==4.66.2 # via openai-whisper
+triton==2.2.0 # via openai-whisper, torch
+typing-extensions==4.10.0 # via torch
+urllib3==2.2.1 # via requests
# The following packages are considered to be unsafe in a requirements file:
# setuptools
diff --git a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt
index 1d550fbad..48659fa8c 100644
--- a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt
+++ b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt
@@ -7,7 +7,10 @@ project(lasr_vector_databases_faiss)
## Find catkin macros and libraries
## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
## is used, also find other catkin packages
-find_package(catkin REQUIRED catkin_virtualenv)
+find_package(catkin REQUIRED catkin_virtualenv COMPONENTS
+rospy
+lasr_vector_databases_msgs
+)
## System dependencies are found with CMake's conventions
# find_package(Boost REQUIRED COMPONENTS system)
@@ -55,8 +58,6 @@ catkin_generate_virtualenv(
## Generate services in the 'srv' folder
# add_service_files(
# FILES
-# Service1.srv
-# Service2.srv
# )
# Generate actions in the 'action' folder
@@ -68,8 +69,7 @@ catkin_generate_virtualenv(
# Generate added messages and services with any dependencies listed here
# generate_messages(
# DEPENDENCIES
-# actionlib_msgs
-# geometry_msgs
+# std_msgs
# )
################################################
@@ -157,22 +157,13 @@ include_directories(
## Mark executable scripts (Python etc.) for installation
## in contrast to setup.py, you can choose the destination
-# catkin_install_python(PROGRAMS
-# nodes/qualification
-# nodes/actions/wait_greet
-# nodes/actions/identify
-# nodes/actions/greet
-# nodes/actions/get_name
-# nodes/actions/learn_face
-# nodes/actions/get_command
-# nodes/actions/guide
-# nodes/actions/find_person
-# nodes/actions/detect_people
-# nodes/actions/receive_object
-# nodes/actions/handover_object
-# nodes/better_qualification
-# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
-# )
+catkin_install_python(PROGRAMS
+ nodes/txt_index_service
+ nodes/txt_query_service
+ scripts/test_index_service.py
+ scripts/test_query_service.py
+ DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+)
## Mark executables for installation
## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html
@@ -196,11 +187,10 @@ include_directories(
# )
## Mark other files for installation (e.g. launch and bag files, etc.)
-# install(FILES
-# # myfile1
-# # myfile2
-# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
-# )
+install(FILES
+ requirements.txt
+ DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+)
#############
## Testing ##
diff --git a/common/vector_databases/lasr_vector_databases_faiss/doc/TECHNICAL.md b/common/vector_databases/lasr_vector_databases_faiss/doc/TECHNICAL.md
new file mode 100644
index 000000000..e69de29bb
diff --git a/common/vector_databases/lasr_vector_databases_faiss/doc/USAGE.md b/common/vector_databases/lasr_vector_databases_faiss/doc/USAGE.md
new file mode 100644
index 000000000..d17914026
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/doc/USAGE.md
@@ -0,0 +1,33 @@
+This package currently contains two services `txt_index_service` and `txt_query_service`. These services are used to create and search (respectively) a vector database of natural language sentence embeddings.
+
+# Index Service
+The Index service is used to create a [FAISS](https://github.com/facebookresearch/faiss) index object containing a set of sentence embeddings, where each sentence is assumed to be a line in a given `.txt` file. This Index object is saved to disk at a specified location, and can be thought of as a Vector Database.
+
+## Request
+The request takes two string parameters: `txt_path` which is the path to the `.txt` file we wish to create sentence embeddings for, where each line in this file is treated as a sentence; and `index_path` which is the path to a `.index` file that will be created by the Service.
+
+## Response
+No response is given from this service.
+
+## Example Usage
+Please see the `scripts/test_index_service.py` script for a simple example of sending a request to the service.
+
+# Query Service
+The query service is used to search the `.index` file created by the Index Service to find the most similar sentences given an input query sentence.
+
+## Request
+The request requires four fields:
+
+1. `txt_path` -- this is a `string` that is the path to the txt file that contains the original sentences that the `.index` file was populated with.
+2. `index_path` -- this is a `string` that is the path to the `.index` file that was created with the Index Service, on the same txt file as the `txt_path`.
+3. `query_sentence` -- this is a `string` that is the sentence that you wish to query the index with and find the most similar sentence.
+4. `k` -- this is a `uint8` that is the number of closest sentences you wish to return.
+
+## Response
+The response contains two fields:
+
+1. `closest_sentences` -- this is an ordered list of `string`s that contain the closest sentences to the given query sentence.
+2. `cosine_similaities` -- this is an ordered list of `float32`s that contain the cosine similarity scores of the closest sentences.
+
+## Example Usage
+Please see the `scripts/test_query_service.py` script for a simple example of sending a request to the service.
\ No newline at end of file
diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service
new file mode 100644
index 000000000..a9c5c85b6
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service
@@ -0,0 +1,33 @@
+#!/usr/bin/env python3
+import rospy
+import numpy as np
+from lasr_vector_databases_msgs.srv import TxtIndexRequest, TxtIndexResponse, TxtIndex
+from lasr_vector_databases_faiss import (
+ load_model,
+ parse_txt_file,
+ get_sentence_embeddings,
+ create_vector_database,
+)
+
+
+class TxtIndexService:
+ def __init__(self):
+ rospy.init_node("txt_index_service")
+ rospy.Service("lasr_faiss/txt_index", TxtIndex, self.execute_cb)
+ self._sentence_embedding_model = load_model()
+ rospy.loginfo("Text index service started")
+
+ def execute_cb(self, req: TxtIndexRequest):
+ txt_fp: str = req.txt_path
+ sentences_to_embed: list[str] = parse_txt_file(txt_fp)
+ sentence_embeddings: np.ndarray = get_sentence_embeddings(
+ sentences_to_embed, self._sentence_embedding_model
+ )
+ index_path: str = req.index_path
+ create_vector_database(sentence_embeddings, index_path)
+ return TxtIndexResponse()
+
+
+if __name__ == "__main__":
+ TxtIndexService()
+ rospy.spin()
diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service
new file mode 100644
index 000000000..e45610fd3
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service
@@ -0,0 +1,45 @@
+#!/usr/bin/env python3
+import rospy
+import numpy as np
+
+from lasr_vector_databases_msgs.srv import (
+ TxtQueryRequest,
+ TxtQueryResponse,
+ TxtQuery,
+)
+from lasr_vector_databases_faiss import (
+ load_model,
+ parse_txt_file,
+ get_sentence_embeddings,
+ load_vector_database,
+ query_database,
+)
+
+
+class TxtQueryService:
+ def __init__(self):
+ rospy.init_node("txt_query_service")
+ self._sentence_embedding_model = load_model()
+ rospy.Service("lasr_faiss/txt_query", TxtQuery, self.execute_cb)
+ rospy.loginfo("Text Query service started")
+
+ def execute_cb(self, req: TxtQueryRequest) -> TxtQueryResponse:
+ txt_fp: str = req.txt_path
+ index_path: str = req.index_path
+ query_sentence: str = req.query_sentence
+ possible_matches: list[str] = parse_txt_file(txt_fp)
+ query_embedding: np.ndarray = get_sentence_embeddings(
+ [query_sentence], self._sentence_embedding_model # requires list of strings
+ )
+ distances, indices = query_database(index_path, query_embedding, k=req.k)
+ nearest_matches = [possible_matches[i] for i in indices[0]]
+
+ return TxtQueryResponse(
+ closest_sentences=nearest_matches,
+ cosine_similarities=distances[0].tolist(),
+ )
+
+
+if __name__ == "__main__":
+ TxtQueryService()
+ rospy.spin()
diff --git a/common/vector_databases/lasr_vector_databases_faiss/package.xml b/common/vector_databases/lasr_vector_databases_faiss/package.xml
index f8128ea56..1546f0a38 100644
--- a/common/vector_databases/lasr_vector_databases_faiss/package.xml
+++ b/common/vector_databases/lasr_vector_databases_faiss/package.xml
@@ -50,6 +50,7 @@
catkin
catkin_virtualenv
+ lasr_vector_databases_msgs
diff --git a/common/vector_databases/lasr_vector_databases_faiss/requirements.in b/common/vector_databases/lasr_vector_databases_faiss/requirements.in
index 14955d38d..3259a62c9 100644
--- a/common/vector_databases/lasr_vector_databases_faiss/requirements.in
+++ b/common/vector_databases/lasr_vector_databases_faiss/requirements.in
@@ -1 +1,3 @@
-faiss-cpu
\ No newline at end of file
+faiss-cpu
+sentence-transformers
+torch
\ No newline at end of file
diff --git a/common/vector_databases/lasr_vector_databases_faiss/scripts/test_index_service.py b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_index_service.py
new file mode 100644
index 000000000..cc7f12f3c
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_index_service.py
@@ -0,0 +1,14 @@
+#!/usr/bin/env python3
+import rospy
+from lasr_vector_databases_faiss.srv import TxtIndex, TxtIndexRequest
+
+request = TxtIndexRequest()
+
+request.txt_path = (
+ "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.txt"
+)
+
+request.index_path = (
+ "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.index"
+)
+rospy.ServiceProxy("lasr_faiss/txt_index", TxtIndex)(request)
diff --git a/common/vector_databases/lasr_vector_databases_faiss/scripts/test_query_service.py b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_query_service.py
new file mode 100644
index 000000000..4ae89e530
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_query_service.py
@@ -0,0 +1,22 @@
+#!/usr/bin/env python3
+import rospy
+from lasr_vector_databases_faiss.srv import TxtQuery, TxtQueryRequest
+
+request = TxtQueryRequest()
+
+request.txt_path = (
+ "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.txt"
+)
+
+request.index_path = (
+ "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.index"
+)
+
+request.query_sentence = "Do French like snails?"
+
+request.k = 3
+
+response = rospy.ServiceProxy("lasr_faiss/txt_query", TxtQuery)(request)
+
+print(response.closest_sentences)
+print(response.cosine_similarities)
diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py
index e69de29bb..698927d9c 100644
--- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py
+++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py
@@ -0,0 +1,2 @@
+from .database_utils import create_vector_database, load_vector_database, query_database
+from .get_sentence_embeddings import get_sentence_embeddings, load_model, parse_txt_file
diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py
deleted file mode 100755
index f7dbcbe24..000000000
--- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py
+++ /dev/null
@@ -1,157 +0,0 @@
-#!/usr/bin/env python3
-import os
-import torch
-import numpy as np
-
-# import rospy
-import faiss # type: ignore
-from sentence_transformers import SentenceTransformer # type: ignore
-from typing import Optional
-
-DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
-
-
-def load_commands(command_path: str) -> list[str]:
- """Loads the commands stored in the given txt file
- into a list of string commands
- Args:
- command_path (str): path to the txt file containing
- the commands -- assumes one command per line.
- Returns:
- list[str]: list of string commands where each entry in the
- list is a command
- """
- command_list = []
- with open(command_path, "r", encoding="utf8") as src:
- for command in src:
- # Strip newline char.
- command_list.append(command[:-1])
- return command_list
-
-
-def get_sentence_embeddings(
- command_list: list[str], model: SentenceTransformer
-) -> np.ndarray:
- """Converts the list of command strings into an array of sentence
- embeddings (where each command is a sentence and each sentence
- is converted to a vector)
- Args:
- command_list (list[str]): list of string commands, where each
- entry in the list is assumed to be a separate command
- model (SentenceTransformer): model used to perform the embedding.
- Assumes a method called encode that takes a list of strings
- as input.
- Returns:
- np.ndarray: array of shape (n_commands, embedding_dim)
- """
-
- return model.encode(
- command_list,
- convert_to_numpy=True,
- show_progress_bar=True,
- batch_size=256,
- device=DEVICE,
- )
-
-
-def create_vector_database(vectors: np.ndarray) -> faiss.IndexFlatIP:
- """Creates a vector database from an array of vectors of the same dimensionality
- Args:
- vectors (np.ndarray): shape (n_vectors, vector_dim)
-
- Returns:
- faiss.IndexFlatIP: Flat index containing the vectors
- """
- print("Creating vector database")
- index_flat = faiss.IndexFlatIP(vectors.shape[1])
- faiss.normalize_L2(vectors)
- index_flat.add(vectors)
- print("Finished creating vector database")
- return index_flat
-
-
-def get_command_database(
- index_path: str, command_path: Optional[str] = None
-) -> faiss.IndexFlatL2:
- """Gets a vector database containing a list of embedded commands. Creates the database
- if the path does not exist, else, loads it into memory.
-
- Args:
- index_path (str): Path to an existing faiss Index, or where to save a new one.
- command_path (str, optional): Path of text file containing commands.
- Only required if creating a new database. Defaults to None.
-
- Returns:
- faiss.IndexFlatL2: faiss Index object containing the embedded commands.
- """
-
- if not os.path.exists(f"{index_path}.index"):
- # rospy.loginfo("Creating new command vector database")
- assert command_path is not None
- command_list = load_commands(command_path)
- model = SentenceTransformer("all-MiniLM-L6-v2")
- command_embeddings = get_sentence_embeddings(command_list, model)
- print(command_embeddings.shape)
- command_database = create_vector_database(command_embeddings)
- faiss.write_index(command_database, f"{index_path}.index")
- # rospy.loginfo("Finished creating vector database")
- else:
- command_database = faiss.read_index(f"{index_path}.index")
-
- return command_database
-
-
-def get_similar_commands(
- command: str,
- index_path: str,
- command_path: str,
- n_similar_commands: int = 100,
- return_embeddings: bool = False,
-) -> tuple[list[str], list[float]]:
- """Gets the most similar commands to the given command string
- Args:
- command (str): command to compare against the database
- index_path (str): path to the location to create or retrieve
- the faiss index containing the embedded commands.
- command_path (str): path to the txt file containing the commands
- n_similar_commands (int, optional): number of similar commands to
- return. Defaults to 100.
- Returns:
- list[str]: list of string commands, where each entry in the
- list is a similar command
- """
- command_database = get_command_database(index_path, command_path)
- command_list = load_commands(command_path)
- model = SentenceTransformer("all-MiniLM-L6-v2")
- command_embedding = get_sentence_embeddings([command], model)
- faiss.normalize_L2(command_embedding)
- command_distances, command_indices = command_database.search(
- command_embedding, n_similar_commands
- )
- nearest_commands = [command_list[i] for i in command_indices[0]]
-
- if return_embeddings:
- all_command_embeddings = get_sentence_embeddings(command_list, model)
- # filter for only nererst commands
- all_command_embeddings = all_command_embeddings[command_indices[0]]
- return (
- nearest_commands,
- list(command_distances[0]),
- all_command_embeddings,
- command_embedding,
- )
-
- return nearest_commands, list(command_distances[0])
-
-
-if __name__ == "__main__":
- """Example usage of using this to find similar commands"""
- command = "find Jared and asks if he needs help"
- result, distances, command_embeddings, query_embedding = get_similar_commands(
- command,
- "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/qualification/data/command_index",
- "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/qualification/data/command_list.txt",
- n_similar_commands=1000,
- return_embeddings=True,
- )
- print(result)
diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py
new file mode 100644
index 000000000..6d0139072
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python3
+import os
+import numpy as np
+import faiss
+
+
+def create_vector_database(
+ vectors: np.ndarray,
+ index_path: str,
+ overwrite: bool = False,
+ index_type: str = "Flat",
+ normalise_vecs: bool = True,
+) -> None:
+ """Creates a FAISS Index using the factory constructor and the given
+ index type, and adds the given vector to the index, and then saves
+ it to disk using the given path.
+
+ Args:
+ vectors (np.ndarray): vector of shape (n_vectors, vector_dim)
+ index_path (str): path to save the index
+ overwrite (bool, optional): Whether to replace an existing index
+ at the same filepath if it exists. Defaults to False.
+ index_type (str, optional): FAISS Index Factory string. Defaults to "IndexFlatIP".
+ normalise_vecs (bool, optional): Whether to normalise the vectors before
+ adding them to the Index. This converts the IP metric to Cosine Similarity.
+ Defaults to True.
+ """
+
+ if os.path.exists(index_path) and not overwrite:
+ raise FileExistsError(
+ f"Index already exists at {index_path}. Set overwrite=True to replace it."
+ )
+
+ index = faiss.index_factory(
+ vectors.shape[1], index_type, faiss.METRIC_INNER_PRODUCT
+ )
+ if normalise_vecs:
+ faiss.normalize_L2(vectors)
+ index.add(vectors)
+ faiss.write_index(index, index_path)
+
+
+def load_vector_database(index_path: str, use_gpu: bool = False) -> faiss.Index:
+ """Loads a FAISS Index from the given filepath
+
+ Args:
+ index_path (str): path to the index file
+ use_gpu (bool, optional): Whether to load the index onto the GPU.
+ Defaults to False.
+
+ Returns:
+ faiss.Index: FAISS Index object
+ """
+ print("Loading index from", index_path)
+ index = faiss.read_index(index_path)
+ print("Loaded index with ntotal:", index.ntotal)
+ if use_gpu:
+ index = faiss.index_cpu_to_all_gpus(index)
+ return index
+
+
+def query_database(
+ index_path: str,
+ query_vectors: np.ndarray,
+ normalise: bool = True,
+ k: int = 1,
+) -> tuple[np.ndarray, np.ndarray]:
+ """Queries the given index with the given query vectors
+
+ Args:
+ index_path (str): path to the index file
+ query_vectors (np.ndarray): query vectors of shape (n_queries, vector_dim)
+ normalise (bool, optional): Whether to normalise the query vectors.
+ Defaults to True.
+ k (int, optional): Number of nearest neighbours to return. Defaults to 1.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: (distances, indices) of the nearest neighbours
+ each of shape (n_queries, n_neighbours)
+ """
+ index = load_vector_database(index_path)
+ if normalise:
+ faiss.normalize_L2(query_vectors)
+ distances, indices = index.search(query_vectors, k)
+ return distances, indices
diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py
new file mode 100644
index 000000000..1ba43b48c
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python3
+import torch
+import numpy as np
+from sentence_transformers import SentenceTransformer
+
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+def load_model(model_name: str = "all-MiniLM-L6-v2") -> SentenceTransformer:
+ """Loads the sentence transformer model
+ Args:
+ model_name (str): name of the model to load
+ Returns:
+ sentence_transformers.SentenceTransformer: the loaded model
+ """
+ return SentenceTransformer(model_name, device=DEVICE)
+
+
+def parse_txt_file(fp: str) -> list[str]:
+ """Parses a txt file into a list of strings,
+ where each element is a line in the txt file with the
+ newline char stripped.
+ Args:
+ fp (str): path to the txt file to load
+ Returns:
+ list[str]: list of strings where each element is a line in the txt file
+ """
+ sentences = []
+ with open(fp, "r", encoding="utf8") as src:
+ for line in src:
+ # Strip newline char.
+ sentences.append(line[:-1])
+ return sentences
+
+
+def get_sentence_embeddings(
+ sentence_list: list[str], model: SentenceTransformer
+) -> np.ndarray:
+ """Converts the list of string sentences into an array of sentence
+ embeddings
+ Args:
+ sentence_list (list[str]): list of string sentences, where each
+ entry in the list is assumed to be a separate sentence
+ model (SentenceTransformer): model used to perform the embedding.
+ Assumes a method called encode that takes a list of strings
+ as input.
+ Returns:
+ np.ndarray: array of shape (n_commands, embedding_dim)
+ """
+ return model.encode(
+ sentence_list,
+ convert_to_numpy=True,
+ show_progress_bar=True,
+ batch_size=256,
+ device=DEVICE,
+ )
diff --git a/common/vector_databases/lasr_vector_databases_msgs/CMakeLists.txt b/common/vector_databases/lasr_vector_databases_msgs/CMakeLists.txt
new file mode 100644
index 000000000..63fd162dc
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_msgs/CMakeLists.txt
@@ -0,0 +1,194 @@
+cmake_minimum_required(VERSION 3.0.2)
+project(lasr_vector_databases_msgs)
+
+## Compile as C++11, supported in ROS Kinetic and newer
+# add_compile_options(-std=c++11)
+
+## Find catkin macros and libraries
+## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
+## is used, also find other catkin packages
+find_package(catkin REQUIRED COMPONENTS message_generation message_runtime)
+
+## System dependencies are found with CMake's conventions
+# find_package(Boost REQUIRED COMPONENTS system)
+
+
+## Uncomment this if the package has a setup.py. This macro ensures
+## modules and global scripts declared therein get installed
+## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html
+# catkin_python_setup()
+
+################################################
+## Declare ROS messages, services and actions ##
+################################################
+
+## To declare and build messages, services or actions from within this
+## package, follow these steps:
+## * Let MSG_DEP_SET be the set of packages whose message types you use in
+## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...).
+## * In the file package.xml:
+## * add a build_depend tag for "message_generation"
+## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET
+## * If MSG_DEP_SET isn't empty the following dependency has been pulled in
+## but can be declared for certainty nonetheless:
+## * add a exec_depend tag for "message_runtime"
+## * In this file (CMakeLists.txt):
+## * add "message_generation" and every package in MSG_DEP_SET to
+## find_package(catkin REQUIRED COMPONENTS ...)
+## * add "message_runtime" and every package in MSG_DEP_SET to
+## catkin_package(CATKIN_DEPENDS ...)
+## * uncomment the add_*_files sections below as needed
+## and list every .msg/.srv/.action file to be processed
+## * uncomment the generate_messages entry below
+## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...)
+
+## Generate messages in the 'msg' folder
+# add_message_files(
+# FILES
+# )
+
+## Generate services in the 'srv' folder
+add_service_files(
+ FILES
+ TxtIndex.srv
+ TxtQuery.srv
+)
+
+# Generate actions in the 'action' folder
+
+## Generate added messages and services with any dependencies listed here
+generate_messages(
+ DEPENDENCIES
+)
+
+################################################
+## Declare ROS dynamic reconfigure parameters ##
+################################################
+
+## To declare and build dynamic reconfigure parameters within this
+## package, follow these steps:
+## * In the file package.xml:
+## * add a build_depend and a exec_depend tag for "dynamic_reconfigure"
+## * In this file (CMakeLists.txt):
+## * add "dynamic_reconfigure" to
+## find_package(catkin REQUIRED COMPONENTS ...)
+## * uncomment the "generate_dynamic_reconfigure_options" section below
+## and list every .cfg file to be processed
+
+## Generate dynamic reconfigure parameters in the 'cfg' folder
+# generate_dynamic_reconfigure_options(
+# cfg/DynReconf1.cfg
+# cfg/DynReconf2.cfg
+# )
+
+###################################
+## catkin specific configuration ##
+###################################
+## The catkin_package macro generates cmake config files for your package
+## Declare things to be passed to dependent projects
+## INCLUDE_DIRS: uncomment this if your package contains header files
+## LIBRARIES: libraries you create in this project that dependent projects also need
+## CATKIN_DEPENDS: catkin_packages dependent projects also need
+## DEPENDS: system dependencies of this project that dependent projects also need
+catkin_package(
+# INCLUDE_DIRS include
+# LIBRARIES lasr_vision_msgs
+# CATKIN_DEPENDS other_catkin_pkg
+# DEPENDS system_lib
+)
+
+###########
+## Build ##
+###########
+
+## Specify additional locations of header files
+## Your package locations should be listed before other locations
+include_directories(
+# include
+# ${catkin_INCLUDE_DIRS}
+)
+
+## Declare a C++ library
+# add_library(${PROJECT_NAME}
+# src/${PROJECT_NAME}/lasr_vision_msgs.cpp
+# )
+
+## Add cmake target dependencies of the library
+## as an example, code may need to be generated before libraries
+## either from message generation or dynamic reconfigure
+# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
+
+## Declare a C++ executable
+## With catkin_make all packages are built within a single CMake context
+## The recommended prefix ensures that target names across packages don't collide
+# add_executable(${PROJECT_NAME}_node src/lasr_vision_msgs_node.cpp)
+
+## Rename C++ executable without prefix
+## The above recommended prefix causes long target names, the following renames the
+## target back to the shorter version for ease of user use
+## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node"
+# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "")
+
+## Add cmake target dependencies of the executable
+## same as for the library above
+# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
+
+## Specify libraries to link a library or executable target against
+# target_link_libraries(${PROJECT_NAME}_node
+# ${catkin_LIBRARIES}
+# )
+
+#############
+## Install ##
+#############
+
+# all install targets should use catkin DESTINATION variables
+# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html
+
+## Mark executable scripts (Python etc.) for installation
+## in contrast to setup.py, you can choose the destination
+# catkin_install_python(PROGRAMS
+# scripts/my_python_script
+# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+# )
+
+## Mark executables for installation
+## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html
+# install(TARGETS ${PROJECT_NAME}_node
+# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+# )
+
+## Mark libraries for installation
+## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html
+# install(TARGETS ${PROJECT_NAME}
+# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION}
+# )
+
+## Mark cpp header files for installation
+# install(DIRECTORY include/${PROJECT_NAME}/
+# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}
+# FILES_MATCHING PATTERN "*.h"
+# PATTERN ".svn" EXCLUDE
+# )
+
+## Mark other files for installation (e.g. launch and bag files, etc.)
+# install(FILES
+# # myfile1
+# # myfile2
+# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+# )
+
+#############
+## Testing ##
+#############
+
+## Add gtest based cpp test target and link libraries
+# catkin_add_gtest(${PROJECT_NAME}-test test/test_lasr_vision_msgs.cpp)
+# if(TARGET ${PROJECT_NAME}-test)
+# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME})
+# endif()
+
+## Add folders to be run by python nosetests
+# catkin_add_nosetests(test)
diff --git a/common/vector_databases/lasr_vector_databases_msgs/package.xml b/common/vector_databases/lasr_vector_databases_msgs/package.xml
new file mode 100644
index 000000000..5f4e45e9f
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_msgs/package.xml
@@ -0,0 +1,61 @@
+
+
+ lasr_vector_databases_msgs
+ 0.0.0
+ Messages required for vector databases
+
+
+
+
+ Matt Barker
+
+
+
+
+
+ MIT
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ catkin
+ message_generation
+ message_runtime
+
+
+
+
+
+
+
+
diff --git a/common/vector_databases/lasr_vector_databases_msgs/srv/TxtIndex.srv b/common/vector_databases/lasr_vector_databases_msgs/srv/TxtIndex.srv
new file mode 100644
index 000000000..79ac01654
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_msgs/srv/TxtIndex.srv
@@ -0,0 +1,7 @@
+# Path to input text file
+string txt_path
+
+# Output path to save index
+string index_path
+
+---
diff --git a/common/vector_databases/lasr_vector_databases_msgs/srv/TxtQuery.srv b/common/vector_databases/lasr_vector_databases_msgs/srv/TxtQuery.srv
new file mode 100644
index 000000000..bbcb04613
--- /dev/null
+++ b/common/vector_databases/lasr_vector_databases_msgs/srv/TxtQuery.srv
@@ -0,0 +1,19 @@
+# Path to input text file
+string txt_path
+
+# Path to index file to load
+string index_path
+
+# Sentence to query index with
+string query_sentence
+
+# Number of nearest sentences to return
+uint8 k
+
+---
+# Nearest sentence
+string[] closest_sentences
+
+# Cosine similarity of distances
+float32[] cosine_similarities
+
diff --git a/common/vision/lasr_vision_clip/CMakeLists.txt b/common/vision/lasr_vision_clip/CMakeLists.txt
index a13eb6f2a..c2ce23209 100644
--- a/common/vision/lasr_vision_clip/CMakeLists.txt
+++ b/common/vision/lasr_vision_clip/CMakeLists.txt
@@ -196,11 +196,10 @@ include_directories(
# )
## Mark other files for installation (e.g. launch and bag files, etc.)
-# install(FILES
-# # myfile1
-# # myfile2
-# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
-# )
+install(FILES
+ requirements.txt
+ DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+)
#############
## Testing ##
diff --git a/skills/CMakeLists.txt b/skills/CMakeLists.txt
index cc256ca3d..6a5978ed6 100644
--- a/skills/CMakeLists.txt
+++ b/skills/CMakeLists.txt
@@ -10,6 +10,7 @@ project(lasr_skills)
find_package(catkin REQUIRED COMPONENTS
rospy
lasr_vision_msgs
+ lasr_vector_databases_msgs
)
## System dependencies are found with CMake's conventions
diff --git a/skills/package.xml b/skills/package.xml
index 44a4f5827..053a2c85b 100644
--- a/skills/package.xml
+++ b/skills/package.xml
@@ -53,6 +53,7 @@
rospy
rospy
lasr_vision_msgs
+ lasr_vector_databases_msgs
diff --git a/skills/src/lasr_skills/__init__.py b/skills/src/lasr_skills/__init__.py
index f1a0a65b6..4edfb33aa 100755
--- a/skills/src/lasr_skills/__init__.py
+++ b/skills/src/lasr_skills/__init__.py
@@ -5,6 +5,7 @@
from .wait_for_person_in_area import WaitForPersonInArea
from .describe_people import DescribePeople
from .look_to_point import LookToPoint
-from .look_to_point import LookToPoint
from .go_to_location import GoToLocation
from .go_to_semantic_location import GoToSemanticLocation
+from .listen import Listen
+from .say import Say
\ No newline at end of file
diff --git a/skills/src/lasr_skills/ask_and_listen.py b/skills/src/lasr_skills/ask_and_listen.py
new file mode 100644
index 000000000..6f110c40a
--- /dev/null
+++ b/skills/src/lasr_skills/ask_and_listen.py
@@ -0,0 +1,34 @@
+import smach
+from lasr_skills import Listen
+from lasr_skills import Say
+
+
+class AskAndListen(smach.StateMachine):
+ def __init__(self):
+ smach.StateMachine.__init__(
+ self,
+ outcomes=["succeeded", "failed"],
+ output_keys=["transcribed_speech"],
+ input_keys=["tts_phrase"],
+ )
+ with self:
+ smach.StateMachine.add(
+ "SAY",
+ Say(),
+ transitions={
+ "succeeded": "LISTEN",
+ "aborted": "failed",
+ "preempted": "failed",
+ },
+ remapping={"text": "tts_phrase"},
+ )
+ smach.StateMachine.add(
+ "LISTEN",
+ Listen(),
+ transitions={
+ "succeeded": "succeeded",
+ "aborted": "failed",
+ "preempted": "failed",
+ },
+ remapping={"sequence": "transcribed_speech"},
+ )
diff --git a/skills/src/lasr_skills/listen.py b/skills/src/lasr_skills/listen.py
new file mode 100644
index 000000000..272f24b47
--- /dev/null
+++ b/skills/src/lasr_skills/listen.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+import smach_ros
+from lasr_speech_recognition_msgs.msg import (
+ TranscribeSpeechAction,
+)
+
+
+class Listen(smach_ros.SimpleActionState):
+ def __init__(self):
+ smach_ros.SimpleActionState.__init__(
+ self,
+ "transcribe_speech",
+ TranscribeSpeechAction,
+ result_slots=["sequence"],
+ )
diff --git a/skills/src/lasr_skills/say.py b/skills/src/lasr_skills/say.py
new file mode 100644
index 000000000..5e0bddb8d
--- /dev/null
+++ b/skills/src/lasr_skills/say.py
@@ -0,0 +1,45 @@
+import smach_ros
+
+from pal_interaction_msgs.msg import TtsGoal, TtsAction, TtsText
+
+from typing import Union
+
+
+class Say(smach_ros.SimpleActionState):
+ def __init__(
+ self, text: Union[str, None] = None, format_str: Union[str, None] = None
+ ):
+ if text is not None:
+ super(Say, self).__init__(
+ "tts",
+ TtsAction,
+ goal=TtsGoal(rawtext=TtsText(text=text, lang_id="en_GB")),
+ )
+ elif format_str is not None:
+ super(Say, self).__init__(
+ "tts",
+ TtsAction,
+ goal_cb=lambda ud, _: (
+ TtsGoal(
+ rawtext=TtsText(
+ text=format_str.format(*ud.placeholders), lang_id="en_GB"
+ )
+ )
+ if isinstance(ud.placeholders, (list, tuple))
+ else TtsGoal(
+ rawtext=TtsText(
+ text=format_str.format(ud.placeholders), lang_id="en_GB"
+ )
+ )
+ ),
+ input_keys=["placeholders"],
+ )
+ else:
+ super(Say, self).__init__(
+ "tts",
+ TtsAction,
+ goal_cb=lambda ud, _: TtsGoal(
+ rawtext=TtsText(text=ud.text, lang_id="en_GB")
+ ),
+ input_keys=["text"],
+ )
diff --git a/skills/src/lasr_skills/xml_question_answer.py b/skills/src/lasr_skills/xml_question_answer.py
new file mode 100644
index 000000000..abede91b5
--- /dev/null
+++ b/skills/src/lasr_skills/xml_question_answer.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+
+import rospy
+import smach
+import xml.etree.ElementTree as ET
+from lasr_voice import Voice
+
+from lasr_vector_databases_msgs.srv import TxtQuery, TxtQueryRequest
+
+
+def parse_question_xml(xml_file_path: str) -> dict:
+ """Parses the GPSR Q/A xml file and returns a dictionary
+ consisting of two lists, one for questions and one for answers,
+ where the index of each question corresponds to the index of its
+ corresponding answer.
+
+ Args:
+ xml_file_path (str): full path to xml file to parse
+
+ Returns:
+ dict: dictionary with keys "questions" and "answers"
+ each of which is a list of strings.
+ """
+ tree = ET.parse(xml_file_path)
+ root = tree.getroot()
+ parsed_questions = []
+ parsed_answers = []
+ for q_a in root:
+ question = q_a.find("q").text
+ answer = q_a.find("a").text
+ parsed_questions.append(question)
+ parsed_answers.append(answer)
+
+ return {"questions": parsed_questions, "answers": parsed_answers}
+
+
+class XmlQuestionAnswer(smach.State):
+
+ def __init__(self, index_path: str, txt_path: str, xml_path: str, k: int = 1):
+ smach.State.__init__(
+ self,
+ outcomes=["succeeded", "failed"],
+ input_keys=["query_sentence", "k"],
+ output_keys=["closest_answers"],
+ )
+ self.index_path = index_path
+ self.txt_path = txt_path
+ self.xml_path = xml_path
+ self.k = k
+ self.txt_query = rospy.ServiceProxy("/lasr_faiss/txt_query", TxtQuery)
+
+ def execute(self, userdata):
+ rospy.wait_for_service("/lasr_faiss/txt_query")
+ q_a_dict: dict = parse_question_xml(self.xml_path)
+ try:
+ request = TxtQueryRequest(
+ self.txt_path,
+ self.index_path,
+ userdata.query_sentence,
+ self.k,
+ )
+ result = self.txt_query(request)
+ answers = [
+ q_a_dict["answers"][q_a_dict["questions"].index(q)]
+ for q in result.closest_sentences
+ ]
+ userdata.closest_answers = answers
+ return "succeeded"
+ except rospy.ServiceException as e:
+ return "failed"
diff --git a/tasks/gpsr/CMakeLists.txt b/tasks/gpsr/CMakeLists.txt
new file mode 100644
index 000000000..b8c8a175f
--- /dev/null
+++ b/tasks/gpsr/CMakeLists.txt
@@ -0,0 +1,200 @@
+cmake_minimum_required(VERSION 3.0.2)
+project(gpsr)
+
+## Compile as C++11, supported in ROS Kinetic and newer
+# add_compile_options(-std=c++11)
+
+## Find catkin macros and libraries
+## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
+## is used, also find other catkin packages
+find_package(catkin REQUIRED)
+
+## System dependencies are found with CMake's conventions
+# find_package(Boost REQUIRED COMPONENTS system)
+
+
+## Uncomment this if the package has a setup.py. This macro ensures
+## modules and global scripts declared therein get installed
+## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html
+catkin_python_setup()
+
+################################################
+## Declare ROS messages, services and actions ##
+################################################
+
+## To declare and build messages, services or actions from within this
+## package, follow these steps:
+## * Let MSG_DEP_SET be the set of packages whose message types you use in
+## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...).
+## * In the file package.xml:
+## * add a build_depend tag for "message_generation"
+## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET
+## * If MSG_DEP_SET isn't empty the following dependency has been pulled in
+## but can be declared for certainty nonetheless:
+## * add a exec_depend tag for "message_runtime"
+## * In this file (CMakeLists.txt):
+## * add "message_generation" and every package in MSG_DEP_SET to
+## find_package(catkin REQUIRED COMPONENTS ...)
+## * add "message_runtime" and every package in MSG_DEP_SET to
+## catkin_package(CATKIN_DEPENDS ...)
+## * uncomment the add_*_files sections below as needed
+## and list every .msg/.srv/.action file to be processed
+## * uncomment the generate_messages entry below
+## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...)
+
+## Generate messages in the 'msg' folder
+# add_message_files(
+# FILES
+# Message1.msg
+# Message2.msg
+# )
+
+## Generate services in the 'srv' folder
+# add_service_files(
+# FILES
+# )
+
+## Generate actions in the 'action' folder
+# add_action_files(
+# FILES
+# Action1.action
+# Action2.action
+# )
+
+## Generate added messages and services with any dependencies listed here
+# generate_messages(
+# DEPENDENCIES
+# std_msgs
+# )
+
+################################################
+## Declare ROS dynamic reconfigure parameters ##
+################################################
+
+## To declare and build dynamic reconfigure parameters within this
+## package, follow these steps:
+## * In the file package.xml:
+## * add a build_depend and a exec_depend tag for "dynamic_reconfigure"
+## * In this file (CMakeLists.txt):
+## * add "dynamic_reconfigure" to
+## find_package(catkin REQUIRED COMPONENTS ...)
+## * uncomment the "generate_dynamic_reconfigure_options" section below
+## and list every .cfg file to be processed
+
+## Generate dynamic reconfigure parameters in the 'cfg' folder
+# generate_dynamic_reconfigure_options(
+# cfg/DynReconf1.cfg
+# cfg/DynReconf2.cfg
+# )
+
+###################################
+## catkin specific configuration ##
+###################################
+## The catkin_package macro generates cmake config files for your package
+## Declare things to be passed to dependent projects
+## INCLUDE_DIRS: uncomment this if your package contains header files
+## LIBRARIES: libraries you create in this project that dependent projects also need
+## CATKIN_DEPENDS: catkin_packages dependent projects also need
+## DEPENDS: system dependencies of this project that dependent projects also need
+catkin_package(
+# INCLUDE_DIRS include
+# LIBRARIES coffee_shop
+# CATKIN_DEPENDS other_catkin_pkg
+# DEPENDS system_lib
+)
+
+###########
+## Build ##
+###########
+
+## Specify additional locations of header files
+## Your package locations should be listed before other locations
+include_directories(
+# include
+# ${catkin_INCLUDE_DIRS}
+)
+
+## Declare a C++ library
+# add_library(${PROJECT_NAME}
+# src/${PROJECT_NAME}/coffee_shop.cpp
+# )
+
+## Add cmake target dependencies of the library
+## as an example, code may need to be generated before libraries
+## either from message generation or dynamic reconfigure
+# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
+
+## Declare a C++ executable
+## With catkin_make all packages are built within a single CMake context
+## The recommended prefix ensures that target names across packages don't collide
+# add_executable(${PROJECT_NAME}_node src/coffee_shop_node.cpp)
+
+## Rename C++ executable without prefix
+## The above recommended prefix causes long target names, the following renames the
+## target back to the shorter version for ease of user use
+## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node"
+# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "")
+
+## Add cmake target dependencies of the executable
+## same as for the library above
+# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
+
+## Specify libraries to link a library or executable target against
+# target_link_libraries(${PROJECT_NAME}_node
+# ${catkin_LIBRARIES}
+# )
+
+#############
+## Install ##
+#############
+
+# all install targets should use catkin DESTINATION variables
+# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html
+
+## Mark executable scripts (Python etc.) for installation
+## in contrast to setup.py, you can choose the destination
+catkin_install_python(PROGRAMS
+ scripts/parse_gpsr_xmls.py
+ nodes/commands/question_answer
+ DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+)
+
+## Mark executables for installation
+## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html
+# install(TARGETS ${PROJECT_NAME}_node
+# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+# )
+
+## Mark libraries for installation
+## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html
+# install(TARGETS ${PROJECT_NAME}
+# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION}
+# )
+
+## Mark cpp header files for installation
+# install(DIRECTORY include/${PROJECT_NAME}/
+# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}
+# FILES_MATCHING PATTERN "*.h"
+# PATTERN ".svn" EXCLUDE
+# )
+
+## Mark other files for installation (e.g. launch and bag files, etc.)
+# install(FILES
+# requirements.txt
+# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+# )
+
+#############
+## Testing ##
+#############
+
+## Add gtest based cpp test target and link libraries
+# catkin_add_gtest(${PROJECT_NAME}-test test/test_coffee_shop.cpp)
+# if(TARGET ${PROJECT_NAME}-test)
+# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME})
+# endif()
+
+## Add folders to be run by python nosetests
+# catkin_add_nosetests(test)
diff --git a/tasks/gpsr/launch/commands/question_answer.launch b/tasks/gpsr/launch/commands/question_answer.launch
new file mode 100644
index 000000000..2f7855173
--- /dev/null
+++ b/tasks/gpsr/launch/commands/question_answer.launch
@@ -0,0 +1,30 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tasks/gpsr/nodes/commands/question_answer b/tasks/gpsr/nodes/commands/question_answer
new file mode 100644
index 000000000..9c2301c6c
--- /dev/null
+++ b/tasks/gpsr/nodes/commands/question_answer
@@ -0,0 +1,96 @@
+#!/usr/bin/env python3
+import rospy
+import argparse
+import smach
+from lasr_skills.xml_question_answer import XmlQuestionAnswer
+from lasr_skills.ask_and_listen import AskAndListen
+from lasr_skills.say import Say
+
+
+class QuestionAnswerStateMachine(smach.StateMachine):
+ def __init__(self, input_data: dict):
+ smach.StateMachine.__init__(
+ self,
+ outcomes=["succeeded", "failed"],
+ output_keys=["closest_answers"],
+ )
+ self.userdata.tts_phrase = "I hear you have a question for me; ask away!"
+ with self:
+ smach.StateMachine.add(
+ "GET_QUESTION",
+ AskAndListen(),
+ transitions={"succeeded": "XML_QUESTION_ANSWER", "failed": "failed"},
+ remapping={
+ "tts_phrase:": "tts_phrase",
+ "transcribed_speech": "transcribed_speech",
+ },
+ )
+ smach.StateMachine.add(
+ "XML_QUESTION_ANSWER",
+ XmlQuestionAnswer(
+ input_data["index_path"],
+ input_data["txt_path"],
+ input_data["xml_path"],
+ input_data["k"],
+ ),
+ transitions={"succeeded": "SAY_ANSWER", "failed": "failed"},
+ remapping={
+ "query_sentence": "transcribed_speech",
+ "closest_answers": "closest_answers",
+ },
+ )
+ smach.StateMachine.add(
+ "SAY_ANSWER",
+ Say(format_str="The answer to your question is: {}"),
+ transitions={
+ "succeeded": "succeeded",
+ "aborted": "failed",
+ "preempted": "failed",
+ },
+ remapping={"placeholders": "closest_answers"},
+ )
+
+
+def parse_args() -> dict:
+ parser = argparse.ArgumentParser(description="GPSR Question Answer")
+ parser.add_argument(
+ "--k",
+ type=int,
+ help="The number of closest answers to return",
+ required=True,
+ )
+ parser.add_argument(
+ "--index_path",
+ type=str,
+ help="The path to the index file that is populated with the sentences embeddings of the questions",
+ required=True,
+ )
+ parser.add_argument(
+ "--txt_path",
+ type=str,
+ help="The path to the txt file containing a list of questions.",
+ required=True,
+ )
+ parser.add_argument(
+ "--xml_path",
+ type=str,
+ help="The path to the xml file containing question/answer pairs",
+ required=True,
+ )
+ args, _ = parser.parse_known_args()
+ args.k = int(args.k)
+ return vars(args)
+
+
+if __name__ == "__main__":
+ rospy.init_node("gpsr_question_answer")
+ args: dict = parse_args()
+ while not rospy.is_shutdown():
+ q_a_sm = QuestionAnswerStateMachine(args)
+ outcome = q_a_sm.execute()
+ if outcome == "succeeded":
+ rospy.loginfo("Question Answer State Machine succeeded")
+ else:
+ rospy.logerr("Question Answer State Machine failed")
+
+ rospy.spin()
diff --git a/tasks/gpsr/package.xml b/tasks/gpsr/package.xml
new file mode 100644
index 000000000..e94bc707c
--- /dev/null
+++ b/tasks/gpsr/package.xml
@@ -0,0 +1,64 @@
+
+
+ gpsr
+ 0.0.0
+ The gpsr task package
+
+
+
+
+ Matt Barker
+ Siyao Li
+ Jared Swift
+ Nicole Lehchevska
+
+
+
+
+
+ MIT
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ catkin
+ lasr_vector_databases_faiss
+ lasr_vector_databases_msgs
+ lasr_speech_recognition_whisper
+
+
+
+
+
+
+
diff --git a/tasks/gpsr/scripts/parse_gpsr_xmls.py b/tasks/gpsr/scripts/parse_gpsr_xmls.py
new file mode 100644
index 000000000..e85fdf706
--- /dev/null
+++ b/tasks/gpsr/scripts/parse_gpsr_xmls.py
@@ -0,0 +1,27 @@
+import xml.etree.ElementTree as ET
+
+
+def parse_question_xml(xml_file_path: str) -> dict:
+ """Parses the GPSR Q/A xml file and returns a dictionary
+ consisting of two lists, one for questions and one for answers,
+ where the index of each question corresponds to the index of its
+ corresponding answer.
+
+ Args:
+ xml_file_path (str): full path to xml file to parse
+
+ Returns:
+ dict: dictionary with keys "questions" and "answers"
+ each of which is a list of strings.
+ """
+ tree = ET.parse(xml_file_path)
+ root = tree.getroot()
+ parsed_questions = []
+ parsed_answers = []
+ for q_a in root:
+ question = q_a.find("q").text
+ answer = q_a.find("a").text
+ parsed_questions.append(question)
+ parsed_answers.append(answer)
+
+ return {"questions": parsed_questions, "answers": parsed_answers}
diff --git a/tasks/gpsr/setup.py b/tasks/gpsr/setup.py
new file mode 100644
index 000000000..386c709e5
--- /dev/null
+++ b/tasks/gpsr/setup.py
@@ -0,0 +1,8 @@
+#!/usr/bin/env python3
+
+from distutils.core import setup
+from catkin_pkg.python_setup import generate_distutils_setup
+
+setup_args = generate_distutils_setup(packages=["gpsr"], package_dir={"": "src"})
+
+setup(**setup_args)