Skip to content

Commit

Permalink
Gpsr_command_similarity_parser (#161)
Browse files Browse the repository at this point in the history
* fix: incorrect remapping keys for handover

* fix: remove speech server in launch file

* feat: handle multiple txt files and indices in faiss vector db

* feat: util to split a large text file into a set number of chunks

* feat: command similarity matcher state

* feat: working similarity state, but need to speed up

* feat: handle different index types and multiple data sources

* feat: querying with smart lookup

* feat: working blazingly fast command similarity matcher

* Update tasks/gpsr/states/command_similarity_matcher.py

Co-authored-by: Jared Swift <[email protected]>

* Update tasks/gpsr/states/command_similarity_matcher.py

Co-authored-by: Jared Swift <[email protected]>

* Update tasks/gpsr/states/command_similarity_matcher.py

Co-authored-by: Jared Swift <[email protected]>

---------

Co-authored-by: Jared Swift <[email protected]>
  • Loading branch information
m-barker and jws-1 authored Apr 23, 2024
1 parent 6b03a06 commit a6bf97c
Show file tree
Hide file tree
Showing 11 changed files with 364 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ from lasr_vector_databases_faiss import (
load_model,
parse_txt_file,
get_sentence_embeddings,
create_vector_database,
construct_faiss_index,
add_vectors_to_index,
save_index_to_disk,
)
from typing import List


class TxtIndexService:
Expand All @@ -18,14 +21,67 @@ class TxtIndexService:
rospy.loginfo("Text index service started")

def execute_cb(self, req: TxtIndexRequest):
txt_fp: str = req.txt_path
sentences_to_embed: list[str] = parse_txt_file(txt_fp)
sentence_embeddings: np.ndarray = get_sentence_embeddings(
sentences_to_embed, self._sentence_embedding_model
)
index_path: str = req.index_path
create_vector_database(sentence_embeddings, index_path)
return TxtIndexResponse()
txt_fps: List[str] = req.txt_paths
index_paths: List[str] = req.index_paths
factory_string: str = req.index_factory_string
vecs_per_txt_file: List[int] = []
n_train_vecs = 5000000
if len(index_paths) == 1 and len(txt_fps) > 1:
xn = np.memmap(
f"/tmp/xn.dat",
dtype="float32",
mode="w+",
shape=(11779430, 384),
)
for i, txt_fp in enumerate(txt_fps):
sentences_to_embed: List[str] = parse_txt_file(txt_fp)
sentence_embeddings: np.ndarray = get_sentence_embeddings(
sentences_to_embed, self._sentence_embedding_model
)
if i == 0:
index = construct_faiss_index(
index_factory_string=factory_string,
vector_dim=sentence_embeddings.shape[1],
)
xt = np.empty(
(n_train_vecs, sentence_embeddings.shape[1]), dtype=np.float32
)
sentences_for_training = sentence_embeddings[:100000]
xt[i * 100000 : (i + 1) * 100000] = sentences_for_training
xn[
i
* sentence_embeddings.shape[0] : (i + 1)
* sentence_embeddings.shape[0],
] = sentence_embeddings
vecs_per_txt_file.append(sentence_embeddings.shape[0])
rospy.loginfo("Training index")
index.train(xt)
rospy.loginfo("Adding vectors to index")
add_vectors_to_index(index, xn)
rospy.loginfo("Saving index to disk")
save_index_to_disk(index, index_paths[0])

elif len(index_paths) != len(txt_fps):
rospy.logerr(
"Number of txt files and index paths must be the same, or only one index "
"path must be provided."
f"Got {len(txt_fps)} txt files and {len(index_paths)} index paths."
)
else:
for txt_fp, index_path in zip(txt_fps, index_paths):
sentences_to_embed: list[str] = parse_txt_file(txt_fp)
sentence_embeddings: np.ndarray = get_sentence_embeddings(
sentences_to_embed, self._sentence_embedding_model
)
index = construct_faiss_index(
index_factory_string=factory_string,
vector_dim=sentence_embeddings.shape[1],
)
add_vectors_to_index(index, sentence_embeddings)
save_index_to_disk(index, index_path)
vecs_per_txt_file.append(sentence_embeddings.shape[0])

return TxtIndexResponse(vecs_per_txt_file=vecs_per_txt_file)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ from lasr_vector_databases_faiss import (
query_database,
)

from typing import List
from math import inf


class TxtQueryService:
def __init__(self):
Expand All @@ -24,19 +27,67 @@ class TxtQueryService:
rospy.loginfo("Text Query service started")

def execute_cb(self, req: TxtQueryRequest) -> TxtQueryResponse:
txt_fp: str = req.txt_path
index_path: str = req.index_path
txt_fps: List[str] = req.txt_paths
index_paths: List[str] = req.index_paths
query_sentence: str = req.query_sentence
possible_matches: list[str] = parse_txt_file(txt_fp)
query_embedding: np.ndarray = get_sentence_embeddings(
[query_sentence], self._sentence_embedding_model # requires list of strings
)
distances, indices = query_database(index_path, query_embedding, k=req.k)
nearest_matches = [possible_matches[i] for i in indices[0]]
vecs_per_txt_file: List[int] = req.vecs_per_txt_file

if len(index_paths) == 1 and len(txt_fps) > 1:
distances, indices = query_database(
index_paths[0],
get_sentence_embeddings(
[query_sentence],
self._sentence_embedding_model, # requires list of strings
),
k=req.k,
)
closest_sentences: List[str] = []
for i, index in enumerate(indices[0]):
for j, n_vecs in enumerate(vecs_per_txt_file):
if index < n_vecs:
break
index -= n_vecs
closest_sentences.append(parse_txt_file(txt_fps[j])[index])

return TxtQueryResponse(
closest_sentences=closest_sentences,
cosine_similarities=distances[0],
)

elif len(index_paths) != len(txt_fps):
rospy.logerr(
"Number of txt files and index files must be equal or index files must be 1"
)
return TxtQueryResponse()

else:
best_distances: list[float] = [inf] * req.k
best_matches: list[str] = [""] * req.k
for txt_fp, index_path in zip(txt_fps, index_paths):
possible_matches: list[str] = parse_txt_file(txt_fp)
query_embedding: np.ndarray = get_sentence_embeddings(
[query_sentence],
self._sentence_embedding_model, # requires list of strings
)
distances, indices = query_database(
index_path, query_embedding, k=req.k
)
current_nearest_matches = [possible_matches[i] for i in indices[0]]

for i, match in enumerate(current_nearest_matches):
if distances[0][i] < best_distances[-1]:
best_distances[-1] = distances[0][i]
best_matches[-1] = match
best_distances, best_matches = zip(
*sorted(zip(best_distances, best_matches))
)
best_distances = list(best_distances)
best_matches = list(best_matches)
best_distances.sort()

return TxtQueryResponse(
closest_sentences=nearest_matches,
cosine_similarities=distances[0].tolist(),
closest_sentences=best_matches,
cosine_similarities=best_distances,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

request = TxtIndexRequest()

request.txt_path = (
request.txt_paths = [
"/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.txt"
)
]
request.index_factory_string = "Flat"

request.index_path = (
request.index_paths = [
"/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.index"
)
]
rospy.ServiceProxy("lasr_faiss/txt_index", TxtIndex)(request)
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

request = TxtQueryRequest()

request.txt_path = (
request.txt_paths = [
"/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.txt"
)
]

request.index_path = (
request.index_paths = [
"/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.index"
)
]

request.query_sentence = "Do French like snails?"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
from .database_utils import create_vector_database, load_vector_database, query_database
from .database_utils import (
load_vector_database,
query_database,
save_index_to_disk,
add_vectors_to_index,
construct_faiss_index,
)
from .get_sentence_embeddings import get_sentence_embeddings, load_model, parse_txt_file
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,80 @@
import numpy as np
import faiss

from typing import Union

def create_vector_database(

def construct_faiss_index(
index_factory_string: str,
vector_dim: int,
normalise: bool = False,
use_gpu: bool = False,
) -> faiss.Index:
"""Constructs the faiss vector datbase object.
Args:
index_factory_string (str): Index factory string
vector_dim (int): constant dim of each vector to be added to the db.
normalise (bool, optional): whether to use inner product instead of Euclidean distance.
Defaults to False.
use_gpu (bool, optional): whether to move the index to the GPU. Defaults to False.
Returns:
faiss.Index: constructed faiss index object.
"""

metric = faiss.METRIC_INNER_PRODUCT if normalise else faiss.METRIC_L2
index = faiss.index_factory(vector_dim, index_factory_string, metric)
if use_gpu:
index = faiss.index_cpu_to_all_gpus(index)
return index


def add_vectors_to_index(
index: faiss.Index,
vectors: np.ndarray,
index_path: str,
overwrite: bool = False,
index_type: str = "Flat",
normalise_vecs: bool = True,
normalise: bool = False,
add_with_ids: bool = False,
) -> Union[None, np.ndarray]:
"""Adds a set of vectors to the index, optionally normalising vectors
or adding them with Ids.
Args:
index (faiss.Index): index to add the vectors to.
vectors (np.ndarray): vectors to add to the index of shape (n_vecs, vec_dim)
normalise (bool, optional): whether to normalise the vectors. Defaults to False.
add_with_ids (bool, optional): whether to add the vectors with ids. Defaults to False.
Returns:
Union[None, np.ndarray]: None or the ids of the vectors added.
"""

if normalise:
faiss.normalize_L2(vectors)
if add_with_ids:
ids = np.arange(index.ntotal, index.ntotal + vectors.shape[0])
index.add_with_ids(vectors, ids)
return ids
else:
index.add(vectors)
return None


def save_index_to_disk(
index: faiss.Index, index_path: str, overwrite: bool = False
) -> None:
"""Creates a FAISS Index using the factory constructor and the given
index type, and adds the given vector to the index, and then saves
it to disk using the given path.
"""Saves the index to disk.
Args:
vectors (np.ndarray): vector of shape (n_vectors, vector_dim)
index (faiss.Index): index to save
index_path (str): path to save the index
overwrite (bool, optional): Whether to replace an existing index
at the same filepath if it exists. Defaults to False.
index_type (str, optional): FAISS Index Factory string. Defaults to "IndexFlatIP".
normalise_vecs (bool, optional): Whether to normalise the vectors before
adding them to the Index. This converts the IP metric to Cosine Similarity.
Defaults to True.
overwrite (bool, optional): whether to overwrite the index if it already exists.
Defaults to False.
"""

if os.path.exists(index_path) and not overwrite:
raise FileExistsError(
f"Index already exists at {index_path}. Set overwrite=True to replace it."
)

index = faiss.index_factory(
vectors.shape[1], index_type, faiss.METRIC_INNER_PRODUCT
)
if normalise_vecs:
faiss.normalize_L2(vectors)
index.add(vectors)
faiss.write_index(index, index_path)


Expand All @@ -62,7 +102,7 @@ def load_vector_database(index_path: str, use_gpu: bool = False) -> faiss.Index:
def query_database(
index_path: str,
query_vectors: np.ndarray,
normalise: bool = True,
normalise: bool = False,
k: int = 1,
) -> tuple[np.ndarray, np.ndarray]:
"""Queries the given index with the given query vectors
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import argparse
from typing import Dict


def split_txt_file(input_file, output_file, num_splits):
with open(input_file, "r", encoding="utf8") as src:
lines = src.readlines()
split_size = len(lines) // num_splits
for i in range(num_splits):
with open(f"{output_file}_chunk_{i+1}.txt", "w", encoding="utf8") as dest:
dest.writelines(lines[i * split_size : (i + 1) * split_size])


def parse_args() -> Dict:
parser = argparse.ArgumentParser(description="Split a txt file into chunks")
parser.add_argument("input_file", type=str, help="Path to the input txt file")
parser.add_argument("output_file", type=str, help="Path to the output txt file")
parser.add_argument(
"num_splits", type=int, help="Number of chunks to split the file into"
)
known, _ = parser.parse_known_args()
return vars(known)


if __name__ == "__main__":
args = parse_args()
split_txt_file(args["input_file"], args["output_file"], args["num_splits"])
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Path to input text file
string txt_path
# Path to input text files
string[] txt_paths

# Output path to save index
string index_path
# Output path to save created indices
# If multiple text files are provided, but one
# index file path is provided, this index will contain
# all of the vectors from all of the txt files.
string[] index_paths

# Specifies the type of index to create
# see https://github.com/facebookresearch/faiss/wiki/The-index-factory
string index_factory_string
---
int32[] vecs_per_txt_file
Loading

0 comments on commit a6bf97c

Please sign in to comment.