Skip to content

Commit

Permalink
made changes to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AmoghTantradi committed Jan 22, 2025
1 parent 820f3be commit a0a70d2
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 69 deletions.
153 changes: 152 additions & 1 deletion .github/tests/rm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import lotus
from lotus.models import CrossEncoderReranker, LiteLLMRM, SentenceTransformersRM
from lotus.vector_store import ChromaVS, PineconeVS, QdrantVS, WeaviateVS

################################################################################
# Setup
Expand All @@ -30,6 +31,13 @@
"text-embedding-3-small": LiteLLMRM,
}

VECTOR_STORE_TO_CLS = {
'weaviate':WeaviateVS,
'pinecone': PineconeVS,
'chroma': ChromaVS,
'qdrant': QdrantVS
}


def get_enabled(*candidate_models: str) -> list[str]:
return [model for model in candidate_models if model in ENABLED_MODEL_NAMES]
Expand All @@ -48,7 +56,13 @@ def setup_models():

@pytest.fixture(scope='session')
def setup_vs():
pass
vs_and_embed_model = {}

for vs in VECTOR_STORE_TO_CLS:
for model_name in ENABLED_MODEL_NAMES:
vs_and_embed_model[(vs, model_name)] = VECTOR_STORE_TO_CLS[vs](embedding_model=model_name)

return vs_and_embed_model

################################################################################
# RM Only Tests
Expand Down Expand Up @@ -131,6 +145,143 @@ def test_sim_join(setup_models, model):
def test_dedup(setup_models):
rm = setup_models["intfloat/e5-small-v2"]
lotus.settings.configure(rm=rm)
data = {
"Text": [
"Probability and Random Processes",
"Probability and Markov Chains",
"Harry Potter",3
+
































+
"Harry James Potter",
]
}
df = pd.DataFrame(data)
df = df.sem_index("Text", "index_dir").sem_dedup("Text", threshold=0.85)
kept = df["Text"].tolist()
kept.sort()
assert len(kept) == 2, kept
assert "Harry" in kept[0], kept
assert "Probability" in kept[1], kept



################################################################################
# VS Only Tests
################################################################################


@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys())
@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small"))
def test_vs_cluster_by(setup_vs, vs, model):
my_vs = setup_vs[(vs, model)]
lotus.settings.configure(vs=my_vs)

data = {
"Course Name": [
"Probability and Random Processes",
"Cooking",
"Food Sciences",
"Optimization Methods in Engineering",
]
}
df = pd.DataFrame(data)
df = df.sem_index("Course Name", "index_dir")
df = df.sem_cluster_by("Course Name", 2)
groups = df.groupby("cluster_id")["Course Name"].apply(set).to_dict()
assert len(groups) == 2, groups
if "Cooking" in groups[0]:
cooking_group = groups[0]
probability_group = groups[1]
else:
cooking_group = groups[1]
probability_group = groups[0]

assert cooking_group == {"Cooking", "Food Sciences"}, groups
assert probability_group == {"Probability and Random Processes", "Optimization Methods in Engineering"}, groups

@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys())
@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small"))
def test_vs_search_rm_only(setup_vs, vs, model):
my_vs = setup_vs[(vs, model)]
lotus.settings.configure(vs=my_vs)

data = {
"Course Name": [
"Probability and Random Processes",
"Cooking",
"Food Sciences",
"Optimization Methods in Engineering",
]
}
df = pd.DataFrame(data)
df = df.sem_index("Course Name", "index_dir")
df = df.sem_search("Course Name", "Optimization", K=1)
assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"]

@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys())
@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small"))
def test_vs_sim_join(setup_vs, vs, model):
my_vs = setup_vs[(vs, model)]
lotus.settings.configure(vs=my_vs)

data1 = {
"Course Name": [
"History of the Atlantic World",
"Riemannian Geometry",
]
}

data2 = {"Skill": ["Math", "History"]}

df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir")
joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1)
joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"]))
expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")}
assert joined_pairs == expected_pairs, joined_pairs


# TODO: threshold is hardcoded for intfloat/e5-small-v2
@pytest.mark.skipif(
"intfloat/e5-small-v2" not in ENABLED_MODEL_NAMES,
reason="Skipping test because intfloat/e5-small-v2 is not enabled",
)
@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys())
def test_vs_dedup(setup_vs, vs):
my_vs = setup_vs[(vs ,"intfloat/e5-small-v2")]
lotus.settings.configure(vs=my_vs)
data = {
"Text": [
"Probability and Random Processes",
Expand Down
2 changes: 1 addition & 1 deletion lotus/sem_ops/sem_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __call__(self, col_name: str, index_dir: str) -> pd.DataFrame:
"The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()"
)

rm = lotus.settings.rm
rm = lotus.settings.get_rm_or_vs()
rm.index(self._obj[col_name], index_dir)
self._obj.attrs["index_dirs"][col_name] = index_dir
return self._obj
4 changes: 2 additions & 2 deletions lotus/sem_ops/sem_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __call__(
assert not (K is None and n_rerank is None), "K or n_rerank must be provided"
if K is not None:
# get retriever model and index
rm = lotus.settings.rm
rm = lotus.settings.get_rm_or_vs()
if rm is None:
raise ValueError(
"The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()"
"The retrieval model must be an instance of RM or VS. Please configure a valid retrieval model pr vector store using lotus.settings.configure()"
)

col_index_dir = self._obj.attrs["index_dirs"][col_name]
Expand Down
7 changes: 4 additions & 3 deletions lotus/sem_ops/sem_sim_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lotus.cache import operator_cache
from lotus.models import RM
from lotus.types import RMOutput
from lotus.vector_store import VS


@pd.api.extensions.register_dataframe_accessor("sem_sim_join")
Expand Down Expand Up @@ -51,10 +52,10 @@ def __call__(
raise ValueError("Other Series must have a name")
other = pd.DataFrame({other.name: other})

rm = lotus.settings.rm
if not isinstance(rm, RM):
rm = lotus.settings.get_rm_or_vs()
if not isinstance(rm, RM) and not isinstance(rm, VS):
raise ValueError(
"The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()"
"The retrieval model must be an instance of RM or VS. Please configure a valid retrieval model or vector store using lotus.settings.configure()"
)

# load query embeddings from index if they exist
Expand Down
7 changes: 7 additions & 0 deletions lotus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Settings:
reranker: lotus.models.Reranker | None = None
vs: lotus.vector_store.VS | None = None


# Cache settings
enable_cache: bool = False

Expand All @@ -26,10 +27,16 @@ def configure(self, **kwargs):
for key, value in kwargs.items():
if not hasattr(self, key):
raise ValueError(f"Invalid setting: {key}")
if (key == 'vs' and hasattr(self, 'rm')) or (key == 'rm' and hasattr(self, 'vs')):
raise ValueError('Invalid settings: you can only set a retriever module or a vector store, but not both')

setattr(self, key, value)

def __str__(self):
return str(vars(self))

def get_rm_or_vs(self):
return self.rm or self.vs


settings = Settings()
33 changes: 19 additions & 14 deletions lotus/vector_store/chroma_vs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lotus.vector_store.vs import VS

try:
from chromadb import ClientAPI
from chromadb import Client, ClientAPI
from chromadb.api import Collection
from chromadb.api.types import IncludeEnum
except ImportError as err:
Expand All @@ -19,22 +19,27 @@
) from err

class ChromaVS(VS):
def __init__(self, client: ClientAPI, embedding_model: str, max_batch_size: int = 64):
def __init__(self, embedding_model: str, max_batch_size: int = 64):

client: ClientAPI = Client()

"""Initialize with ChromaDB client and embedding model"""
super().__init__(embedding_model)
self.client = client
self.collection: Collection | None = None
self.collection_name = None
self.index_dir = None
self.max_batch_size = max_batch_size

def __del__(self):
return

def index(self, docs: pd.Series, collection_name: str):
def index(self, docs: pd.Series, index_dir: str):
"""Create a collection and add documents with their embeddings"""
self.collection_name = collection_name
self.index_dir = index_dir

# Create collection without embedding function (we'll provide embeddings directly)
self.collection = self.client.create_collection(
name=collection_name,
name=index_dir,
metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency
)

Expand All @@ -59,13 +64,13 @@ def index(self, docs: pd.Series, collection_name: str):
metadatas=metadatas[i:end_idx]
)

def load_index(self, collection_name: str):
def load_index(self, index_dir: str):
"""Load an existing collection"""
try:
self.collection = self.client.get_collection(collection_name)
self.collection_name = collection_name
self.collection = self.client.get_collection(index_dir)
self.index_dir = index_dir
except ValueError as e:
raise ValueError(f"Collection {collection_name} not found") from e
raise ValueError(f"Collection {index_dir} not found") from e

def __call__(
self,
Expand Down Expand Up @@ -126,14 +131,14 @@ def __call__(
indices=np.array(all_indices, dtype=np.int64).tolist()
)

def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]:
def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]:
"""Retrieve vectors for specific document IDs"""
if self.collection is None or self.collection_name != collection_name:
self.load_index(collection_name)
if self.collection is None or self.index_dir != index_dir:
self.load_index(index_dir)


if self.collection is None: # Add this check after load_index
raise ValueError(f"Failed to load collection {collection_name}")
raise ValueError(f"Failed to load collection {index_dir}")


# Convert integer ids to strings for ChromaDB
Expand Down
Loading

0 comments on commit a0a70d2

Please sign in to comment.