Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpsr_command_similarity_parser #161

Merged
merged 15 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ from lasr_vector_databases_faiss import (
load_model,
parse_txt_file,
get_sentence_embeddings,
create_vector_database,
construct_faiss_index,
add_vectors_to_index,
save_index_to_disk,
)
from typing import List


class TxtIndexService:
Expand All @@ -18,14 +21,67 @@ class TxtIndexService:
rospy.loginfo("Text index service started")

def execute_cb(self, req: TxtIndexRequest):
txt_fp: str = req.txt_path
sentences_to_embed: list[str] = parse_txt_file(txt_fp)
sentence_embeddings: np.ndarray = get_sentence_embeddings(
sentences_to_embed, self._sentence_embedding_model
)
index_path: str = req.index_path
create_vector_database(sentence_embeddings, index_path)
return TxtIndexResponse()
txt_fps: List[str] = req.txt_paths
index_paths: List[str] = req.index_paths
factory_string: str = req.index_factory_string
vecs_per_txt_file: List[int] = []
n_train_vecs = 5000000
if len(index_paths) == 1 and len(txt_fps) > 1:
xn = np.memmap(
f"/tmp/xn.dat",
dtype="float32",
mode="w+",
shape=(11779430, 384),
)
for i, txt_fp in enumerate(txt_fps):
sentences_to_embed: List[str] = parse_txt_file(txt_fp)
sentence_embeddings: np.ndarray = get_sentence_embeddings(
sentences_to_embed, self._sentence_embedding_model
)
if i == 0:
index = construct_faiss_index(
index_factory_string=factory_string,
vector_dim=sentence_embeddings.shape[1],
)
xt = np.empty(
(n_train_vecs, sentence_embeddings.shape[1]), dtype=np.float32
)
sentences_for_training = sentence_embeddings[:100000]
xt[i * 100000 : (i + 1) * 100000] = sentences_for_training
xn[
i
* sentence_embeddings.shape[0] : (i + 1)
* sentence_embeddings.shape[0],
] = sentence_embeddings
vecs_per_txt_file.append(sentence_embeddings.shape[0])
rospy.loginfo("Training index")
index.train(xt)
rospy.loginfo("Adding vectors to index")
add_vectors_to_index(index, xn)
rospy.loginfo("Saving index to disk")
save_index_to_disk(index, index_paths[0])

elif len(index_paths) != len(txt_fps):
rospy.logerr(
"Number of txt files and index paths must be the same, or only one index "
"path must be provided."
f"Got {len(txt_fps)} txt files and {len(index_paths)} index paths."
)
else:
for txt_fp, index_path in zip(txt_fps, index_paths):
sentences_to_embed: list[str] = parse_txt_file(txt_fp)
sentence_embeddings: np.ndarray = get_sentence_embeddings(
sentences_to_embed, self._sentence_embedding_model
)
index = construct_faiss_index(
index_factory_string=factory_string,
vector_dim=sentence_embeddings.shape[1],
)
add_vectors_to_index(index, sentence_embeddings)
save_index_to_disk(index, index_path)
vecs_per_txt_file.append(sentence_embeddings.shape[0])

return TxtIndexResponse(vecs_per_txt_file=vecs_per_txt_file)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ from lasr_vector_databases_faiss import (
query_database,
)

from typing import List
from math import inf


class TxtQueryService:
def __init__(self):
Expand All @@ -24,19 +27,67 @@ class TxtQueryService:
rospy.loginfo("Text Query service started")

def execute_cb(self, req: TxtQueryRequest) -> TxtQueryResponse:
txt_fp: str = req.txt_path
index_path: str = req.index_path
txt_fps: List[str] = req.txt_paths
index_paths: List[str] = req.index_paths
query_sentence: str = req.query_sentence
possible_matches: list[str] = parse_txt_file(txt_fp)
query_embedding: np.ndarray = get_sentence_embeddings(
[query_sentence], self._sentence_embedding_model # requires list of strings
)
distances, indices = query_database(index_path, query_embedding, k=req.k)
nearest_matches = [possible_matches[i] for i in indices[0]]
vecs_per_txt_file: List[int] = req.vecs_per_txt_file

if len(index_paths) == 1 and len(txt_fps) > 1:
distances, indices = query_database(
index_paths[0],
get_sentence_embeddings(
[query_sentence],
self._sentence_embedding_model, # requires list of strings
),
k=req.k,
)
closest_sentences: List[str] = []
for i, index in enumerate(indices[0]):
for j, n_vecs in enumerate(vecs_per_txt_file):
if index < n_vecs:
break
index -= n_vecs
closest_sentences.append(parse_txt_file(txt_fps[j])[index])

return TxtQueryResponse(
closest_sentences=closest_sentences,
cosine_similarities=distances[0],
)

elif len(index_paths) != len(txt_fps):
rospy.logerr(
"Number of txt files and index files must be equal or index files must be 1"
)
return TxtQueryResponse()

else:
best_distances: list[float] = [inf] * req.k
best_matches: list[str] = [""] * req.k
for txt_fp, index_path in zip(txt_fps, index_paths):
possible_matches: list[str] = parse_txt_file(txt_fp)
query_embedding: np.ndarray = get_sentence_embeddings(
[query_sentence],
self._sentence_embedding_model, # requires list of strings
)
distances, indices = query_database(
index_path, query_embedding, k=req.k
)
current_nearest_matches = [possible_matches[i] for i in indices[0]]

for i, match in enumerate(current_nearest_matches):
if distances[0][i] < best_distances[-1]:
best_distances[-1] = distances[0][i]
best_matches[-1] = match
best_distances, best_matches = zip(
*sorted(zip(best_distances, best_matches))
)
best_distances = list(best_distances)
best_matches = list(best_matches)
best_distances.sort()

return TxtQueryResponse(
closest_sentences=nearest_matches,
cosine_similarities=distances[0].tolist(),
closest_sentences=best_matches,
cosine_similarities=best_distances,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

request = TxtIndexRequest()

request.txt_path = (
request.txt_paths = [
"/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.txt"
)
]
request.index_factory_string = "Flat"

request.index_path = (
request.index_paths = [
"/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.index"
)
]
rospy.ServiceProxy("lasr_faiss/txt_index", TxtIndex)(request)
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

request = TxtQueryRequest()

request.txt_path = (
request.txt_paths = [
"/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.txt"
)
]

request.index_path = (
request.index_paths = [
"/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.index"
)
]

request.query_sentence = "Do French like snails?"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
from .database_utils import create_vector_database, load_vector_database, query_database
from .database_utils import (
load_vector_database,
query_database,
save_index_to_disk,
add_vectors_to_index,
construct_faiss_index,
)
from .get_sentence_embeddings import get_sentence_embeddings, load_model, parse_txt_file
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,80 @@
import numpy as np
import faiss

from typing import Union

def create_vector_database(

def construct_faiss_index(
index_factory_string: str,
vector_dim: int,
normalise: bool = False,
use_gpu: bool = False,
) -> faiss.Index:
"""Constructs the faiss vector datbase object.

Args:
index_factory_string (str): Index factory string
vector_dim (int): constant dim of each vector to be added to the db.
normalise (bool, optional): whether to use inner product instead of Euclidean distance.
Defaults to False.
use_gpu (bool, optional): whether to move the index to the GPU. Defaults to False.

Returns:
faiss.Index: constructed faiss index object.
"""

metric = faiss.METRIC_INNER_PRODUCT if normalise else faiss.METRIC_L2
index = faiss.index_factory(vector_dim, index_factory_string, metric)
if use_gpu:
index = faiss.index_cpu_to_all_gpus(index)
return index


def add_vectors_to_index(
index: faiss.Index,
vectors: np.ndarray,
index_path: str,
overwrite: bool = False,
index_type: str = "Flat",
normalise_vecs: bool = True,
normalise: bool = False,
add_with_ids: bool = False,
) -> Union[None, np.ndarray]:
"""Adds a set of vectors to the index, optionally normalising vectors
or adding them with Ids.

Args:
index (faiss.Index): index to add the vectors to.
vectors (np.ndarray): vectors to add to the index of shape (n_vecs, vec_dim)
normalise (bool, optional): whether to normalise the vectors. Defaults to False.
add_with_ids (bool, optional): whether to add the vectors with ids. Defaults to False.

Returns:
Union[None, np.ndarray]: None or the ids of the vectors added.
"""

if normalise:
faiss.normalize_L2(vectors)
if add_with_ids:
ids = np.arange(index.ntotal, index.ntotal + vectors.shape[0])
index.add_with_ids(vectors, ids)
return ids
else:
index.add(vectors)
return None


def save_index_to_disk(
index: faiss.Index, index_path: str, overwrite: bool = False
) -> None:
"""Creates a FAISS Index using the factory constructor and the given
index type, and adds the given vector to the index, and then saves
it to disk using the given path.
"""Saves the index to disk.

Args:
vectors (np.ndarray): vector of shape (n_vectors, vector_dim)
index (faiss.Index): index to save
index_path (str): path to save the index
overwrite (bool, optional): Whether to replace an existing index
at the same filepath if it exists. Defaults to False.
index_type (str, optional): FAISS Index Factory string. Defaults to "IndexFlatIP".
normalise_vecs (bool, optional): Whether to normalise the vectors before
adding them to the Index. This converts the IP metric to Cosine Similarity.
Defaults to True.
overwrite (bool, optional): whether to overwrite the index if it already exists.
Defaults to False.
"""

if os.path.exists(index_path) and not overwrite:
raise FileExistsError(
f"Index already exists at {index_path}. Set overwrite=True to replace it."
)

index = faiss.index_factory(
vectors.shape[1], index_type, faiss.METRIC_INNER_PRODUCT
)
if normalise_vecs:
faiss.normalize_L2(vectors)
index.add(vectors)
faiss.write_index(index, index_path)


Expand All @@ -62,7 +102,7 @@ def load_vector_database(index_path: str, use_gpu: bool = False) -> faiss.Index:
def query_database(
index_path: str,
query_vectors: np.ndarray,
normalise: bool = True,
normalise: bool = False,
k: int = 1,
) -> tuple[np.ndarray, np.ndarray]:
"""Queries the given index with the given query vectors
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import argparse
from typing import Dict


def split_txt_file(input_file, output_file, num_splits):
with open(input_file, "r", encoding="utf8") as src:
lines = src.readlines()
split_size = len(lines) // num_splits
for i in range(num_splits):
with open(f"{output_file}_chunk_{i+1}.txt", "w", encoding="utf8") as dest:
dest.writelines(lines[i * split_size : (i + 1) * split_size])


def parse_args() -> Dict:
parser = argparse.ArgumentParser(description="Split a txt file into chunks")
parser.add_argument("input_file", type=str, help="Path to the input txt file")
parser.add_argument("output_file", type=str, help="Path to the output txt file")
parser.add_argument(
"num_splits", type=int, help="Number of chunks to split the file into"
)
known, _ = parser.parse_known_args()
return vars(known)


if __name__ == "__main__":
args = parse_args()
split_txt_file(args["input_file"], args["output_file"], args["num_splits"])
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Path to input text file
string txt_path
# Path to input text files
string[] txt_paths

# Output path to save index
string index_path
# Output path to save created indices
# If multiple text files are provided, but one
# index file path is provided, this index will contain
# all of the vectors from all of the txt files.
string[] index_paths

# Specifies the type of index to create
# see https://github.com/facebookresearch/faiss/wiki/The-index-factory
string index_factory_string
---
int32[] vecs_per_txt_file
Loading
Loading