From 5d6e21c261df689dcd730c36e2379e17f5ce9be1 Mon Sep 17 00:00:00 2001 From: m-barker Date: Sun, 3 Mar 2024 17:54:21 +0000 Subject: [PATCH 01/23] data: add 2023 robocup gpsr xml files --- tasks/gpsr/data/Gestures.xml | 11 +++ tasks/gpsr/data/Locations.xml | 44 ++++++++++ tasks/gpsr/data/Names.xml | 58 ++++++++++++ tasks/gpsr/data/Objects.xml | 71 +++++++++++++++ tasks/gpsr/data/Questions.xml | 160 ++++++++++++++++++++++++++++++++++ 5 files changed, 344 insertions(+) create mode 100644 tasks/gpsr/data/Gestures.xml create mode 100644 tasks/gpsr/data/Locations.xml create mode 100644 tasks/gpsr/data/Names.xml create mode 100644 tasks/gpsr/data/Objects.xml create mode 100644 tasks/gpsr/data/Questions.xml diff --git a/tasks/gpsr/data/Gestures.xml b/tasks/gpsr/data/Gestures.xml new file mode 100644 index 000000000..59617e994 --- /dev/null +++ b/tasks/gpsr/data/Gestures.xml @@ -0,0 +1,11 @@ + + + + + + + + + diff --git a/tasks/gpsr/data/Locations.xml b/tasks/gpsr/data/Locations.xml new file mode 100644 index 000000000..aafef1b99 --- /dev/null +++ b/tasks/gpsr/data/Locations.xml @@ -0,0 +1,44 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tasks/gpsr/data/Names.xml b/tasks/gpsr/data/Names.xml new file mode 100644 index 000000000..cc53d5ea4 --- /dev/null +++ b/tasks/gpsr/data/Names.xml @@ -0,0 +1,58 @@ + + + + Adel + Angel + Axel + Charlie + Jane + Jules + Morgan + Paris + Robin + Simone + Adel + Angel + Axel + Charlie + James + Jules + Morgan + Paris + Robin + Simone + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tasks/gpsr/data/Objects.xml b/tasks/gpsr/data/Objects.xml new file mode 100644 index 000000000..aba1235bd --- /dev/null +++ b/tasks/gpsr/data/Objects.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tasks/gpsr/data/Questions.xml b/tasks/gpsr/data/Questions.xml new file mode 100644 index 000000000..1d5d42c83 --- /dev/null +++ b/tasks/gpsr/data/Questions.xml @@ -0,0 +1,160 @@ + + + + + + Do french like snails? + The French eat around 30,000 tons of snails a year. + + + + Would you mind me kissing you on a train? + I would. French law forbids couples from kissing on train platforms. + + + + Which French king ruled the least? + Louis XIX was the king of France for just 20 minutes, the shortest ever reign. + + + + What's the busiest train station in Europe? + Paris Gare du Nord is Europe's busiest railway station. + + + + Which is the highest mountain in Europe? + The highest mountain in Europe is Mont Blank in the French Alps. + + + + Which bread is most french, a croissant or a baguette? + The croissant was actually invented in Austria in the 13th century. + + + + Which is the most visited museum in the wrorld? + The Louvre is the most visited museum in the world. + + + + What's France's cheese production? + France produces around 1.7 million tons of cheese a year in around 1,600 varieties + + + + Which 21-stage, 23-day, 2,200-mile men's bike race is held each summer and ends at the Champs-Élysées? + That would be the Tour de France. + + + + France shares a land border with what country that also immediately follows it on an alphabetical list of the English names of E.U. nations? + I'm sure you're talking about Germany. + + + + What colour features in the national flags of all the countries that border mainland France? + Belgium, Luxemburg, Germany, Switzerland, Italy, and Spain, all have the red color in their flags. + + + + What is Vincenzo Peruggia famous for? + Vincenzo Peruggia is infamous for stealing the Mona Lisa in 1911. + + + + Which airport is the biggest and busiest in France? + The Charles de Gaulle Airport. + + + + Lyon, France is home to what border-spanning law enforcement agency? + Lyon, France is home to the Interpol. + + + + What metallic element gets its name from France's old Latin name? + The gallium element got its name from France's old Latin name + + + + Which major public square is located at the eastern end of the Champs-Elysees + The Place De La Concorde + + + + Which are the five countries that are represented at every modern Olympics since its beginning. + Australia, France, Great Britain, Greece, and Switzerland. + + + + What did Napoleon said in the Waterloo battle? + I surrender. + + + + In what city is the European Disney theme park located? + The European Disney theme park is located in Paris. + + + + How big is a nanobot? + A nanobot is 50-100nm wide. + + + + Why most computerized robot voices tend to be female? + One of the reasons is that females traditionally are lovely and caretaking. + + + + Who is the world's most expensive robot? + Honda's Asimo is the most expensive robot, costing circa $2.5 million USD. + + + + What is the main source of inspiration in robotics. + Nature, contributing to the field of bio-inspired robotics. + + + + Who crafted the word Robot? + The czech writer Karel Čapek in his 1920's play Rossum's Universal Robots + + + + What does the word Robot mean? + Labor or work. That would make me a servant. + + + + Who formulated the principles of Cybernetics in 1948. + Norbert Wiener formulated the principles of Cybernetics in 1948. + + + + Do you like super-hero movies? + Yes, I do. Zack Snyder's are the best and my favorite character is Cyborg. + + + + What did Nikola Tesla demonstrate in 1898? + In 1898, Nikola Tesla demonstrated the first radio-controlled vessel. + + + + What was developed in 1978? + The first object-level robot programming language. + + + + What is the shortest path to the Dark Side? + My A-star algorithm indicates the answer is Fear. Fear leads to anger, anger leads to hate, and hate leads to suffering. + + + + From 4cd6361cf681e4c7456f517466bea3a9a333e2b2 Mon Sep 17 00:00:00 2001 From: m-barker Date: Sun, 3 Mar 2024 18:42:40 +0000 Subject: [PATCH 02/23] feat: xml parser for gpsr q/a task component --- tasks/gpsr/scripts/parse_gpsr_xmls.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tasks/gpsr/scripts/parse_gpsr_xmls.py diff --git a/tasks/gpsr/scripts/parse_gpsr_xmls.py b/tasks/gpsr/scripts/parse_gpsr_xmls.py new file mode 100644 index 000000000..e85fdf706 --- /dev/null +++ b/tasks/gpsr/scripts/parse_gpsr_xmls.py @@ -0,0 +1,27 @@ +import xml.etree.ElementTree as ET + + +def parse_question_xml(xml_file_path: str) -> dict: + """Parses the GPSR Q/A xml file and returns a dictionary + consisting of two lists, one for questions and one for answers, + where the index of each question corresponds to the index of its + corresponding answer. + + Args: + xml_file_path (str): full path to xml file to parse + + Returns: + dict: dictionary with keys "questions" and "answers" + each of which is a list of strings. + """ + tree = ET.parse(xml_file_path) + root = tree.getroot() + parsed_questions = [] + parsed_answers = [] + for q_a in root: + question = q_a.find("q").text + answer = q_a.find("a").text + parsed_questions.append(question) + parsed_answers.append(answer) + + return {"questions": parsed_questions, "answers": parsed_answers} From 4c049d6c3c1aceb5d9b70e745b7d4e0459fbb4bb Mon Sep 17 00:00:00 2001 From: m-barker Date: Sun, 3 Mar 2024 20:23:53 +0000 Subject: [PATCH 03/23] fix: require install of requirements.tt --- .../lasr_vector_databases_faiss/CMakeLists.txt | 9 ++++----- .../lasr_vector_databases_faiss/requirements.in | 4 +++- .../src/lasr_vector_databases_faiss/__init__.py | 6 ++++++ common/vision/lasr_vision_clip/CMakeLists.txt | 9 ++++----- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt index 1d550fbad..4f3065eb3 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt +++ b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt @@ -196,11 +196,10 @@ include_directories( # ) ## Mark other files for installation (e.g. launch and bag files, etc.) -# install(FILES -# # myfile1 -# # myfile2 -# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} -# ) +install(FILES + requirements.txt + DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +) ############# ## Testing ## diff --git a/common/vector_databases/lasr_vector_databases_faiss/requirements.in b/common/vector_databases/lasr_vector_databases_faiss/requirements.in index 14955d38d..3259a62c9 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/requirements.in +++ b/common/vector_databases/lasr_vector_databases_faiss/requirements.in @@ -1 +1,3 @@ -faiss-cpu \ No newline at end of file +faiss-cpu +sentence-transformers +torch \ No newline at end of file diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py index e69de29bb..20fcf0f6d 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py @@ -0,0 +1,6 @@ +from .command_similarity import ( + get_sentence_embeddings, + create_vector_database, + faiss, + SentenceTransformer, +) diff --git a/common/vision/lasr_vision_clip/CMakeLists.txt b/common/vision/lasr_vision_clip/CMakeLists.txt index a13eb6f2a..c2ce23209 100644 --- a/common/vision/lasr_vision_clip/CMakeLists.txt +++ b/common/vision/lasr_vision_clip/CMakeLists.txt @@ -196,11 +196,10 @@ include_directories( # ) ## Mark other files for installation (e.g. launch and bag files, etc.) -# install(FILES -# # myfile1 -# # myfile2 -# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} -# ) +install(FILES + requirements.txt + DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +) ############# ## Testing ## From 0673914e8ca1ac7939f431a16cb3c8054a758548 Mon Sep 17 00:00:00 2001 From: m-barker Date: Mon, 4 Mar 2024 18:45:37 +0000 Subject: [PATCH 04/23] feat: add backend for creating & querying vector db --- .../database_utils.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py new file mode 100644 index 000000000..68bcf7bd3 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +import os +import numpy as np +import faiss + + +def create_vector_database( + vectors: np.ndarray, + index_path: str, + overwrite: bool = False, + index_type: str = "IndexFlatIP", + normalise_vecs: bool = True, +) -> None: + """Creates a FAISS Index using the factor constructor and the given + index type, and adds the given vector to the index, and then saves + it to disk using the given path. + + Args: + vectors (np.ndarray): vector of shape (n_vectors, vector_dim) + index_path (str): path to save the index + overwrite (bool, optional): Whether to replace an existing index + at the same filepath if it exists. Defaults to False. + index_type (str, optional): FAISS Index Factory string. Defaults to "IndexFlatIP". + normalise_vecs (bool, optional): Whether to normalise the vectors before + adding them to the Index. This converts the IP metric to Cosine Similarity. + Defaults to True. + """ + + if os.path.exists(index_path) and not overwrite: + raise FileExistsError( + f"Index already exists at {index_path}. Set overwrite=True to replace it." + ) + + index = faiss.index_factory(vectors.shape[1], index_type) + if normalise_vecs: + faiss.normalize_L2(vectors) + index.add(vectors) + faiss.write_index(index, index_path) + + +def load_vector_database(index_path: str, use_gpu: bool = False) -> faiss.Index: + """Loads a FAISS Index from the given filepath + + Args: + index_path (str): path to the index file + use_gpu (bool, optional): Whether to load the index onto the GPU. + Defaults to False. + + Returns: + faiss.Index: FAISS Index object + """ + index = faiss.read_index(index_path) + if use_gpu: + index = faiss.index_cpu_to_all_gpus(index) + return index + + +def query_database( + index_path: str, query_vectors: np.ndarray, normalise: bool = True +) -> tuple[np.ndarray, np.ndarray]: + """Queries the given index with the given query vectors + + Args: + index_path (str): path to the index file + query_vectors (np.ndarray): query vectors of shape (n_queries, vector_dim) + normalise (bool, optional): Whether to normalise the query vectors. + Defaults to True. + + Returns: + tuple[np.ndarray, np.ndarray]: (distances, indices) of the nearest neighbours + each of shape (n_queries, n_neighbours) + """ + index = load_vector_database(index_path) + if normalise: + faiss.normalize_L2(query_vectors) + distances, indices = index.search(query_vectors, 1) + return distances, indices From e8d898a2d116e639e3846d0631e169593b9488ca Mon Sep 17 00:00:00 2001 From: m-barker Date: Mon, 4 Mar 2024 19:05:16 +0000 Subject: [PATCH 05/23] feat: add sentence embedding utils --- .../get_sentence_embeddings.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py new file mode 100644 index 000000000..3e25293f0 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +import torch +import numpy as np +from sentence_transformers import SentenceTransformer + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +def load_model(model_name: str) -> SentenceTransformer: + """Loads the sentence transformer model + Args: + model_name (str): name of the model to load + Returns: + sentence_transformers.SentenceTransformer: the loaded model + """ + return SentenceTransformer(model_name, device=DEVICE) + + +def parse_txt_file(fp: str) -> list[str]: + """Parses a txt file into a list of strings, + where each element is a line in the txt file with the + newline char stripped. + Args: + fp (str): path to the txt file to load + Returns: + list[str]: list of strings where each element is a line in the txt file + """ + sentences = [] + with open(fp, "r", encoding="utf8") as src: + for line in src: + # Strip newline char. + sentences.append(line[:-1]) + return sentences + + +def get_sentence_embeddings( + sentence_list: list[str], model: SentenceTransformer +) -> np.ndarray: + """Converts the list of string sentences into an array of sentence + embeddings + Args: + sentece_list (list[str]): list of string sentences, where each + entry in the list is assumed to be a separate sentence + model (SentenceTransformer): model used to perform the embedding. + Assumes a method called encode that takes a list of strings + as input. + Returns: + np.ndarray: array of shape (n_commands, embedding_dim) + """ + + return model.encode( + sentence_list, + convert_to_numpy=True, + show_progress_bar=True, + batch_size=256, + device=DEVICE, + ) From b4a4ce8a6747d374c97ce24ccfce002d658186d9 Mon Sep 17 00:00:00 2001 From: m-barker Date: Tue, 5 Mar 2024 10:34:06 +0000 Subject: [PATCH 06/23] feat: service to create text FAISS index --- .../nodes/txt_index_service | 31 ++++ .../nodes/txt_query_service | 1 + .../lasr_vector_databases_faiss/__init__.py | 8 +- .../command_similarity.py | 157 ------------------ .../get_sentence_embeddings.py | 2 +- .../srv/TxtIndex.srv | 7 + 6 files changed, 42 insertions(+), 164 deletions(-) create mode 100644 common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service create mode 100644 common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service delete mode 100755 common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py create mode 100644 common/vector_databases/lasr_vector_databases_faiss/srv/TxtIndex.srv diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service new file mode 100644 index 000000000..19686176c --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service @@ -0,0 +1,31 @@ +#!/usr/bin/env python +import rospy +import numpy as np +from lasr_vector_databases_faiss.srv import TxtIndexRequest, TxtIndexResponse +from lasr_vector_databases_faiss import ( + load_model, + parse_txt_file, + get_sentence_embeddings, + create_vector_database, +) + + +class TxtIndexService: + def __init__(self): + rospy.Service("lasr_faiss/txt_index", TxtIndexResponse, self.execute_cb) + self._sentence_embedding_model = load_model() + rospy.loginfo("Text index service started") + rospy.spin() + + def execute_cb(self, req: TxtIndexRequest): + txt_fp: str = req.txt_path + sentences_to_embed: list[str] = parse_txt_file(txt_fp) + sentence_embeddings: np.ndarray = get_sentence_embeddings( + sentences_to_embed, self._sentence_embedding_model + ) + index_path: str = req.index_path + create_vector_database(sentence_embeddings, index_path) + + +if __name__ == "__main__": + TxtIndexService() diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service new file mode 100644 index 000000000..5f7ce86af --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service @@ -0,0 +1 @@ +#!/usr/bin/env python3 \ No newline at end of file diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py index 20fcf0f6d..698927d9c 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/__init__.py @@ -1,6 +1,2 @@ -from .command_similarity import ( - get_sentence_embeddings, - create_vector_database, - faiss, - SentenceTransformer, -) +from .database_utils import create_vector_database, load_vector_database, query_database +from .get_sentence_embeddings import get_sentence_embeddings, load_model, parse_txt_file diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py deleted file mode 100755 index f7dbcbe24..000000000 --- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/command_similarity.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 -import os -import torch -import numpy as np - -# import rospy -import faiss # type: ignore -from sentence_transformers import SentenceTransformer # type: ignore -from typing import Optional - -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - - -def load_commands(command_path: str) -> list[str]: - """Loads the commands stored in the given txt file - into a list of string commands - Args: - command_path (str): path to the txt file containing - the commands -- assumes one command per line. - Returns: - list[str]: list of string commands where each entry in the - list is a command - """ - command_list = [] - with open(command_path, "r", encoding="utf8") as src: - for command in src: - # Strip newline char. - command_list.append(command[:-1]) - return command_list - - -def get_sentence_embeddings( - command_list: list[str], model: SentenceTransformer -) -> np.ndarray: - """Converts the list of command strings into an array of sentence - embeddings (where each command is a sentence and each sentence - is converted to a vector) - Args: - command_list (list[str]): list of string commands, where each - entry in the list is assumed to be a separate command - model (SentenceTransformer): model used to perform the embedding. - Assumes a method called encode that takes a list of strings - as input. - Returns: - np.ndarray: array of shape (n_commands, embedding_dim) - """ - - return model.encode( - command_list, - convert_to_numpy=True, - show_progress_bar=True, - batch_size=256, - device=DEVICE, - ) - - -def create_vector_database(vectors: np.ndarray) -> faiss.IndexFlatIP: - """Creates a vector database from an array of vectors of the same dimensionality - Args: - vectors (np.ndarray): shape (n_vectors, vector_dim) - - Returns: - faiss.IndexFlatIP: Flat index containing the vectors - """ - print("Creating vector database") - index_flat = faiss.IndexFlatIP(vectors.shape[1]) - faiss.normalize_L2(vectors) - index_flat.add(vectors) - print("Finished creating vector database") - return index_flat - - -def get_command_database( - index_path: str, command_path: Optional[str] = None -) -> faiss.IndexFlatL2: - """Gets a vector database containing a list of embedded commands. Creates the database - if the path does not exist, else, loads it into memory. - - Args: - index_path (str): Path to an existing faiss Index, or where to save a new one. - command_path (str, optional): Path of text file containing commands. - Only required if creating a new database. Defaults to None. - - Returns: - faiss.IndexFlatL2: faiss Index object containing the embedded commands. - """ - - if not os.path.exists(f"{index_path}.index"): - # rospy.loginfo("Creating new command vector database") - assert command_path is not None - command_list = load_commands(command_path) - model = SentenceTransformer("all-MiniLM-L6-v2") - command_embeddings = get_sentence_embeddings(command_list, model) - print(command_embeddings.shape) - command_database = create_vector_database(command_embeddings) - faiss.write_index(command_database, f"{index_path}.index") - # rospy.loginfo("Finished creating vector database") - else: - command_database = faiss.read_index(f"{index_path}.index") - - return command_database - - -def get_similar_commands( - command: str, - index_path: str, - command_path: str, - n_similar_commands: int = 100, - return_embeddings: bool = False, -) -> tuple[list[str], list[float]]: - """Gets the most similar commands to the given command string - Args: - command (str): command to compare against the database - index_path (str): path to the location to create or retrieve - the faiss index containing the embedded commands. - command_path (str): path to the txt file containing the commands - n_similar_commands (int, optional): number of similar commands to - return. Defaults to 100. - Returns: - list[str]: list of string commands, where each entry in the - list is a similar command - """ - command_database = get_command_database(index_path, command_path) - command_list = load_commands(command_path) - model = SentenceTransformer("all-MiniLM-L6-v2") - command_embedding = get_sentence_embeddings([command], model) - faiss.normalize_L2(command_embedding) - command_distances, command_indices = command_database.search( - command_embedding, n_similar_commands - ) - nearest_commands = [command_list[i] for i in command_indices[0]] - - if return_embeddings: - all_command_embeddings = get_sentence_embeddings(command_list, model) - # filter for only nererst commands - all_command_embeddings = all_command_embeddings[command_indices[0]] - return ( - nearest_commands, - list(command_distances[0]), - all_command_embeddings, - command_embedding, - ) - - return nearest_commands, list(command_distances[0]) - - -if __name__ == "__main__": - """Example usage of using this to find similar commands""" - command = "find Jared and asks if he needs help" - result, distances, command_embeddings, query_embedding = get_similar_commands( - command, - "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/qualification/data/command_index", - "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/qualification/data/command_list.txt", - n_similar_commands=1000, - return_embeddings=True, - ) - print(result) diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py index 3e25293f0..e28189e5c 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py @@ -6,7 +6,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -def load_model(model_name: str) -> SentenceTransformer: +def load_model(model_name: str = "all-MiniLM-L6-v2") -> SentenceTransformer: """Loads the sentence transformer model Args: model_name (str): name of the model to load diff --git a/common/vector_databases/lasr_vector_databases_faiss/srv/TxtIndex.srv b/common/vector_databases/lasr_vector_databases_faiss/srv/TxtIndex.srv new file mode 100644 index 000000000..79ac01654 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/srv/TxtIndex.srv @@ -0,0 +1,7 @@ +# Path to input text file +string txt_path + +# Output path to save index +string index_path + +--- From 3d3007c906ee27767de1a3bea8382f3aef8a0b3e Mon Sep 17 00:00:00 2001 From: m-barker Date: Tue, 5 Mar 2024 11:02:07 +0000 Subject: [PATCH 07/23] feat: text query service for FAISS --- .../nodes/txt_index_service | 1 + .../nodes/txt_query_service | 41 ++++++++++++++++++- .../database_utils.py | 8 +++- .../srv/TxtQuery.srv | 19 +++++++++ 4 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 common/vector_databases/lasr_vector_databases_faiss/srv/TxtQuery.srv diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service index 19686176c..56fc511ad 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service @@ -25,6 +25,7 @@ class TxtIndexService: ) index_path: str = req.index_path create_vector_database(sentence_embeddings, index_path) + return TextIndexResponse() if __name__ == "__main__": diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service index 5f7ce86af..7ffe384b7 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service @@ -1 +1,40 @@ -#!/usr/bin/env python3 \ No newline at end of file +#!/usr/bin/env python +import rospy +import numpy as np +from lasr_vector_databases_faiss.srv import TxtQueryRequest, TxtQueryResponse +from lasr_vector_databases_faiss import ( + load_model, + parse_txt_file, + get_sentence_embeddings, + load_vector_database, + query_database, +) + + +class TxtQueryService: + def __init__(self): + rospy.Service("lasr_faiss/txt_query", TxtQueryResponse, self.execute_cb) + self._sentence_embedding_model = load_model() + rospy.loginfo("Text Query service started") + rospy.spin() + + def execute_cb(self, req: TxtQueryRequest) -> TxtQueryResponse: + txt_fp: str = req.txt_path + index_path: str = req.index_path + index = load_vector_database(index_path) + query_sentence: str = req.query_sentence + possible_matches: list[str] = parse_txt_file(txt_fp) + query_embedding: np.ndarray = get_sentence_embeddings( + [query_sentence], self._sentence_embedding_model # requires list of strings + ) + distances, indices = query_database(index, query_embedding, req.k) + nearest_matches = [possible_matches[i] for i in indices[0]] + + return TxtQueryResponse( + closest_sentences=nearest_matches, + distances=distances[0].tolist(), + ) + + +if __name__ == "__main__": + TxtQueryService() diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py index 68bcf7bd3..acb826434 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py @@ -56,7 +56,10 @@ def load_vector_database(index_path: str, use_gpu: bool = False) -> faiss.Index: def query_database( - index_path: str, query_vectors: np.ndarray, normalise: bool = True + index_path: str, + query_vectors: np.ndarray, + normalise: bool = True, + k: int = 1, ) -> tuple[np.ndarray, np.ndarray]: """Queries the given index with the given query vectors @@ -65,6 +68,7 @@ def query_database( query_vectors (np.ndarray): query vectors of shape (n_queries, vector_dim) normalise (bool, optional): Whether to normalise the query vectors. Defaults to True. + k (int, optional): Number of nearest neighbours to return. Defaults to 1. Returns: tuple[np.ndarray, np.ndarray]: (distances, indices) of the nearest neighbours @@ -73,5 +77,5 @@ def query_database( index = load_vector_database(index_path) if normalise: faiss.normalize_L2(query_vectors) - distances, indices = index.search(query_vectors, 1) + distances, indices = index.search(query_vectors, k) return distances, indices diff --git a/common/vector_databases/lasr_vector_databases_faiss/srv/TxtQuery.srv b/common/vector_databases/lasr_vector_databases_faiss/srv/TxtQuery.srv new file mode 100644 index 000000000..bb61ab204 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/srv/TxtQuery.srv @@ -0,0 +1,19 @@ +# Path to input text file +string txt_path + +# Path to index file to load +string index_path + +# Sentence to query index with +string query_sentence + +# Number of nearest sentences to return +int k + +--- +# Nearest sentence +string [] closest_sentences + +# Cosine similarity of distances +float32 [] cosine_similarities + From 898ed1dc6c98ec7c1ed03ca98a5cde9cf22710b4 Mon Sep 17 00:00:00 2001 From: m-barker Date: Tue, 5 Mar 2024 11:44:29 +0000 Subject: [PATCH 08/23] feat: working FAISS text services --- .../CMakeLists.txt | 42 +++++++------------ .../nodes/txt_index_service | 9 ++-- .../nodes/txt_query_service | 7 ++-- .../srv/TxtQuery.srv | 6 +-- 4 files changed, 27 insertions(+), 37 deletions(-) diff --git a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt index 4f3065eb3..fc06e72e6 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt +++ b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt @@ -7,7 +7,7 @@ project(lasr_vector_databases_faiss) ## Find catkin macros and libraries ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) ## is used, also find other catkin packages -find_package(catkin REQUIRED catkin_virtualenv) +find_package(catkin REQUIRED catkin_virtualenv genmsg std_msgs) ## System dependencies are found with CMake's conventions # find_package(Boost REQUIRED COMPONENTS system) @@ -53,11 +53,11 @@ catkin_generate_virtualenv( # ) ## Generate services in the 'srv' folder -# add_service_files( -# FILES -# Service1.srv -# Service2.srv -# ) +add_service_files( + FILES + TxtIndex.srv + TxtQuery.srv +) # Generate actions in the 'action' folder # add_action_files( @@ -66,11 +66,10 @@ catkin_generate_virtualenv( # ) # Generate added messages and services with any dependencies listed here -# generate_messages( -# DEPENDENCIES -# actionlib_msgs -# geometry_msgs -# ) +generate_messages( + DEPENDENCIES + std_msgs +) ################################################ ## Declare ROS dynamic reconfigure parameters ## @@ -157,22 +156,11 @@ include_directories( ## Mark executable scripts (Python etc.) for installation ## in contrast to setup.py, you can choose the destination -# catkin_install_python(PROGRAMS -# nodes/qualification -# nodes/actions/wait_greet -# nodes/actions/identify -# nodes/actions/greet -# nodes/actions/get_name -# nodes/actions/learn_face -# nodes/actions/get_command -# nodes/actions/guide -# nodes/actions/find_person -# nodes/actions/detect_people -# nodes/actions/receive_object -# nodes/actions/handover_object -# nodes/better_qualification -# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} -# ) +catkin_install_python(PROGRAMS + nodes/txt_index_service + nodes/txt_query_service + DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) ## Mark executables for installation ## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service index 56fc511ad..ed658b9be 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service @@ -1,7 +1,7 @@ #!/usr/bin/env python import rospy import numpy as np -from lasr_vector_databases_faiss.srv import TxtIndexRequest, TxtIndexResponse +from lasr_vector_databases_faiss.srv import TxtIndexRequest, TxtIndexResponse, TxtIndex from lasr_vector_databases_faiss import ( load_model, parse_txt_file, @@ -12,10 +12,10 @@ from lasr_vector_databases_faiss import ( class TxtIndexService: def __init__(self): - rospy.Service("lasr_faiss/txt_index", TxtIndexResponse, self.execute_cb) + rospy.init_node("txt_index_service") + rospy.Service("lasr_faiss/txt_index", TxtIndex, self.execute_cb) self._sentence_embedding_model = load_model() rospy.loginfo("Text index service started") - rospy.spin() def execute_cb(self, req: TxtIndexRequest): txt_fp: str = req.txt_path @@ -25,8 +25,9 @@ class TxtIndexService: ) index_path: str = req.index_path create_vector_database(sentence_embeddings, index_path) - return TextIndexResponse() + return TxtIndexResponse() if __name__ == "__main__": TxtIndexService() + rospy.spin() diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service index 7ffe384b7..25e3b63cf 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service @@ -1,7 +1,7 @@ #!/usr/bin/env python import rospy import numpy as np -from lasr_vector_databases_faiss.srv import TxtQueryRequest, TxtQueryResponse +from lasr_vector_databases_faiss.srv import TxtQueryRequest, TxtQueryResponse, TxtQuery from lasr_vector_databases_faiss import ( load_model, parse_txt_file, @@ -13,10 +13,10 @@ from lasr_vector_databases_faiss import ( class TxtQueryService: def __init__(self): - rospy.Service("lasr_faiss/txt_query", TxtQueryResponse, self.execute_cb) + rospy.init_node("txt_query_service") + rospy.Service("lasr_faiss/txt_query", TxtQuery, self.execute_cb) self._sentence_embedding_model = load_model() rospy.loginfo("Text Query service started") - rospy.spin() def execute_cb(self, req: TxtQueryRequest) -> TxtQueryResponse: txt_fp: str = req.txt_path @@ -38,3 +38,4 @@ class TxtQueryService: if __name__ == "__main__": TxtQueryService() + rospy.spin() diff --git a/common/vector_databases/lasr_vector_databases_faiss/srv/TxtQuery.srv b/common/vector_databases/lasr_vector_databases_faiss/srv/TxtQuery.srv index bb61ab204..bbcb04613 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/srv/TxtQuery.srv +++ b/common/vector_databases/lasr_vector_databases_faiss/srv/TxtQuery.srv @@ -8,12 +8,12 @@ string index_path string query_sentence # Number of nearest sentences to return -int k +uint8 k --- # Nearest sentence -string [] closest_sentences +string[] closest_sentences # Cosine similarity of distances -float32 [] cosine_similarities +float32[] cosine_similarities From 04ddae3c62c222cd70892324b55b010a8a0c89e1 Mon Sep 17 00:00:00 2001 From: m-barker Date: Tue, 5 Mar 2024 14:52:26 +0000 Subject: [PATCH 09/23] feat: test scripts for FAISS services --- .../CMakeLists.txt | 8 ++++++- .../nodes/txt_index_service | 2 +- .../nodes/txt_query_service | 7 +++--- .../lasr_vector_databases_faiss/package.xml | 2 ++ .../scripts/test_index_service.py | 14 ++++++++++++ .../scripts/test_query_service.py | 22 +++++++++++++++++++ .../database_utils.py | 8 +++++-- 7 files changed, 55 insertions(+), 8 deletions(-) create mode 100644 common/vector_databases/lasr_vector_databases_faiss/scripts/test_index_service.py create mode 100644 common/vector_databases/lasr_vector_databases_faiss/scripts/test_query_service.py diff --git a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt index fc06e72e6..12fe169e8 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt +++ b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt @@ -7,7 +7,11 @@ project(lasr_vector_databases_faiss) ## Find catkin macros and libraries ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) ## is used, also find other catkin packages -find_package(catkin REQUIRED catkin_virtualenv genmsg std_msgs) +find_package(catkin REQUIRED catkin_virtualenv COMPONENTS +rospy +std_msgs +message_generation +) ## System dependencies are found with CMake's conventions # find_package(Boost REQUIRED COMPONENTS system) @@ -159,6 +163,8 @@ include_directories( catkin_install_python(PROGRAMS nodes/txt_index_service nodes/txt_query_service + scripts/test_index_service.py + scripts/test_query_service.py DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} ) diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service index ed658b9be..2d7ce7949 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import rospy import numpy as np from lasr_vector_databases_faiss.srv import TxtIndexRequest, TxtIndexResponse, TxtIndex diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service index 25e3b63cf..dae0970a2 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import rospy import numpy as np from lasr_vector_databases_faiss.srv import TxtQueryRequest, TxtQueryResponse, TxtQuery @@ -21,18 +21,17 @@ class TxtQueryService: def execute_cb(self, req: TxtQueryRequest) -> TxtQueryResponse: txt_fp: str = req.txt_path index_path: str = req.index_path - index = load_vector_database(index_path) query_sentence: str = req.query_sentence possible_matches: list[str] = parse_txt_file(txt_fp) query_embedding: np.ndarray = get_sentence_embeddings( [query_sentence], self._sentence_embedding_model # requires list of strings ) - distances, indices = query_database(index, query_embedding, req.k) + distances, indices = query_database(index_path, query_embedding, k=req.k) nearest_matches = [possible_matches[i] for i in indices[0]] return TxtQueryResponse( closest_sentences=nearest_matches, - distances=distances[0].tolist(), + cosine_similarities=distances[0].tolist(), ) diff --git a/common/vector_databases/lasr_vector_databases_faiss/package.xml b/common/vector_databases/lasr_vector_databases_faiss/package.xml index f8128ea56..55594d8d3 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/package.xml +++ b/common/vector_databases/lasr_vector_databases_faiss/package.xml @@ -50,6 +50,8 @@ catkin catkin_virtualenv + message_generation + message_runtime diff --git a/common/vector_databases/lasr_vector_databases_faiss/scripts/test_index_service.py b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_index_service.py new file mode 100644 index 000000000..cc7f12f3c --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_index_service.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +import rospy +from lasr_vector_databases_faiss.srv import TxtIndex, TxtIndexRequest + +request = TxtIndexRequest() + +request.txt_path = ( + "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.txt" +) + +request.index_path = ( + "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.index" +) +rospy.ServiceProxy("lasr_faiss/txt_index", TxtIndex)(request) diff --git a/common/vector_databases/lasr_vector_databases_faiss/scripts/test_query_service.py b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_query_service.py new file mode 100644 index 000000000..4ae89e530 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/scripts/test_query_service.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +import rospy +from lasr_vector_databases_faiss.srv import TxtQuery, TxtQueryRequest + +request = TxtQueryRequest() + +request.txt_path = ( + "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.txt" +) + +request.index_path = ( + "/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.index" +) + +request.query_sentence = "Do French like snails?" + +request.k = 3 + +response = rospy.ServiceProxy("lasr_faiss/txt_query", TxtQuery)(request) + +print(response.closest_sentences) +print(response.cosine_similarities) diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py index acb826434..7a901ff3f 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py @@ -8,7 +8,7 @@ def create_vector_database( vectors: np.ndarray, index_path: str, overwrite: bool = False, - index_type: str = "IndexFlatIP", + index_type: str = "Flat", normalise_vecs: bool = True, ) -> None: """Creates a FAISS Index using the factor constructor and the given @@ -31,7 +31,9 @@ def create_vector_database( f"Index already exists at {index_path}. Set overwrite=True to replace it." ) - index = faiss.index_factory(vectors.shape[1], index_type) + index = faiss.index_factory( + vectors.shape[1], index_type, faiss.METRIC_INNER_PRODUCT + ) if normalise_vecs: faiss.normalize_L2(vectors) index.add(vectors) @@ -49,7 +51,9 @@ def load_vector_database(index_path: str, use_gpu: bool = False) -> faiss.Index: Returns: faiss.Index: FAISS Index object """ + print("Loading index from", index_path) index = faiss.read_index(index_path) + print("Loaded index with ntotal:", index.ntotal) if use_gpu: index = faiss.index_cpu_to_all_gpus(index) return index From 73216966d83f70553589ee05bc4286ff8ff21670 Mon Sep 17 00:00:00 2001 From: m-barker Date: Tue, 5 Mar 2024 16:33:45 +0000 Subject: [PATCH 10/23] docs: add documenation for FAISS vector service --- .../doc/TECHNICAL.md | 0 .../lasr_vector_databases_faiss/doc/USAGE.md | 33 +++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 common/vector_databases/lasr_vector_databases_faiss/doc/TECHNICAL.md create mode 100644 common/vector_databases/lasr_vector_databases_faiss/doc/USAGE.md diff --git a/common/vector_databases/lasr_vector_databases_faiss/doc/TECHNICAL.md b/common/vector_databases/lasr_vector_databases_faiss/doc/TECHNICAL.md new file mode 100644 index 000000000..e69de29bb diff --git a/common/vector_databases/lasr_vector_databases_faiss/doc/USAGE.md b/common/vector_databases/lasr_vector_databases_faiss/doc/USAGE.md new file mode 100644 index 000000000..d17914026 --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_faiss/doc/USAGE.md @@ -0,0 +1,33 @@ +This package currently contains two services `txt_index_service` and `txt_query_service`. These services are used to create and search (respectively) a vector database of natural language sentence embeddings. + +# Index Service +The Index service is used to create a [FAISS](https://github.com/facebookresearch/faiss) index object containing a set of sentence embeddings, where each sentence is assumed to be a line in a given `.txt` file. This Index object is saved to disk at a specified location, and can be thought of as a Vector Database. + +## Request +The request takes two string parameters: `txt_path` which is the path to the `.txt` file we wish to create sentence embeddings for, where each line in this file is treated as a sentence; and `index_path` which is the path to a `.index` file that will be created by the Service. + +## Response +No response is given from this service. + +## Example Usage +Please see the `scripts/test_index_service.py` script for a simple example of sending a request to the service. + +# Query Service +The query service is used to search the `.index` file created by the Index Service to find the most similar sentences given an input query sentence. + +## Request +The request requires four fields: + +1. `txt_path` -- this is a `string` that is the path to the txt file that contains the original sentences that the `.index` file was populated with. +2. `index_path` -- this is a `string` that is the path to the `.index` file that was created with the Index Service, on the same txt file as the `txt_path`. +3. `query_sentence` -- this is a `string` that is the sentence that you wish to query the index with and find the most similar sentence. +4. `k` -- this is a `uint8` that is the number of closest sentences you wish to return. + +## Response +The response contains two fields: + +1. `closest_sentences` -- this is an ordered list of `string`s that contain the closest sentences to the given query sentence. +2. `cosine_similaities` -- this is an ordered list of `float32`s that contain the cosine similarity scores of the closest sentences. + +## Example Usage +Please see the `scripts/test_query_service.py` script for a simple example of sending a request to the service. \ No newline at end of file From 86b964b88210d3c59453322130d5e4a2597b7274 Mon Sep 17 00:00:00 2001 From: m-barker Date: Tue, 5 Mar 2024 16:37:55 +0000 Subject: [PATCH 11/23] chore: remove xml files --- tasks/gpsr/data/Gestures.xml | 11 --- tasks/gpsr/data/Locations.xml | 44 ---------- tasks/gpsr/data/Names.xml | 58 ------------ tasks/gpsr/data/Objects.xml | 71 --------------- tasks/gpsr/data/Questions.xml | 160 ---------------------------------- 5 files changed, 344 deletions(-) delete mode 100644 tasks/gpsr/data/Gestures.xml delete mode 100644 tasks/gpsr/data/Locations.xml delete mode 100644 tasks/gpsr/data/Names.xml delete mode 100644 tasks/gpsr/data/Objects.xml delete mode 100644 tasks/gpsr/data/Questions.xml diff --git a/tasks/gpsr/data/Gestures.xml b/tasks/gpsr/data/Gestures.xml deleted file mode 100644 index 59617e994..000000000 --- a/tasks/gpsr/data/Gestures.xml +++ /dev/null @@ -1,11 +0,0 @@ - - - - - - - - - diff --git a/tasks/gpsr/data/Locations.xml b/tasks/gpsr/data/Locations.xml deleted file mode 100644 index aafef1b99..000000000 --- a/tasks/gpsr/data/Locations.xml +++ /dev/null @@ -1,44 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tasks/gpsr/data/Names.xml b/tasks/gpsr/data/Names.xml deleted file mode 100644 index cc53d5ea4..000000000 --- a/tasks/gpsr/data/Names.xml +++ /dev/null @@ -1,58 +0,0 @@ - - - - Adel - Angel - Axel - Charlie - Jane - Jules - Morgan - Paris - Robin - Simone - Adel - Angel - Axel - Charlie - James - Jules - Morgan - Paris - Robin - Simone - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tasks/gpsr/data/Objects.xml b/tasks/gpsr/data/Objects.xml deleted file mode 100644 index aba1235bd..000000000 --- a/tasks/gpsr/data/Objects.xml +++ /dev/null @@ -1,71 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tasks/gpsr/data/Questions.xml b/tasks/gpsr/data/Questions.xml deleted file mode 100644 index 1d5d42c83..000000000 --- a/tasks/gpsr/data/Questions.xml +++ /dev/null @@ -1,160 +0,0 @@ - - - - - - Do french like snails? - The French eat around 30,000 tons of snails a year. - - - - Would you mind me kissing you on a train? - I would. French law forbids couples from kissing on train platforms. - - - - Which French king ruled the least? - Louis XIX was the king of France for just 20 minutes, the shortest ever reign. - - - - What's the busiest train station in Europe? - Paris Gare du Nord is Europe's busiest railway station. - - - - Which is the highest mountain in Europe? - The highest mountain in Europe is Mont Blank in the French Alps. - - - - Which bread is most french, a croissant or a baguette? - The croissant was actually invented in Austria in the 13th century. - - - - Which is the most visited museum in the wrorld? - The Louvre is the most visited museum in the world. - - - - What's France's cheese production? - France produces around 1.7 million tons of cheese a year in around 1,600 varieties - - - - Which 21-stage, 23-day, 2,200-mile men's bike race is held each summer and ends at the Champs-Élysées? - That would be the Tour de France. - - - - France shares a land border with what country that also immediately follows it on an alphabetical list of the English names of E.U. nations? - I'm sure you're talking about Germany. - - - - What colour features in the national flags of all the countries that border mainland France? - Belgium, Luxemburg, Germany, Switzerland, Italy, and Spain, all have the red color in their flags. - - - - What is Vincenzo Peruggia famous for? - Vincenzo Peruggia is infamous for stealing the Mona Lisa in 1911. - - - - Which airport is the biggest and busiest in France? - The Charles de Gaulle Airport. - - - - Lyon, France is home to what border-spanning law enforcement agency? - Lyon, France is home to the Interpol. - - - - What metallic element gets its name from France's old Latin name? - The gallium element got its name from France's old Latin name - - - - Which major public square is located at the eastern end of the Champs-Elysees - The Place De La Concorde - - - - Which are the five countries that are represented at every modern Olympics since its beginning. - Australia, France, Great Britain, Greece, and Switzerland. - - - - What did Napoleon said in the Waterloo battle? - I surrender. - - - - In what city is the European Disney theme park located? - The European Disney theme park is located in Paris. - - - - How big is a nanobot? - A nanobot is 50-100nm wide. - - - - Why most computerized robot voices tend to be female? - One of the reasons is that females traditionally are lovely and caretaking. - - - - Who is the world's most expensive robot? - Honda's Asimo is the most expensive robot, costing circa $2.5 million USD. - - - - What is the main source of inspiration in robotics. - Nature, contributing to the field of bio-inspired robotics. - - - - Who crafted the word Robot? - The czech writer Karel Čapek in his 1920's play Rossum's Universal Robots - - - - What does the word Robot mean? - Labor or work. That would make me a servant. - - - - Who formulated the principles of Cybernetics in 1948. - Norbert Wiener formulated the principles of Cybernetics in 1948. - - - - Do you like super-hero movies? - Yes, I do. Zack Snyder's are the best and my favorite character is Cyborg. - - - - What did Nikola Tesla demonstrate in 1898? - In 1898, Nikola Tesla demonstrated the first radio-controlled vessel. - - - - What was developed in 1978? - The first object-level robot programming language. - - - - What is the shortest path to the Dark Side? - My A-star algorithm indicates the answer is Fear. Fear leads to anger, anger leads to hate, and hate leads to suffering. - - - - From db0750f767a34d9717012ebe9792b584ac9234f7 Mon Sep 17 00:00:00 2001 From: m-barker Date: Wed, 6 Mar 2024 09:52:50 +0000 Subject: [PATCH 12/23] feat: xml question answer state --- skills/src/lasr_skills/xml_question_answer.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 skills/src/lasr_skills/xml_question_answer.py diff --git a/skills/src/lasr_skills/xml_question_answer.py b/skills/src/lasr_skills/xml_question_answer.py new file mode 100644 index 000000000..338abf98e --- /dev/null +++ b/skills/src/lasr_skills/xml_question_answer.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +import rospy +import smach +import xml.etree.ElementTree as ET + +from lasr_vector_databases_faiss.srv import TxtQuery, TxtQueryRequest + + +def parse_question_xml(xml_file_path: str) -> dict: + """Parses the GPSR Q/A xml file and returns a dictionary + consisting of two lists, one for questions and one for answers, + where the index of each question corresponds to the index of its + corresponding answer. + + Args: + xml_file_path (str): full path to xml file to parse + + Returns: + dict: dictionary with keys "questions" and "answers" + each of which is a list of strings. + """ + tree = ET.parse(xml_file_path) + root = tree.getroot() + parsed_questions = [] + parsed_answers = [] + for q_a in root: + question = q_a.find("q").text + answer = q_a.find("a").text + parsed_questions.append(question) + parsed_answers.append(answer) + + return {"questions": parsed_questions, "answers": parsed_answers} + + +class XmlQuestionAnswer(smach.State): + + def __init__(self): + smach.State.__init__( + self, + outcomes=["succeeded", "failed"], + input_keys=["query_sentence", "k", "index_path", "txt_path", "xml_path"], + output_keys=["closest_answers"], + ) + self.txt_query = rospy.ServiceProxy("/lasr_faiss/txt_query", TxtQuery) + + def execute(self, userdata): + q_a_dict: dict = parse_question_xml(userdata.xml_path) + try: + request = TxtQueryRequest( + userdata.txt_path, + userdata.index_path, + userdata.query_sentence, + userdata.k, + ) + result = self.txt_query(request) + answers = [ + q_a_dict["answers"][q_a_dict["questions"].index(q)] + for q in result.closest_sentences + ] + userdata.closest_answers = answers + return "succeeded" + except rospy.ServiceException as e: + rospy.logwarn(f"Unable to perform Index Query. ({str(e)})") + return "failed" From 9d8a9f1c3b4141299a64c7964fdc768b19b9d095 Mon Sep 17 00:00:00 2001 From: m-barker Date: Wed, 6 Mar 2024 13:58:04 +0000 Subject: [PATCH 13/23] feat: separate package for vector db messages --- .../CMakeLists.txt | 19 +- .../nodes/txt_index_service | 2 +- .../nodes/txt_query_service | 10 +- .../lasr_vector_databases_faiss/package.xml | 3 +- .../lasr_vector_databases_msgs/CMakeLists.txt | 194 ++++++++++++++++++ .../lasr_vector_databases_msgs/package.xml | 61 ++++++ .../srv/TxtIndex.srv | 0 .../srv/TxtQuery.srv | 0 8 files changed, 273 insertions(+), 16 deletions(-) create mode 100644 common/vector_databases/lasr_vector_databases_msgs/CMakeLists.txt create mode 100644 common/vector_databases/lasr_vector_databases_msgs/package.xml rename common/vector_databases/{lasr_vector_databases_faiss => lasr_vector_databases_msgs}/srv/TxtIndex.srv (100%) rename common/vector_databases/{lasr_vector_databases_faiss => lasr_vector_databases_msgs}/srv/TxtQuery.srv (100%) diff --git a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt index 12fe169e8..48659fa8c 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt +++ b/common/vector_databases/lasr_vector_databases_faiss/CMakeLists.txt @@ -9,8 +9,7 @@ project(lasr_vector_databases_faiss) ## is used, also find other catkin packages find_package(catkin REQUIRED catkin_virtualenv COMPONENTS rospy -std_msgs -message_generation +lasr_vector_databases_msgs ) ## System dependencies are found with CMake's conventions @@ -57,11 +56,9 @@ catkin_generate_virtualenv( # ) ## Generate services in the 'srv' folder -add_service_files( - FILES - TxtIndex.srv - TxtQuery.srv -) +# add_service_files( +# FILES +# ) # Generate actions in the 'action' folder # add_action_files( @@ -70,10 +67,10 @@ add_service_files( # ) # Generate added messages and services with any dependencies listed here -generate_messages( - DEPENDENCIES - std_msgs -) +# generate_messages( +# DEPENDENCIES +# std_msgs +# ) ################################################ ## Declare ROS dynamic reconfigure parameters ## diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service index 2d7ce7949..a9c5c85b6 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_index_service @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import rospy import numpy as np -from lasr_vector_databases_faiss.srv import TxtIndexRequest, TxtIndexResponse, TxtIndex +from lasr_vector_databases_msgs.srv import TxtIndexRequest, TxtIndexResponse, TxtIndex from lasr_vector_databases_faiss import ( load_model, parse_txt_file, diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service index dae0970a2..8dec15270 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service @@ -1,7 +1,12 @@ #!/usr/bin/env python3 import rospy import numpy as np -from lasr_vector_databases_faiss.srv import TxtQueryRequest, TxtQueryResponse, TxtQuery + +from lasr_vector_databases_msgs.srv import ( + TxtQueryRequest, + TxtQueryResponse, + TxtQuery, +) from lasr_vector_databases_faiss import ( load_model, parse_txt_file, @@ -14,8 +19,9 @@ from lasr_vector_databases_faiss import ( class TxtQueryService: def __init__(self): rospy.init_node("txt_query_service") - rospy.Service("lasr_faiss/txt_query", TxtQuery, self.execute_cb) self._sentence_embedding_model = load_model() + print(self._sentence_embedding_model) + rospy.Service("lasr_faiss/txt_query", TxtQuery, self.execute_cb) rospy.loginfo("Text Query service started") def execute_cb(self, req: TxtQueryRequest) -> TxtQueryResponse: diff --git a/common/vector_databases/lasr_vector_databases_faiss/package.xml b/common/vector_databases/lasr_vector_databases_faiss/package.xml index 55594d8d3..1546f0a38 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/package.xml +++ b/common/vector_databases/lasr_vector_databases_faiss/package.xml @@ -50,8 +50,7 @@ catkin catkin_virtualenv - message_generation - message_runtime + lasr_vector_databases_msgs diff --git a/common/vector_databases/lasr_vector_databases_msgs/CMakeLists.txt b/common/vector_databases/lasr_vector_databases_msgs/CMakeLists.txt new file mode 100644 index 000000000..63fd162dc --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_msgs/CMakeLists.txt @@ -0,0 +1,194 @@ +cmake_minimum_required(VERSION 3.0.2) +project(lasr_vector_databases_msgs) + +## Compile as C++11, supported in ROS Kinetic and newer +# add_compile_options(-std=c++11) + +## Find catkin macros and libraries +## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) +## is used, also find other catkin packages +find_package(catkin REQUIRED COMPONENTS message_generation message_runtime) + +## System dependencies are found with CMake's conventions +# find_package(Boost REQUIRED COMPONENTS system) + + +## Uncomment this if the package has a setup.py. This macro ensures +## modules and global scripts declared therein get installed +## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html +# catkin_python_setup() + +################################################ +## Declare ROS messages, services and actions ## +################################################ + +## To declare and build messages, services or actions from within this +## package, follow these steps: +## * Let MSG_DEP_SET be the set of packages whose message types you use in +## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...). +## * In the file package.xml: +## * add a build_depend tag for "message_generation" +## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET +## * If MSG_DEP_SET isn't empty the following dependency has been pulled in +## but can be declared for certainty nonetheless: +## * add a exec_depend tag for "message_runtime" +## * In this file (CMakeLists.txt): +## * add "message_generation" and every package in MSG_DEP_SET to +## find_package(catkin REQUIRED COMPONENTS ...) +## * add "message_runtime" and every package in MSG_DEP_SET to +## catkin_package(CATKIN_DEPENDS ...) +## * uncomment the add_*_files sections below as needed +## and list every .msg/.srv/.action file to be processed +## * uncomment the generate_messages entry below +## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...) + +## Generate messages in the 'msg' folder +# add_message_files( +# FILES +# ) + +## Generate services in the 'srv' folder +add_service_files( + FILES + TxtIndex.srv + TxtQuery.srv +) + +# Generate actions in the 'action' folder + +## Generate added messages and services with any dependencies listed here +generate_messages( + DEPENDENCIES +) + +################################################ +## Declare ROS dynamic reconfigure parameters ## +################################################ + +## To declare and build dynamic reconfigure parameters within this +## package, follow these steps: +## * In the file package.xml: +## * add a build_depend and a exec_depend tag for "dynamic_reconfigure" +## * In this file (CMakeLists.txt): +## * add "dynamic_reconfigure" to +## find_package(catkin REQUIRED COMPONENTS ...) +## * uncomment the "generate_dynamic_reconfigure_options" section below +## and list every .cfg file to be processed + +## Generate dynamic reconfigure parameters in the 'cfg' folder +# generate_dynamic_reconfigure_options( +# cfg/DynReconf1.cfg +# cfg/DynReconf2.cfg +# ) + +################################### +## catkin specific configuration ## +################################### +## The catkin_package macro generates cmake config files for your package +## Declare things to be passed to dependent projects +## INCLUDE_DIRS: uncomment this if your package contains header files +## LIBRARIES: libraries you create in this project that dependent projects also need +## CATKIN_DEPENDS: catkin_packages dependent projects also need +## DEPENDS: system dependencies of this project that dependent projects also need +catkin_package( +# INCLUDE_DIRS include +# LIBRARIES lasr_vision_msgs +# CATKIN_DEPENDS other_catkin_pkg +# DEPENDS system_lib +) + +########### +## Build ## +########### + +## Specify additional locations of header files +## Your package locations should be listed before other locations +include_directories( +# include +# ${catkin_INCLUDE_DIRS} +) + +## Declare a C++ library +# add_library(${PROJECT_NAME} +# src/${PROJECT_NAME}/lasr_vision_msgs.cpp +# ) + +## Add cmake target dependencies of the library +## as an example, code may need to be generated before libraries +## either from message generation or dynamic reconfigure +# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Declare a C++ executable +## With catkin_make all packages are built within a single CMake context +## The recommended prefix ensures that target names across packages don't collide +# add_executable(${PROJECT_NAME}_node src/lasr_vision_msgs_node.cpp) + +## Rename C++ executable without prefix +## The above recommended prefix causes long target names, the following renames the +## target back to the shorter version for ease of user use +## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" +# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") + +## Add cmake target dependencies of the executable +## same as for the library above +# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Specify libraries to link a library or executable target against +# target_link_libraries(${PROJECT_NAME}_node +# ${catkin_LIBRARIES} +# ) + +############# +## Install ## +############# + +# all install targets should use catkin DESTINATION variables +# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html + +## Mark executable scripts (Python etc.) for installation +## in contrast to setup.py, you can choose the destination +# catkin_install_python(PROGRAMS +# scripts/my_python_script +# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark executables for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html +# install(TARGETS ${PROJECT_NAME}_node +# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark libraries for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html +# install(TARGETS ${PROJECT_NAME} +# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION} +# ) + +## Mark cpp header files for installation +# install(DIRECTORY include/${PROJECT_NAME}/ +# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} +# FILES_MATCHING PATTERN "*.h" +# PATTERN ".svn" EXCLUDE +# ) + +## Mark other files for installation (e.g. launch and bag files, etc.) +# install(FILES +# # myfile1 +# # myfile2 +# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +# ) + +############# +## Testing ## +############# + +## Add gtest based cpp test target and link libraries +# catkin_add_gtest(${PROJECT_NAME}-test test/test_lasr_vision_msgs.cpp) +# if(TARGET ${PROJECT_NAME}-test) +# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) +# endif() + +## Add folders to be run by python nosetests +# catkin_add_nosetests(test) diff --git a/common/vector_databases/lasr_vector_databases_msgs/package.xml b/common/vector_databases/lasr_vector_databases_msgs/package.xml new file mode 100644 index 000000000..0fab0d65a --- /dev/null +++ b/common/vector_databases/lasr_vector_databases_msgs/package.xml @@ -0,0 +1,61 @@ + + + lasr_vector_databases_msgs + 0.0.0 + Messages required for vector database + + + + + Paul Makles + + + + + + MIT + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + catkin + message_generation + message_runtime + + + + + + + + diff --git a/common/vector_databases/lasr_vector_databases_faiss/srv/TxtIndex.srv b/common/vector_databases/lasr_vector_databases_msgs/srv/TxtIndex.srv similarity index 100% rename from common/vector_databases/lasr_vector_databases_faiss/srv/TxtIndex.srv rename to common/vector_databases/lasr_vector_databases_msgs/srv/TxtIndex.srv diff --git a/common/vector_databases/lasr_vector_databases_faiss/srv/TxtQuery.srv b/common/vector_databases/lasr_vector_databases_msgs/srv/TxtQuery.srv similarity index 100% rename from common/vector_databases/lasr_vector_databases_faiss/srv/TxtQuery.srv rename to common/vector_databases/lasr_vector_databases_msgs/srv/TxtQuery.srv From b012e2399aeccd2e0c5ab7f94253373078f39feb Mon Sep 17 00:00:00 2001 From: m-barker Date: Wed, 6 Mar 2024 13:58:46 +0000 Subject: [PATCH 14/23] feat: working state machine for gpsr QA task --- skills/CMakeLists.txt | 1 + skills/package.xml | 1 + skills/src/lasr_skills/xml_question_answer.py | 4 +- tasks/gpsr/CMakeLists.txt | 200 ++++++++++++++++++ tasks/gpsr/launch/question_answer.launch | 23 ++ tasks/gpsr/nodes/question_answer | 87 ++++++++ tasks/gpsr/package.xml | 60 ++++++ tasks/gpsr/setup.py | 8 + 8 files changed, 383 insertions(+), 1 deletion(-) create mode 100644 tasks/gpsr/CMakeLists.txt create mode 100644 tasks/gpsr/launch/question_answer.launch create mode 100644 tasks/gpsr/nodes/question_answer create mode 100644 tasks/gpsr/package.xml create mode 100644 tasks/gpsr/setup.py diff --git a/skills/CMakeLists.txt b/skills/CMakeLists.txt index cc256ca3d..6a5978ed6 100644 --- a/skills/CMakeLists.txt +++ b/skills/CMakeLists.txt @@ -10,6 +10,7 @@ project(lasr_skills) find_package(catkin REQUIRED COMPONENTS rospy lasr_vision_msgs + lasr_vector_databases_msgs ) ## System dependencies are found with CMake's conventions diff --git a/skills/package.xml b/skills/package.xml index be2d32977..afee2817b 100644 --- a/skills/package.xml +++ b/skills/package.xml @@ -53,6 +53,7 @@ rospy rospy lasr_vision_msgs + lasr_vector_databases_msgs diff --git a/skills/src/lasr_skills/xml_question_answer.py b/skills/src/lasr_skills/xml_question_answer.py index 338abf98e..7736ffc70 100644 --- a/skills/src/lasr_skills/xml_question_answer.py +++ b/skills/src/lasr_skills/xml_question_answer.py @@ -4,7 +4,7 @@ import smach import xml.etree.ElementTree as ET -from lasr_vector_databases_faiss.srv import TxtQuery, TxtQueryRequest +from lasr_vector_databases_msgs.srv import TxtQuery, TxtQueryRequest def parse_question_xml(xml_file_path: str) -> dict: @@ -45,6 +45,7 @@ def __init__(self): self.txt_query = rospy.ServiceProxy("/lasr_faiss/txt_query", TxtQuery) def execute(self, userdata): + rospy.wait_for_service("/lasr_faiss/txt_query") q_a_dict: dict = parse_question_xml(userdata.xml_path) try: request = TxtQueryRequest( @@ -62,4 +63,5 @@ def execute(self, userdata): return "succeeded" except rospy.ServiceException as e: rospy.logwarn(f"Unable to perform Index Query. ({str(e)})") + userdata.closest_answers = [] return "failed" diff --git a/tasks/gpsr/CMakeLists.txt b/tasks/gpsr/CMakeLists.txt new file mode 100644 index 000000000..37cbd376b --- /dev/null +++ b/tasks/gpsr/CMakeLists.txt @@ -0,0 +1,200 @@ +cmake_minimum_required(VERSION 3.0.2) +project(gpsr) + +## Compile as C++11, supported in ROS Kinetic and newer +# add_compile_options(-std=c++11) + +## Find catkin macros and libraries +## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) +## is used, also find other catkin packages +find_package(catkin REQUIRED) + +## System dependencies are found with CMake's conventions +# find_package(Boost REQUIRED COMPONENTS system) + + +## Uncomment this if the package has a setup.py. This macro ensures +## modules and global scripts declared therein get installed +## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html +catkin_python_setup() + +################################################ +## Declare ROS messages, services and actions ## +################################################ + +## To declare and build messages, services or actions from within this +## package, follow these steps: +## * Let MSG_DEP_SET be the set of packages whose message types you use in +## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...). +## * In the file package.xml: +## * add a build_depend tag for "message_generation" +## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET +## * If MSG_DEP_SET isn't empty the following dependency has been pulled in +## but can be declared for certainty nonetheless: +## * add a exec_depend tag for "message_runtime" +## * In this file (CMakeLists.txt): +## * add "message_generation" and every package in MSG_DEP_SET to +## find_package(catkin REQUIRED COMPONENTS ...) +## * add "message_runtime" and every package in MSG_DEP_SET to +## catkin_package(CATKIN_DEPENDS ...) +## * uncomment the add_*_files sections below as needed +## and list every .msg/.srv/.action file to be processed +## * uncomment the generate_messages entry below +## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...) + +## Generate messages in the 'msg' folder +# add_message_files( +# FILES +# Message1.msg +# Message2.msg +# ) + +## Generate services in the 'srv' folder +# add_service_files( +# FILES +# ) + +## Generate actions in the 'action' folder +# add_action_files( +# FILES +# Action1.action +# Action2.action +# ) + +## Generate added messages and services with any dependencies listed here +# generate_messages( +# DEPENDENCIES +# std_msgs +# ) + +################################################ +## Declare ROS dynamic reconfigure parameters ## +################################################ + +## To declare and build dynamic reconfigure parameters within this +## package, follow these steps: +## * In the file package.xml: +## * add a build_depend and a exec_depend tag for "dynamic_reconfigure" +## * In this file (CMakeLists.txt): +## * add "dynamic_reconfigure" to +## find_package(catkin REQUIRED COMPONENTS ...) +## * uncomment the "generate_dynamic_reconfigure_options" section below +## and list every .cfg file to be processed + +## Generate dynamic reconfigure parameters in the 'cfg' folder +# generate_dynamic_reconfigure_options( +# cfg/DynReconf1.cfg +# cfg/DynReconf2.cfg +# ) + +################################### +## catkin specific configuration ## +################################### +## The catkin_package macro generates cmake config files for your package +## Declare things to be passed to dependent projects +## INCLUDE_DIRS: uncomment this if your package contains header files +## LIBRARIES: libraries you create in this project that dependent projects also need +## CATKIN_DEPENDS: catkin_packages dependent projects also need +## DEPENDS: system dependencies of this project that dependent projects also need +catkin_package( +# INCLUDE_DIRS include +# LIBRARIES coffee_shop +# CATKIN_DEPENDS other_catkin_pkg +# DEPENDS system_lib +) + +########### +## Build ## +########### + +## Specify additional locations of header files +## Your package locations should be listed before other locations +include_directories( +# include +# ${catkin_INCLUDE_DIRS} +) + +## Declare a C++ library +# add_library(${PROJECT_NAME} +# src/${PROJECT_NAME}/coffee_shop.cpp +# ) + +## Add cmake target dependencies of the library +## as an example, code may need to be generated before libraries +## either from message generation or dynamic reconfigure +# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Declare a C++ executable +## With catkin_make all packages are built within a single CMake context +## The recommended prefix ensures that target names across packages don't collide +# add_executable(${PROJECT_NAME}_node src/coffee_shop_node.cpp) + +## Rename C++ executable without prefix +## The above recommended prefix causes long target names, the following renames the +## target back to the shorter version for ease of user use +## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" +# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") + +## Add cmake target dependencies of the executable +## same as for the library above +# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Specify libraries to link a library or executable target against +# target_link_libraries(${PROJECT_NAME}_node +# ${catkin_LIBRARIES} +# ) + +############# +## Install ## +############# + +# all install targets should use catkin DESTINATION variables +# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html + +## Mark executable scripts (Python etc.) for installation +## in contrast to setup.py, you can choose the destination +catkin_install_python(PROGRAMS + scripts/parse_gpsr_xmls.py + nodes/question_answer + DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +## Mark executables for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html +# install(TARGETS ${PROJECT_NAME}_node +# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark libraries for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html +# install(TARGETS ${PROJECT_NAME} +# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION} +# ) + +## Mark cpp header files for installation +# install(DIRECTORY include/${PROJECT_NAME}/ +# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} +# FILES_MATCHING PATTERN "*.h" +# PATTERN ".svn" EXCLUDE +# ) + +## Mark other files for installation (e.g. launch and bag files, etc.) +# install(FILES +# requirements.txt +# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +# ) + +############# +## Testing ## +############# + +## Add gtest based cpp test target and link libraries +# catkin_add_gtest(${PROJECT_NAME}-test test/test_coffee_shop.cpp) +# if(TARGET ${PROJECT_NAME}-test) +# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) +# endif() + +## Add folders to be run by python nosetests +# catkin_add_nosetests(test) diff --git a/tasks/gpsr/launch/question_answer.launch b/tasks/gpsr/launch/question_answer.launch new file mode 100644 index 000000000..60128141b --- /dev/null +++ b/tasks/gpsr/launch/question_answer.launch @@ -0,0 +1,23 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/tasks/gpsr/nodes/question_answer b/tasks/gpsr/nodes/question_answer new file mode 100644 index 000000000..9c715a460 --- /dev/null +++ b/tasks/gpsr/nodes/question_answer @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +import rospy +import argparse +import smach +from lasr_skills.xml_question_answer import XmlQuestionAnswer + + +class QuestionAnswerStateMachine(smach.StateMachine): + def __init__(self, input_data: dict): + smach.StateMachine.__init__( + self, + outcomes=["succeeded", "failed"], + output_keys=["closest_answers"], + ) + self.userdata.query_sentence = input_data["question"] + self.userdata.k = input_data["k"] + self.userdata.index_path = input_data["index_path"] + self.userdata.txt_path = input_data["txt_path"] + self.userdata.xml_path = input_data["xml_path"] + print(self.userdata) + + with self: + smach.StateMachine.add( + "XML_QUESTION_ANSWER", + XmlQuestionAnswer(), + transitions={"succeeded": "succeeded", "failed": "failed"}, + remapping={ + "query_sentence": "query_sentence", + "k": "k", + "index_path": "index_path", + "txt_path": "txt_path", + "xml_path": "xml_path", + "closest_answers": "closest_answers", + }, + ) + + +def parse_args() -> dict: + parser = argparse.ArgumentParser(description="GPSR Question Answer") + parser.add_argument( + "--question", + type=str, + help="The question to query", + required=True, + ) + parser.add_argument( + "--k", + type=int, + help="The number of closest answers to return", + required=True, + ) + parser.add_argument( + "--index_path", + type=str, + help="The path to the index file that is populated with the sentences embeddings of the questions", + required=True, + ) + parser.add_argument( + "--txt_path", + type=str, + help="The path to the txt file containing a list of questions.", + required=True, + ) + parser.add_argument( + "--xml_path", + type=str, + help="The path to the xml file containing question/answer pairs", + required=True, + ) + args, _ = parser.parse_known_args() + args.k = int(args.k) + return vars(args) + + +if __name__ == "__main__": + rospy.init_node("gpsr_question_answer") + args: dict = parse_args() + print(args) + q_a_sm = QuestionAnswerStateMachine(args) + outcome = q_a_sm.execute() + if outcome == "succeeded": + rospy.loginfo(f"Question: {args['question']}") + rospy.loginfo(f"Closest Answers: {q_a_sm.userdata.closest_answers}") + else: + rospy.logerr("Question Answer State Machine failed") + + rospy.spin() diff --git a/tasks/gpsr/package.xml b/tasks/gpsr/package.xml new file mode 100644 index 000000000..d9d8689e5 --- /dev/null +++ b/tasks/gpsr/package.xml @@ -0,0 +1,60 @@ + + + gpsr + 0.0.0 + The gpsr task package + + + + + Matt Barker + + + + + + MIT + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + catkin + lasr_vector_databases_faiss + lasr_vector_databases_msgs + + + + + + + diff --git a/tasks/gpsr/setup.py b/tasks/gpsr/setup.py new file mode 100644 index 000000000..386c709e5 --- /dev/null +++ b/tasks/gpsr/setup.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 + +from distutils.core import setup +from catkin_pkg.python_setup import generate_distutils_setup + +setup_args = generate_distutils_setup(packages=["gpsr"], package_dir={"": "src"}) + +setup(**setup_args) From 014c6e9fcdcb591e0526b78fc0e030bd50f91e98 Mon Sep 17 00:00:00 2001 From: m-barker Date: Wed, 6 Mar 2024 14:30:37 +0000 Subject: [PATCH 15/23] feat: tts question answering --- skills/src/lasr_skills/xml_question_answer.py | 5 +++ tasks/gpsr/launch/question_answer.launch | 3 +- tasks/gpsr/nodes/question_answer | 22 +++++++----- tasks/gpsr/src/gpsr/states/get_question.py | 35 +++++++++++++++++++ 4 files changed, 55 insertions(+), 10 deletions(-) create mode 100644 tasks/gpsr/src/gpsr/states/get_question.py diff --git a/skills/src/lasr_skills/xml_question_answer.py b/skills/src/lasr_skills/xml_question_answer.py index 7736ffc70..ca58bc64b 100644 --- a/skills/src/lasr_skills/xml_question_answer.py +++ b/skills/src/lasr_skills/xml_question_answer.py @@ -3,6 +3,7 @@ import rospy import smach import xml.etree.ElementTree as ET +from lasr_voice import Voice from lasr_vector_databases_msgs.srv import TxtQuery, TxtQueryRequest @@ -64,4 +65,8 @@ def execute(self, userdata): except rospy.ServiceException as e: rospy.logwarn(f"Unable to perform Index Query. ({str(e)})") userdata.closest_answers = [] + voice = Voice() + voice.sync_tts( + "I'm sorry, I couldn't find an answer to your question. Please ask me another question." + ) return "failed" diff --git a/tasks/gpsr/launch/question_answer.launch b/tasks/gpsr/launch/question_answer.launch index 60128141b..87de56930 100644 --- a/tasks/gpsr/launch/question_answer.launch +++ b/tasks/gpsr/launch/question_answer.launch @@ -1,5 +1,4 @@ - @@ -17,7 +16,7 @@ type="question_answer" name="question_answer" output="screen" - args="--question $(arg question) --k $(arg k) --index_path $(arg index_path) --txt_path $(arg text_path) --xml_path $(arg xml_path)" + args="--k $(arg k) --index_path $(arg index_path) --txt_path $(arg text_path) --xml_path $(arg xml_path)" /> \ No newline at end of file diff --git a/tasks/gpsr/nodes/question_answer b/tasks/gpsr/nodes/question_answer index 9c715a460..fe00ce725 100644 --- a/tasks/gpsr/nodes/question_answer +++ b/tasks/gpsr/nodes/question_answer @@ -2,7 +2,9 @@ import rospy import argparse import smach +from lasr_voice import Voice from lasr_skills.xml_question_answer import XmlQuestionAnswer +from gpsr.states.get_question import GetQuestion class QuestionAnswerStateMachine(smach.StateMachine): @@ -12,7 +14,6 @@ class QuestionAnswerStateMachine(smach.StateMachine): outcomes=["succeeded", "failed"], output_keys=["closest_answers"], ) - self.userdata.query_sentence = input_data["question"] self.userdata.k = input_data["k"] self.userdata.index_path = input_data["index_path"] self.userdata.txt_path = input_data["txt_path"] @@ -20,10 +21,16 @@ class QuestionAnswerStateMachine(smach.StateMachine): print(self.userdata) with self: + smach.StateMachine.add( + "GET_QUESTION", + GetQuestion(), + transitions={"succeeded": "XML_QUESTION_ANSWER", "failed": "failed"}, + remapping={"question": "query_sentence"}, + ) smach.StateMachine.add( "XML_QUESTION_ANSWER", XmlQuestionAnswer(), - transitions={"succeeded": "succeeded", "failed": "failed"}, + transitions={"succeeded": "succeeded", "failed": "GET_QUESTION"}, remapping={ "query_sentence": "query_sentence", "k": "k", @@ -37,12 +44,6 @@ class QuestionAnswerStateMachine(smach.StateMachine): def parse_args() -> dict: parser = argparse.ArgumentParser(description="GPSR Question Answer") - parser.add_argument( - "--question", - type=str, - help="The question to query", - required=True, - ) parser.add_argument( "--k", type=int, @@ -78,10 +79,15 @@ if __name__ == "__main__": print(args) q_a_sm = QuestionAnswerStateMachine(args) outcome = q_a_sm.execute() + voice = Voice() if outcome == "succeeded": rospy.loginfo(f"Question: {args['question']}") rospy.loginfo(f"Closest Answers: {q_a_sm.userdata.closest_answers}") + voice.sync_tts( + f"The answer to your question is: {q_a_sm.userdata.closest_answers[0]}" + ) else: rospy.logerr("Question Answer State Machine failed") + voice.sync_tts(f"Sorry, I wasn't able to find an answer to your question") rospy.spin() diff --git a/tasks/gpsr/src/gpsr/states/get_question.py b/tasks/gpsr/src/gpsr/states/get_question.py new file mode 100644 index 000000000..e7719034c --- /dev/null +++ b/tasks/gpsr/src/gpsr/states/get_question.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +import smach +import rospy +import actionlib +from lasr_voice import Voice +from lasr_speech_recognition_msgs.msg import ( + TranscribeSpeechAction, + TranscribeSpeechGoal, +) + + +class GetQuestion(smach.State): + def __init__(self): + smach.State.__init__( + self, outcomes=["succeeded", "failed"], output_keys=["question"] + ) + self.voice = Voice() + self.client = actionlib.SimpleActionClient( + "transcribe_speech", TranscribeSpeechAction + ) + + def execute(self, userdata): + try: + self.client.wait_for_server() + self.voice.sync_tts("Hello, I hear you have a question for me, ask away!") + goal = TranscribeSpeechGoal() + self.client.send_goal(goal) + self.client.wait_for_result() + result = self.client.get_result() + text = result.sequence + userdata.question = text + return "succeeded" + except Exception as e: + rospy.loginfo(f"Failed to get question: {e}") + return "failed" From e234f9dd62c61e69befa33d349f6edce2d15cb22 Mon Sep 17 00:00:00 2001 From: m-barker Date: Wed, 6 Mar 2024 15:57:43 +0000 Subject: [PATCH 16/23] feat: working launch file and quesiton answering with TTS --- tasks/gpsr/launch/question_answer.launch | 8 ++++++++ tasks/gpsr/nodes/question_answer | 25 ++++++++++++------------ tasks/gpsr/package.xml | 1 + 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/tasks/gpsr/launch/question_answer.launch b/tasks/gpsr/launch/question_answer.launch index 87de56930..79c3a6fb2 100644 --- a/tasks/gpsr/launch/question_answer.launch +++ b/tasks/gpsr/launch/question_answer.launch @@ -19,4 +19,12 @@ args="--k $(arg k) --index_path $(arg index_path) --txt_path $(arg text_path) --xml_path $(arg xml_path)" /> + + \ No newline at end of file diff --git a/tasks/gpsr/nodes/question_answer b/tasks/gpsr/nodes/question_answer index fe00ce725..e7541a5b0 100644 --- a/tasks/gpsr/nodes/question_answer +++ b/tasks/gpsr/nodes/question_answer @@ -76,18 +76,17 @@ def parse_args() -> dict: if __name__ == "__main__": rospy.init_node("gpsr_question_answer") args: dict = parse_args() - print(args) - q_a_sm = QuestionAnswerStateMachine(args) - outcome = q_a_sm.execute() - voice = Voice() - if outcome == "succeeded": - rospy.loginfo(f"Question: {args['question']}") - rospy.loginfo(f"Closest Answers: {q_a_sm.userdata.closest_answers}") - voice.sync_tts( - f"The answer to your question is: {q_a_sm.userdata.closest_answers[0]}" - ) - else: - rospy.logerr("Question Answer State Machine failed") - voice.sync_tts(f"Sorry, I wasn't able to find an answer to your question") + while not rospy.is_shutdown(): + q_a_sm = QuestionAnswerStateMachine(args) + outcome = q_a_sm.execute() + voice = Voice() + if outcome == "succeeded": + rospy.loginfo(f"Closest Answers: {q_a_sm.userdata.closest_answers}") + voice.sync_tts( + f"The answer to your question is: {q_a_sm.userdata.closest_answers[0]}" + ) + else: + rospy.logerr("Question Answer State Machine failed") + voice.sync_tts(f"Sorry, I wasn't able to find an answer to your question") rospy.spin() diff --git a/tasks/gpsr/package.xml b/tasks/gpsr/package.xml index d9d8689e5..cf974db17 100644 --- a/tasks/gpsr/package.xml +++ b/tasks/gpsr/package.xml @@ -51,6 +51,7 @@ catkin lasr_vector_databases_faiss lasr_vector_databases_msgs + lasr_speech_recognition_whisper From d72654256d6e2871ff633e4675ee3f40c23c6e5c Mon Sep 17 00:00:00 2001 From: m-barker Date: Mon, 11 Mar 2024 14:53:37 +0000 Subject: [PATCH 17/23] fix: integrate @jws-1 review suggestions --- .../nodes/txt_query_service | 1 - .../database_utils.py | 2 +- .../get_sentence_embeddings.py | 3 +-- .../lasr_vector_databases_msgs/package.xml | 4 ++-- skills/src/lasr_skills/xml_question_answer.py | 20 +++++++++---------- tasks/gpsr/package.xml | 3 +++ 6 files changed, 16 insertions(+), 17 deletions(-) diff --git a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service index 8dec15270..e45610fd3 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service +++ b/common/vector_databases/lasr_vector_databases_faiss/nodes/txt_query_service @@ -20,7 +20,6 @@ class TxtQueryService: def __init__(self): rospy.init_node("txt_query_service") self._sentence_embedding_model = load_model() - print(self._sentence_embedding_model) rospy.Service("lasr_faiss/txt_query", TxtQuery, self.execute_cb) rospy.loginfo("Text Query service started") diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py index 7a901ff3f..6d0139072 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/database_utils.py @@ -11,7 +11,7 @@ def create_vector_database( index_type: str = "Flat", normalise_vecs: bool = True, ) -> None: - """Creates a FAISS Index using the factor constructor and the given + """Creates a FAISS Index using the factory constructor and the given index type, and adds the given vector to the index, and then saves it to disk using the given path. diff --git a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py index e28189e5c..1ba43b48c 100644 --- a/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py +++ b/common/vector_databases/lasr_vector_databases_faiss/src/lasr_vector_databases_faiss/get_sentence_embeddings.py @@ -39,7 +39,7 @@ def get_sentence_embeddings( """Converts the list of string sentences into an array of sentence embeddings Args: - sentece_list (list[str]): list of string sentences, where each + sentence_list (list[str]): list of string sentences, where each entry in the list is assumed to be a separate sentence model (SentenceTransformer): model used to perform the embedding. Assumes a method called encode that takes a list of strings @@ -47,7 +47,6 @@ def get_sentence_embeddings( Returns: np.ndarray: array of shape (n_commands, embedding_dim) """ - return model.encode( sentence_list, convert_to_numpy=True, diff --git a/common/vector_databases/lasr_vector_databases_msgs/package.xml b/common/vector_databases/lasr_vector_databases_msgs/package.xml index 0fab0d65a..5f4e45e9f 100644 --- a/common/vector_databases/lasr_vector_databases_msgs/package.xml +++ b/common/vector_databases/lasr_vector_databases_msgs/package.xml @@ -2,12 +2,12 @@ lasr_vector_databases_msgs 0.0.0 - Messages required for vector database + Messages required for vector databases - Paul Makles + Matt Barker diff --git a/skills/src/lasr_skills/xml_question_answer.py b/skills/src/lasr_skills/xml_question_answer.py index ca58bc64b..df8e7564f 100644 --- a/skills/src/lasr_skills/xml_question_answer.py +++ b/skills/src/lasr_skills/xml_question_answer.py @@ -36,13 +36,17 @@ def parse_question_xml(xml_file_path: str) -> dict: class XmlQuestionAnswer(smach.State): - def __init__(self): + def __init__(self, index_path: str, txt_path: str, xml_path: str, k: int = 1): smach.State.__init__( self, outcomes=["succeeded", "failed"], - input_keys=["query_sentence", "k", "index_path", "txt_path", "xml_path"], + input_keys=["query_sentence", "k"], output_keys=["closest_answers"], ) + self.index_path = index_path + self.txt_path = txt_path + self.xml_path = xml_path + self.k = k self.txt_query = rospy.ServiceProxy("/lasr_faiss/txt_query", TxtQuery) def execute(self, userdata): @@ -50,10 +54,10 @@ def execute(self, userdata): q_a_dict: dict = parse_question_xml(userdata.xml_path) try: request = TxtQueryRequest( - userdata.txt_path, - userdata.index_path, + self.txt_path, + self.index_path, userdata.query_sentence, - userdata.k, + self.k, ) result = self.txt_query(request) answers = [ @@ -63,10 +67,4 @@ def execute(self, userdata): userdata.closest_answers = answers return "succeeded" except rospy.ServiceException as e: - rospy.logwarn(f"Unable to perform Index Query. ({str(e)})") - userdata.closest_answers = [] - voice = Voice() - voice.sync_tts( - "I'm sorry, I couldn't find an answer to your question. Please ask me another question." - ) return "failed" diff --git a/tasks/gpsr/package.xml b/tasks/gpsr/package.xml index cf974db17..e94bc707c 100644 --- a/tasks/gpsr/package.xml +++ b/tasks/gpsr/package.xml @@ -8,6 +8,9 @@ Matt Barker + Siyao Li + Jared Swift + Nicole Lehchevska From 397ac6fecb5d2ffbb17935ccbed4a3ca23858bed Mon Sep 17 00:00:00 2001 From: m-barker Date: Mon, 11 Mar 2024 15:36:02 +0000 Subject: [PATCH 18/23] feat: add listen state --- skills/src/lasr_skills/listen.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 skills/src/lasr_skills/listen.py diff --git a/skills/src/lasr_skills/listen.py b/skills/src/lasr_skills/listen.py new file mode 100644 index 000000000..0e595f59f --- /dev/null +++ b/skills/src/lasr_skills/listen.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +import smach_ros +from lasr_speech_recognition_msgs.msg import ( + TranscribeSpeechAction, + TranscribeSpeechGoal, +) + + +class Listen(smach_ros.SimpleActionState): + def __init__(self): + smach_ros.SimpleActionState.__init__( + self, + "transcribe_speech", + TranscribeSpeechAction, + goal=TranscribeSpeechGoal(), + result_slots=["sequence"], + ) From 85f543933e91eabfc497af8d30b03ca9cc7fac69 Mon Sep 17 00:00:00 2001 From: m-barker Date: Mon, 11 Mar 2024 16:21:23 +0000 Subject: [PATCH 19/23] chore: create gpsr commands sub folder --- tasks/gpsr/launch/{ => commands}/question_answer.launch | 0 tasks/gpsr/nodes/{ => commands}/question_answer | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tasks/gpsr/launch/{ => commands}/question_answer.launch (100%) rename tasks/gpsr/nodes/{ => commands}/question_answer (100%) diff --git a/tasks/gpsr/launch/question_answer.launch b/tasks/gpsr/launch/commands/question_answer.launch similarity index 100% rename from tasks/gpsr/launch/question_answer.launch rename to tasks/gpsr/launch/commands/question_answer.launch diff --git a/tasks/gpsr/nodes/question_answer b/tasks/gpsr/nodes/commands/question_answer similarity index 100% rename from tasks/gpsr/nodes/question_answer rename to tasks/gpsr/nodes/commands/question_answer From 2c8f257a145e98e51888cc7c3fde168ed089f177 Mon Sep 17 00:00:00 2001 From: m-barker Date: Mon, 11 Mar 2024 16:46:00 +0000 Subject: [PATCH 20/23] chore: properly fix ALSA errors --- .../nodes/transcribe_microphone_server | 100 +++++------------- .../requirements.in | 3 +- .../requirements.txt | 81 +++++++------- 3 files changed, 69 insertions(+), 115 deletions(-) diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server index ed1c01d25..afa55f215 100644 --- a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server +++ b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server @@ -1,6 +1,6 @@ #!/usr/bin/env python3 - import os +import sounddevice # needed to remove ALSA error messages import argparse from typing import Optional from dataclasses import dataclass @@ -16,29 +16,6 @@ import lasr_speech_recognition_msgs.msg # type: ignore from std_msgs.msg import String # type: ignore from lasr_speech_recognition_whisper import load_model # type: ignore -# Error handler to remove ALSA error messages taken from: -# https://stackoverflow.com/questions/7088672/pyaudio-working-but-spits-out-error-messages-each-time/17673011#17673011 - -from ctypes import * -from contextlib import contextmanager - -ERROR_HANDLER_FUNC = CFUNCTYPE(None, c_char_p, c_int, c_char_p, c_int, c_char_p) - - -def py_error_handler(filename, line, function, err, fmt): - pass - - -c_error_handler = ERROR_HANDLER_FUNC(py_error_handler) - - -@contextmanager -def noalsaerr(): - asound = cdll.LoadLibrary("libasound.so") - asound.snd_lib_error_set_handler(c_error_handler) - yield - asound.snd_lib_error_set_handler(None) - @dataclass class speech_model_params: @@ -87,40 +64,25 @@ class TranscribeSpeechAction(object): self._transcription_server = rospy.Publisher( "/live_speech_transcription", String, queue_size=10 ) - with noalsaerr(): - self._model = load_model( - self._model_params.model_name, - self._model_params.device, - self._model_params.warmup, - ) - # Configure the speech recogniser object and adjust for ambient noise - self.recogniser = self._configure_recogniser(ambient_adj=True) - # Setup the action server and register execution callback - self._action_server = actionlib.SimpleActionServer( - self._action_name, - lasr_speech_recognition_msgs.msg.TranscribeSpeechAction, - execute_cb=self.execute_cb, - auto_start=False, - ) - self._action_server.register_preempt_callback(self.prempt_cb) - # Setup the timer for adjusting the microphone for ambient noise every x seconds - self._timer_duration = self._model_params.timer_duration - self._timer = rospy.Timer( - rospy.Duration(self._timer_duration), self._timer_cb - ) - self._listening = False - self._action_server.start() + self._model = load_model( + self._model_params.model_name, + self._model_params.device, + self._model_params.warmup, + ) + # Configure the speech recogniser object and adjust for ambient noise + self.recogniser = self._configure_recogniser(ambient_adj=True) + # Setup the action server and register execution callback + self._action_server = actionlib.SimpleActionServer( + self._action_name, + lasr_speech_recognition_msgs.msg.TranscribeSpeechAction, + execute_cb=self.execute_cb, + auto_start=False, + ) + self._action_server.register_preempt_callback(self.prempt_cb) + self._listening = False - def _timer_cb(self, _) -> None: - return - """Adjusts the microphone for ambient noise, unless the action server is listening.""" - if self._listening: - return - rospy.loginfo("Adjusting microphone for ambient noise...") - with noalsaerr(): - with self._configure_microphone() as source: - self.recogniser.adjust_for_ambient_noise(source) + self._action_server.start() def _reset_timer(self) -> None: """Resets the timer for adjusting the microphone for ambient noise.""" @@ -194,17 +156,13 @@ class TranscribeSpeechAction(object): rospy.loginfo("Request Received") if self._action_server.is_preempt_requested(): return - # Since we are about to listen, reset the timer for adjusting the microphone for ambient noise - # as this assumes self_timer_duration seconds of silence before adjusting - self._reset_timer() - with noalsaerr(): - with self._configure_microphone() as src: - self._listening = True - wav_data = self.recogniser.listen( - src, - timeout=self._model_params.start_timeout, - phrase_time_limit=self._model_params.end_timeout, - ).get_wav_data() + with self._configure_microphone() as src: + self._listening = True + wav_data = self.recogniser.listen( + src, + timeout=self._model_params.start_timeout, + phrase_time_limit=self._model_params.end_timeout, + ).get_wav_data() # Magic number 32768.0 is the maximum value of a 16-bit signed integer float_data = ( np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") @@ -293,12 +251,6 @@ def parse_args() -> dict: default=None, help="Microphone device index or name", ) - parser.add_argument( - "--timer_duration", - type=int, - default=20, - help="Number of seconds of silence before the ambient noise adjustment is called.", - ) parser.add_argument( "--no_warmup", action="store_true", @@ -331,8 +283,6 @@ def configure_model_params(config: dict) -> speech_model_params: model_params.sample_rate = config["sample_rate"] if config["mic_device"]: model_params.mic_device = config["mic_device"] - if config["timer_duration"]: - model_params.timer_duration = config["timer_duration"] if config["no_warmup"]: model_params.warmup = False diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.in b/common/speech/lasr_speech_recognition_whisper/requirements.in index 8209d34fc..da48c5086 100644 --- a/common/speech/lasr_speech_recognition_whisper/requirements.in +++ b/common/speech/lasr_speech_recognition_whisper/requirements.in @@ -1,5 +1,6 @@ SpeechRecognition==3.10.0 -openai-whisper==20230314 +sounddevice==0.4.6 +openai-whisper==20231117 PyAudio==0.2.13 PyYaml==6.0.1 rospkg==1.5.0 diff --git a/common/speech/lasr_speech_recognition_whisper/requirements.txt b/common/speech/lasr_speech_recognition_whisper/requirements.txt index a6d9bdebe..1cace21e8 100644 --- a/common/speech/lasr_speech_recognition_whisper/requirements.txt +++ b/common/speech/lasr_speech_recognition_whisper/requirements.txt @@ -1,51 +1,54 @@ -catkin-pkg==0.5.2 # via rospkg -certifi==2023.7.22 # via requests -charset-normalizer==3.2.0 # via requests -cmake==3.27.2 # via triton -distro==1.8.0 # via rospkg +--extra-index-url https://pypi.ngc.nvidia.com +--trusted-host pypi.ngc.nvidia.com + +catkin-pkg==1.0.0 # via rospkg +certifi==2024.2.2 # via requests +cffi==1.16.0 # via sounddevice +charset-normalizer==3.3.2 # via requests +distro==1.9.0 # via rospkg docutils==0.20.1 # via catkin-pkg -ffmpeg-python==0.2.0 # via openai-whisper -filelock==3.12.2 # via torch, triton -future==0.18.3 # via ffmpeg-python -idna==3.4 # via requests -jinja2==3.1.2 # via torch -lit==16.0.6 # via triton -llvmlite==0.40.1 # via numba -markupsafe==2.1.3 # via jinja2 -more-itertools==10.1.0 # via openai-whisper +filelock==3.13.1 # via torch, triton +fsspec==2024.2.0 # via torch +idna==3.6 # via requests +jinja2==3.1.3 # via torch +llvmlite==0.42.0 # via numba +markupsafe==2.1.5 # via jinja2 +more-itertools==10.2.0 # via openai-whisper mpmath==1.3.0 # via sympy -networkx==3.1 # via torch -numba==0.57.1 # via openai-whisper -numpy==1.24.4 # via numba, openai-whisper -nvidia-cublas-cu11==11.10.3.66 # via nvidia-cudnn-cu11, nvidia-cusolver-cu11, torch -nvidia-cuda-cupti-cu11==11.7.101 # via torch -nvidia-cuda-nvrtc-cu11==11.7.99 # via torch -nvidia-cuda-runtime-cu11==11.7.99 # via torch -nvidia-cudnn-cu11==8.5.0.96 # via torch -nvidia-cufft-cu11==10.9.0.58 # via torch -nvidia-curand-cu11==10.2.10.91 # via torch -nvidia-cusolver-cu11==11.4.0.1 # via torch -nvidia-cusparse-cu11==11.7.4.91 # via torch -nvidia-nccl-cu11==2.14.3 # via torch -nvidia-nvtx-cu11==11.7.91 # via torch -openai-whisper==20230314 # via -r requirements.in +networkx==3.2.1 # via torch +numba==0.59.0 # via openai-whisper +numpy==1.26.4 # via numba, openai-whisper +nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch +nvidia-cuda-cupti-cu12==12.1.105 # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 # via torch +nvidia-cuda-runtime-cu12==12.1.105 # via torch +nvidia-cudnn-cu12==8.9.2.26 # via torch +nvidia-cufft-cu12==11.0.2.54 # via torch +nvidia-curand-cu12==10.3.2.106 # via torch +nvidia-cusolver-cu12==11.4.5.107 # via torch +nvidia-cusparse-cu12==12.1.0.106 # via nvidia-cusolver-cu12, torch +nvidia-nccl-cu12==2.19.3 # via torch +nvidia-nvjitlink-cu12==12.4.99 # via nvidia-cusolver-cu12, nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 # via torch +openai-whisper==20231117 # via -r requirements.in pyaudio==0.2.13 # via -r requirements.in -pyparsing==3.1.1 # via catkin-pkg -python-dateutil==2.8.2 # via catkin-pkg +pycparser==2.21 # via cffi +pyparsing==3.1.2 # via catkin-pkg +python-dateutil==2.9.0.post0 # via catkin-pkg pyyaml==6.0.1 # via -r requirements.in, rospkg -regex==2023.8.8 # via tiktoken +regex==2023.12.25 # via tiktoken requests==2.31.0 # via speechrecognition, tiktoken rospkg==1.5.0 # via -r requirements.in six==1.16.0 # via python-dateutil +sounddevice==0.4.6 # via -r requirements.in speechrecognition==3.10.0 # via -r requirements.in sympy==1.12 # via torch -tiktoken==0.3.1 # via openai-whisper -torch==2.0.1 # via openai-whisper, triton -tqdm==4.66.1 # via openai-whisper -triton==2.0.0 # via openai-whisper, torch -typing-extensions==4.7.1 # via torch -urllib3==2.0.4 # via requests -wheel==0.41.1 # via nvidia-cublas-cu11, nvidia-cuda-cupti-cu11, nvidia-cuda-runtime-cu11, nvidia-curand-cu11, nvidia-cusparse-cu11, nvidia-nvtx-cu11 +tiktoken==0.6.0 # via openai-whisper +torch==2.2.1 # via openai-whisper +tqdm==4.66.2 # via openai-whisper +triton==2.2.0 # via openai-whisper, torch +typing-extensions==4.10.0 # via torch +urllib3==2.2.1 # via requests # The following packages are considered to be unsafe in a requirements file: # setuptools From d6dde827f93fd757b125cab4d51af48df2cddf44 Mon Sep 17 00:00:00 2001 From: m-barker Date: Mon, 11 Mar 2024 16:46:18 +0000 Subject: [PATCH 21/23] feat: speech and voice skills --- skills/src/lasr_skills/ask_and_listen.py | 26 ++++++++++++++ skills/src/lasr_skills/listen.py | 2 -- skills/src/lasr_skills/say.py | 45 ++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 skills/src/lasr_skills/ask_and_listen.py create mode 100644 skills/src/lasr_skills/say.py diff --git a/skills/src/lasr_skills/ask_and_listen.py b/skills/src/lasr_skills/ask_and_listen.py new file mode 100644 index 000000000..15f341cd8 --- /dev/null +++ b/skills/src/lasr_skills/ask_and_listen.py @@ -0,0 +1,26 @@ +import smach +from listen import Listen +from skills.src.lasr_skills.say import Say + + +class AskAndListen(smach.StateMachine): + def __init__(self): + smach.StateMachine.__init__( + self, + outcomes=["succeeded", "failed"], + output_keys=["transcribed_speech"], + input_keys=["tts_phrase"], + ) + with self: + smach.StateMachine.add( + "SAY", + Say(), + transitions={"succeeded": "LISTEN", "failed": "failed"}, + remapping={"tts_phrase": "text"}, + ) + smach.StateMachine.add( + "LISTEN", + Listen(), + transitions={"succeeded": "succeeded", "failed": "failed"}, + remapping={"transcribed_speech": "transcribed_speech"}, + ) diff --git a/skills/src/lasr_skills/listen.py b/skills/src/lasr_skills/listen.py index 0e595f59f..272f24b47 100644 --- a/skills/src/lasr_skills/listen.py +++ b/skills/src/lasr_skills/listen.py @@ -2,7 +2,6 @@ import smach_ros from lasr_speech_recognition_msgs.msg import ( TranscribeSpeechAction, - TranscribeSpeechGoal, ) @@ -12,6 +11,5 @@ def __init__(self): self, "transcribe_speech", TranscribeSpeechAction, - goal=TranscribeSpeechGoal(), result_slots=["sequence"], ) diff --git a/skills/src/lasr_skills/say.py b/skills/src/lasr_skills/say.py new file mode 100644 index 000000000..5e0bddb8d --- /dev/null +++ b/skills/src/lasr_skills/say.py @@ -0,0 +1,45 @@ +import smach_ros + +from pal_interaction_msgs.msg import TtsGoal, TtsAction, TtsText + +from typing import Union + + +class Say(smach_ros.SimpleActionState): + def __init__( + self, text: Union[str, None] = None, format_str: Union[str, None] = None + ): + if text is not None: + super(Say, self).__init__( + "tts", + TtsAction, + goal=TtsGoal(rawtext=TtsText(text=text, lang_id="en_GB")), + ) + elif format_str is not None: + super(Say, self).__init__( + "tts", + TtsAction, + goal_cb=lambda ud, _: ( + TtsGoal( + rawtext=TtsText( + text=format_str.format(*ud.placeholders), lang_id="en_GB" + ) + ) + if isinstance(ud.placeholders, (list, tuple)) + else TtsGoal( + rawtext=TtsText( + text=format_str.format(ud.placeholders), lang_id="en_GB" + ) + ) + ), + input_keys=["placeholders"], + ) + else: + super(Say, self).__init__( + "tts", + TtsAction, + goal_cb=lambda ud, _: TtsGoal( + rawtext=TtsText(text=ud.text, lang_id="en_GB") + ), + input_keys=["text"], + ) From e41b4ddb5caaf37a58e56ed4ebb92d1b75ae4c5d Mon Sep 17 00:00:00 2001 From: m-barker Date: Mon, 11 Mar 2024 16:46:35 +0000 Subject: [PATCH 22/23] feat: Q/A skill using new speech/voice skills --- tasks/gpsr/nodes/commands/question_answer | 44 +++++++++++----------- tasks/gpsr/src/gpsr/states/get_question.py | 35 ----------------- 2 files changed, 22 insertions(+), 57 deletions(-) delete mode 100644 tasks/gpsr/src/gpsr/states/get_question.py diff --git a/tasks/gpsr/nodes/commands/question_answer b/tasks/gpsr/nodes/commands/question_answer index e7541a5b0..fea59e120 100644 --- a/tasks/gpsr/nodes/commands/question_answer +++ b/tasks/gpsr/nodes/commands/question_answer @@ -2,9 +2,9 @@ import rospy import argparse import smach -from lasr_voice import Voice from lasr_skills.xml_question_answer import XmlQuestionAnswer -from gpsr.states.get_question import GetQuestion +from lasr_skills.ask_and_listen import AskAndListen +from lasr_skills.say import Say class QuestionAnswerStateMachine(smach.StateMachine): @@ -14,32 +14,37 @@ class QuestionAnswerStateMachine(smach.StateMachine): outcomes=["succeeded", "failed"], output_keys=["closest_answers"], ) - self.userdata.k = input_data["k"] - self.userdata.index_path = input_data["index_path"] - self.userdata.txt_path = input_data["txt_path"] - self.userdata.xml_path = input_data["xml_path"] - print(self.userdata) - + self.userdata.tts_phrase = "I hear you have a question for me; ask away!" with self: smach.StateMachine.add( "GET_QUESTION", - GetQuestion(), + AskAndListen(), transitions={"succeeded": "XML_QUESTION_ANSWER", "failed": "failed"}, - remapping={"question": "query_sentence"}, + remapping={ + "tts_phrase:": "tts_phrase", + "transcribed_speech": "query_sentence", + }, ) smach.StateMachine.add( "XML_QUESTION_ANSWER", - XmlQuestionAnswer(), - transitions={"succeeded": "succeeded", "failed": "GET_QUESTION"}, + XmlQuestionAnswer( + input_data["index_path"], + input_data["txt_path"], + input_data["xml_path"], + input_data["k"], + ), + transitions={"succeeded": "succeeded", "failed": "failed"}, remapping={ "query_sentence": "query_sentence", - "k": "k", - "index_path": "index_path", - "txt_path": "txt_path", - "xml_path": "xml_path", "closest_answers": "closest_answers", }, ) + smach.StateMachine.add( + "SAY_ANSWER", + Say(format_str="The answer to your question is: {}"), + transitions={"succeeded": "succeeded", "failed": "failed"}, + remapping={"placeholders": "closest_answers"}, + ) def parse_args() -> dict: @@ -79,14 +84,9 @@ if __name__ == "__main__": while not rospy.is_shutdown(): q_a_sm = QuestionAnswerStateMachine(args) outcome = q_a_sm.execute() - voice = Voice() if outcome == "succeeded": - rospy.loginfo(f"Closest Answers: {q_a_sm.userdata.closest_answers}") - voice.sync_tts( - f"The answer to your question is: {q_a_sm.userdata.closest_answers[0]}" - ) + rospy.loginfo("Question Answer State Machine succeeded") else: rospy.logerr("Question Answer State Machine failed") - voice.sync_tts(f"Sorry, I wasn't able to find an answer to your question") rospy.spin() diff --git a/tasks/gpsr/src/gpsr/states/get_question.py b/tasks/gpsr/src/gpsr/states/get_question.py deleted file mode 100644 index e7719034c..000000000 --- a/tasks/gpsr/src/gpsr/states/get_question.py +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env python3 -import smach -import rospy -import actionlib -from lasr_voice import Voice -from lasr_speech_recognition_msgs.msg import ( - TranscribeSpeechAction, - TranscribeSpeechGoal, -) - - -class GetQuestion(smach.State): - def __init__(self): - smach.State.__init__( - self, outcomes=["succeeded", "failed"], output_keys=["question"] - ) - self.voice = Voice() - self.client = actionlib.SimpleActionClient( - "transcribe_speech", TranscribeSpeechAction - ) - - def execute(self, userdata): - try: - self.client.wait_for_server() - self.voice.sync_tts("Hello, I hear you have a question for me, ask away!") - goal = TranscribeSpeechGoal() - self.client.send_goal(goal) - self.client.wait_for_result() - result = self.client.get_result() - text = result.sequence - userdata.question = text - return "succeeded" - except Exception as e: - rospy.loginfo(f"Failed to get question: {e}") - return "failed" From d7a9b22afbbe1d6c7722752dfd5e3b4b73a37cc1 Mon Sep 17 00:00:00 2001 From: m-barker Date: Mon, 11 Mar 2024 17:28:27 +0000 Subject: [PATCH 23/23] fix: working new question answer state machine --- skills/src/lasr_skills/__init__.py | 4 +++- skills/src/lasr_skills/ask_and_listen.py | 20 +++++++++++++------ skills/src/lasr_skills/xml_question_answer.py | 2 +- tasks/gpsr/CMakeLists.txt | 2 +- .../launch/commands/question_answer.launch | 2 +- tasks/gpsr/nodes/commands/question_answer | 12 +++++++---- 6 files changed, 28 insertions(+), 14 deletions(-) diff --git a/skills/src/lasr_skills/__init__.py b/skills/src/lasr_skills/__init__.py index b278ee554..1b89feb05 100755 --- a/skills/src/lasr_skills/__init__.py +++ b/skills/src/lasr_skills/__init__.py @@ -10,4 +10,6 @@ from .look_to_point import LookToPoint from .look_to_point import LookToPoint from .go_to_location import GoToLocation -from .go_to_semantic_location import GoToSemanticLocation \ No newline at end of file +from .go_to_semantic_location import GoToSemanticLocation +from .listen import Listen +from .say import Say diff --git a/skills/src/lasr_skills/ask_and_listen.py b/skills/src/lasr_skills/ask_and_listen.py index 15f341cd8..6f110c40a 100644 --- a/skills/src/lasr_skills/ask_and_listen.py +++ b/skills/src/lasr_skills/ask_and_listen.py @@ -1,6 +1,6 @@ import smach -from listen import Listen -from skills.src.lasr_skills.say import Say +from lasr_skills import Listen +from lasr_skills import Say class AskAndListen(smach.StateMachine): @@ -15,12 +15,20 @@ def __init__(self): smach.StateMachine.add( "SAY", Say(), - transitions={"succeeded": "LISTEN", "failed": "failed"}, - remapping={"tts_phrase": "text"}, + transitions={ + "succeeded": "LISTEN", + "aborted": "failed", + "preempted": "failed", + }, + remapping={"text": "tts_phrase"}, ) smach.StateMachine.add( "LISTEN", Listen(), - transitions={"succeeded": "succeeded", "failed": "failed"}, - remapping={"transcribed_speech": "transcribed_speech"}, + transitions={ + "succeeded": "succeeded", + "aborted": "failed", + "preempted": "failed", + }, + remapping={"sequence": "transcribed_speech"}, ) diff --git a/skills/src/lasr_skills/xml_question_answer.py b/skills/src/lasr_skills/xml_question_answer.py index df8e7564f..abede91b5 100644 --- a/skills/src/lasr_skills/xml_question_answer.py +++ b/skills/src/lasr_skills/xml_question_answer.py @@ -51,7 +51,7 @@ def __init__(self, index_path: str, txt_path: str, xml_path: str, k: int = 1): def execute(self, userdata): rospy.wait_for_service("/lasr_faiss/txt_query") - q_a_dict: dict = parse_question_xml(userdata.xml_path) + q_a_dict: dict = parse_question_xml(self.xml_path) try: request = TxtQueryRequest( self.txt_path, diff --git a/tasks/gpsr/CMakeLists.txt b/tasks/gpsr/CMakeLists.txt index 37cbd376b..b8c8a175f 100644 --- a/tasks/gpsr/CMakeLists.txt +++ b/tasks/gpsr/CMakeLists.txt @@ -155,7 +155,7 @@ include_directories( ## in contrast to setup.py, you can choose the destination catkin_install_python(PROGRAMS scripts/parse_gpsr_xmls.py - nodes/question_answer + nodes/commands/question_answer DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} ) diff --git a/tasks/gpsr/launch/commands/question_answer.launch b/tasks/gpsr/launch/commands/question_answer.launch index 79c3a6fb2..2f7855173 100644 --- a/tasks/gpsr/launch/commands/question_answer.launch +++ b/tasks/gpsr/launch/commands/question_answer.launch @@ -24,7 +24,7 @@ type="transcribe_microphone_server" name="transcribe_speech" output="screen" - args="--mic_device 5" + args="--mic_device 9" /> \ No newline at end of file diff --git a/tasks/gpsr/nodes/commands/question_answer b/tasks/gpsr/nodes/commands/question_answer index fea59e120..9c2301c6c 100644 --- a/tasks/gpsr/nodes/commands/question_answer +++ b/tasks/gpsr/nodes/commands/question_answer @@ -22,7 +22,7 @@ class QuestionAnswerStateMachine(smach.StateMachine): transitions={"succeeded": "XML_QUESTION_ANSWER", "failed": "failed"}, remapping={ "tts_phrase:": "tts_phrase", - "transcribed_speech": "query_sentence", + "transcribed_speech": "transcribed_speech", }, ) smach.StateMachine.add( @@ -33,16 +33,20 @@ class QuestionAnswerStateMachine(smach.StateMachine): input_data["xml_path"], input_data["k"], ), - transitions={"succeeded": "succeeded", "failed": "failed"}, + transitions={"succeeded": "SAY_ANSWER", "failed": "failed"}, remapping={ - "query_sentence": "query_sentence", + "query_sentence": "transcribed_speech", "closest_answers": "closest_answers", }, ) smach.StateMachine.add( "SAY_ANSWER", Say(format_str="The answer to your question is: {}"), - transitions={"succeeded": "succeeded", "failed": "failed"}, + transitions={ + "succeeded": "succeeded", + "aborted": "failed", + "preempted": "failed", + }, remapping={"placeholders": "closest_answers"}, )