From 0fa7d0189666fa8a3cf699d5e9f9d697274481d0 Mon Sep 17 00:00:00 2001 From: Sid Jha <45739834+sidjha1@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:30:41 -0800 Subject: [PATCH] Refactor RM and Reranker (#28) Refactors RM and Reranker classes. Also changes every parameter of `k` to `K` to be more consistent. No tests for `ColBERTv2RM` yet since its not so easy to just exchange it for other models. --- .github/tests/rm_tests.py | 152 ++++++++++++------ .github/workflows/tests.yml | 9 +- docs/quickstart.rst | 4 +- examples/op_examples/agg.py | 4 +- examples/op_examples/cluster.py | 4 +- examples/op_examples/dedup.py | 4 +- examples/op_examples/partition.py | 4 +- examples/op_examples/search.py | 6 +- examples/op_examples/sim_join.py | 5 +- lotus/models/__init__.py | 14 +- .../{colbertv2_model.py => colbertv2_rm.py} | 38 ++--- lotus/models/cross_encoder_model.py | 29 ---- lotus/models/cross_encoder_reranker.py | 28 ++++ lotus/models/e5_model.py | 141 ---------------- lotus/models/faiss_rm.py | 62 +++++++ lotus/models/litellm_rm.py | 29 ++++ lotus/models/reranker.py | 8 +- lotus/models/rm.py | 14 +- lotus/models/sentence_transformers_rm.py | 36 +++++ lotus/sem_ops/sem_search.py | 10 +- lotus/sem_ops/sem_sim_join.py | 11 +- lotus/sem_ops/sem_topk.py | 28 ++-- lotus/types.py | 64 +++++--- 23 files changed, 391 insertions(+), 313 deletions(-) rename lotus/models/{colbertv2_model.py => colbertv2_rm.py} (63%) delete mode 100644 lotus/models/cross_encoder_model.py create mode 100644 lotus/models/cross_encoder_reranker.py delete mode 100644 lotus/models/e5_model.py create mode 100644 lotus/models/faiss_rm.py create mode 100644 lotus/models/litellm_rm.py create mode 100644 lotus/models/sentence_transformers_rm.py diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 3940944a..2c00e116 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -1,23 +1,55 @@ +import os + import pandas as pd import pytest import lotus -from lotus.models import CrossEncoderModel, E5Model +from lotus.models import CrossEncoderReranker, LiteLLMRM, SentenceTransformersRM +################################################################################ +# Setup +################################################################################ # Set logger level to DEBUG lotus.logger.setLevel("DEBUG") +# Environment flags to enable/disable tests +ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true" +ENABLE_LOCAL_TESTS = os.getenv("ENABLE_LOCAL_TESTS", "false").lower() == "true" + +# TODO: Add colbertv2 tests +MODEL_NAME_TO_ENABLED = { + "intfloat/e5-small-v2": ENABLE_LOCAL_TESTS, + "mixedbread-ai/mxbai-rerank-xsmall-v1": ENABLE_LOCAL_TESTS, + "text-embedding-3-small": ENABLE_OPENAI_TESTS, +} +ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled]) + +MODEL_NAME_TO_CLS = { + "intfloat/e5-small-v2": SentenceTransformersRM, + "mixedbread-ai/mxbai-rerank-xsmall-v1": CrossEncoderReranker, + "text-embedding-3-small": LiteLLMRM, +} + + +def get_enabled(*candidate_models: str) -> list[str]: + return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] -@pytest.fixture + +@pytest.fixture(scope="session") def setup_models(): - # Set up embedder and reranker model - rm = E5Model(model="intfloat/e5-small-v2") - reranker = CrossEncoderModel(model="mixedbread-ai/mxbai-rerank-xsmall-v1") - return rm, reranker + models = {} + + for model_name in ENABLED_MODEL_NAMES: + models[model_name] = MODEL_NAME_TO_CLS[model_name](model=model_name) + return models -def test_cluster_by(setup_models): - rm, _ = setup_models +################################################################################ +# RM Only Tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_cluster_by(setup_models, model): + rm = setup_models[model] lotus.settings.configure(rm=rm) data = { @@ -44,8 +76,9 @@ def test_cluster_by(setup_models): assert probability_group == {"Probability and Random Processes", "Optimization Methods in Engineering"}, groups -def test_search_rm_only(setup_models): - rm, _ = setup_models +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_search_rm_only(setup_models, model): + rm = setup_models[model] lotus.settings.configure(rm=rm) data = { @@ -62,43 +95,35 @@ def test_search_rm_only(setup_models): assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] -def test_search_reranker_only(setup_models): - _, reranker = setup_models - lotus.settings.configure(reranker=reranker) +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_sim_join(setup_models, model): + rm = setup_models[model] + lotus.settings.configure(rm=rm) - data = { + data1 = { "Course Name": [ - "Probability and Random Processes", - "Cooking", - "Food Sciences", - "Optimization Methods in Engineering", + "History of the Atlantic World", + "Riemannian Geometry", ] } - df = pd.DataFrame(data) - df = df.sem_search("Course Name", "Optimization", n_rerank=2) - assert df["Course Name"].tolist() == ["Optimization Methods in Engineering", "Probability and Random Processes"] + data2 = {"Skill": ["Math", "History"]} -def test_search(setup_models): - rm, reranker = setup_models - lotus.settings.configure(rm=rm, reranker=reranker) - - data = { - "Course Name": [ - "Probability and Random Processes", - "Cooking", - "Food Sciences", - "Optimization Methods in Engineering", - ] - } - df = pd.DataFrame(data) - df = df.sem_index("Course Name", "index_dir") - df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1) - assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] + df1 = pd.DataFrame(data1) + df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir") + joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1) + joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"])) + expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")} + assert joined_pairs == expected_pairs, joined_pairs +# TODO: threshold is hardcoded for intfloat/e5-small-v2 +@pytest.mark.skipif( + "intfloat/e5-small-v2" not in ENABLED_MODEL_NAMES, + reason="Skipping test because intfloat/e5-small-v2 is not enabled", +) def test_dedup(setup_models): - rm, _ = setup_models + rm = setup_models["intfloat/e5-small-v2"] lotus.settings.configure(rm=rm) data = { "Text": [ @@ -117,22 +142,47 @@ def test_dedup(setup_models): assert "Probability" in kept[1], kept -def test_sim_join(setup_models): - rm, _ = setup_models - lotus.settings.configure(rm=rm) +################################################################################ +# Reranker Only Tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("mixedbread-ai/mxbai-rerank-xsmall-v1")) +def test_search_reranker_only(setup_models, model): + reranker = setup_models[model] + lotus.settings.configure(reranker=reranker) - data1 = { + data = { "Course Name": [ - "History of the Atlantic World", - "Riemannian Geometry", + "Probability and Random Processes", + "Cooking", + "Food Sciences", + "Optimization Methods in Engineering", ] } + df = pd.DataFrame(data) + df = df.sem_search("Course Name", "Optimization", n_rerank=2) + assert df["Course Name"].tolist() == ["Optimization Methods in Engineering", "Probability and Random Processes"] - data2 = {"Skill": ["Math", "History"]} - df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir") - joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1) - joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"])) - expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")} - assert joined_pairs == expected_pairs, joined_pairs +################################################################################ +# Combined Tests +################################################################################ +# TODO: Figure out how to parameterize pairs of models +@pytest.mark.skipif(not ENABLE_LOCAL_TESTS, reason="Skipping test because local tests are not enabled") +def test_search(setup_models): + models = setup_models + rm = models["intfloat/e5-small-v2"] + reranker = models["mixedbread-ai/mxbai-rerank-xsmall-v1"] + lotus.settings.configure(rm=rm, reranker=reranker) + + data = { + "Course Name": [ + "Probability and Random Processes", + "Cooking", + "Food Sciences", + "Optimization Methods in Engineering", + ] + } + df = pd.DataFrame(data) + df = df.sem_index("Course Name", "index_dir") + df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1) + assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2527c886..07a9f3ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -149,5 +149,12 @@ jobs: pip install -e . pip install pytest + - name: Set OpenAI API Key + run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV + - name: Run RM tests - run: pytest .github/tests/rm_tests.py \ No newline at end of file + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ENABLE_OPENAI_TESTS: true + ENABLE_LOCAL_TESTS: true + run: pytest .github/tests/rm_tests.py diff --git a/docs/quickstart.rst b/docs/quickstart.rst index e194177f..2a9f2761 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -50,11 +50,11 @@ This can be achieved by applying a semantic filter followed by a semantic aggreg import pandas as pd import lotus - from lotus.models import E5Model, LM + from lotus.models import SentenceTransformersRM, LM # Configure models for LOTUS lm = LM() - rm = E5Model() + rm = SentenceTransformersRM() lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/agg.py b/examples/op_examples/agg.py index 6f6e14b0..6ad9356f 100644 --- a/examples/op_examples/agg.py +++ b/examples/op_examples/agg.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import LM, E5Model +from lotus.models import LM, SentenceTransformersRM lm = LM() -rm = E5Model() +rm = SentenceTransformersRM() lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/cluster.py b/examples/op_examples/cluster.py index 2d7af6f1..9c6697ad 100644 --- a/examples/op_examples/cluster.py +++ b/examples/op_examples/cluster.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import LM, E5Model +from lotus.models import LM, SentenceTransformersRM lm = LM() -rm = E5Model() +rm = SentenceTransformersRM() lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/dedup.py b/examples/op_examples/dedup.py index 5d21087f..8c89aebd 100644 --- a/examples/op_examples/dedup.py +++ b/examples/op_examples/dedup.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import E5Model +from lotus.models import SentenceTransformersRM -rm = E5Model() +rm = SentenceTransformersRM() lotus.settings.configure(rm=rm) data = { diff --git a/examples/op_examples/partition.py b/examples/op_examples/partition.py index 91fa185b..c1c7174e 100644 --- a/examples/op_examples/partition.py +++ b/examples/op_examples/partition.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import LM, E5Model +from lotus.models import LM, SentenceTransformersRM lm = LM(max_tokens=2048) -rm = E5Model() +rm = SentenceTransformersRM() lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/search.py b/examples/op_examples/search.py index 21c7fb5e..60f04190 100644 --- a/examples/op_examples/search.py +++ b/examples/op_examples/search.py @@ -1,11 +1,11 @@ import pandas as pd import lotus -from lotus.models import LM, CrossEncoderModel, E5Model +from lotus.models import LM, CrossEncoderReranker, SentenceTransformersRM lm = LM() -rm = E5Model() -reranker = CrossEncoderModel() +rm = SentenceTransformersRM() +reranker = CrossEncoderReranker() lotus.settings.configure(lm=lm, rm=rm, reranker=reranker) data = { diff --git a/examples/op_examples/sim_join.py b/examples/op_examples/sim_join.py index beaea582..200a7c43 100644 --- a/examples/op_examples/sim_join.py +++ b/examples/op_examples/sim_join.py @@ -1,10 +1,11 @@ import pandas as pd import lotus -from lotus.models import LM, E5Model +from lotus.models import LM, LiteLLMRM lm = LM() -rm = E5Model() +# rm = SentenceTransformersRM() +rm = LiteLLMRM() lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/lotus/models/__init__.py b/lotus/models/__init__.py index 4477c6e2..f88f1dd4 100644 --- a/lotus/models/__init__.py +++ b/lotus/models/__init__.py @@ -1,15 +1,17 @@ -from lotus.models.colbertv2_model import ColBERTv2Model -from lotus.models.cross_encoder_model import CrossEncoderModel -from lotus.models.e5_model import E5Model +from lotus.models.cross_encoder_reranker import CrossEncoderReranker from lotus.models.lm import LM from lotus.models.reranker import Reranker from lotus.models.rm import RM +from lotus.models.litellm_rm import LiteLLMRM +from lotus.models.sentence_transformers_rm import SentenceTransformersRM +from lotus.models.colbertv2_rm import ColBERTv2RM __all__ = [ - "E5Model", - "ColBERTv2Model", - "CrossEncoderModel", + "CrossEncoderReranker", "LM", "RM", "Reranker", + "LiteLLMRM", + "SentenceTransformersRM", + "ColBERTv2RM", ] diff --git a/lotus/models/colbertv2_model.py b/lotus/models/colbertv2_rm.py similarity index 63% rename from lotus/models/colbertv2_model.py rename to lotus/models/colbertv2_rm.py index 51bcd7cb..018af594 100644 --- a/lotus/models/colbertv2_model.py +++ b/lotus/models/colbertv2_rm.py @@ -5,32 +5,28 @@ from numpy.typing import NDArray from lotus.models.rm import RM +from lotus.types import RMOutput +try: + from colbert import Indexer, Searcher + from colbert.infra import ColBERTConfig, Run, RunConfig +except ImportError: + pass -class ColBERTv2Model(RM): - """ColBERTv2 Model""" +class ColBERTv2RM(RM): def __init__(self) -> None: self.docs: list[str] | None = None self.kwargs: dict[str, Any] = {"doc_maxlen": 300, "nbits": 2} self.index_dir: str | None = None - from colbert import Indexer, Searcher - from colbert.infra import ColBERTConfig, Run, RunConfig - - self.Indexer = Indexer - self.Searcher = Searcher - self.ColBERTConfig = ColBERTConfig - self.Run = Run - self.RunConfig = RunConfig - def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: kwargs = {**self.kwargs, **kwargs} checkpoint = "colbert-ir/colbertv2.0" - with self.Run().context(self.RunConfig(nranks=1, experiment="lotus")): - config = self.ColBERTConfig(doc_maxlen=kwargs["doc_maxlen"], nbits=kwargs["nbits"], kmeans_niters=4) - indexer = self.Indexer(checkpoint=checkpoint, config=config) + with Run().context(RunConfig(nranks=1, experiment="lotus")): + config = ColBERTConfig(doc_maxlen=kwargs["doc_maxlen"], nbits=kwargs["nbits"], kmeans_niters=4) + indexer = Indexer(checkpoint=checkpoint, config=config) indexer.index(name=f"{index_dir}/index", collection=docs, overwrite=True) with open(f"experiments/lotus/indexes/{index_dir}/index/docs", "wb") as fp: @@ -45,25 +41,25 @@ def load_index(self, index_dir: str) -> None: self.docs = pickle.load(fp) def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: - raise NotImplementedError("This method is not implemented for ColBERTv2Model") + raise NotImplementedError("This method is not implemented for ColBERTv2RM") def __call__( self, queries: str | list[str] | NDArray[np.float64], - k: int, + K: int, **kwargs: dict[str, Any], - ) -> tuple[list[list[float]], list[list[int]]]: + ) -> RMOutput: if isinstance(queries, str): queries = [queries] - with self.Run().context(self.RunConfig(experiment="lotus")): - searcher = self.Searcher(index=f"{self.index_dir}/index", collection=self.docs) + with Run().context(RunConfig(experiment="lotus")): + searcher = Searcher(index=f"{self.index_dir}/index", collection=self.docs) # make queries a dict with keys as query ids queries_dict = {i: q for i, q in enumerate(queries)} - all_results = searcher.search_all(queries_dict, k=k).todict() + all_results = searcher.search_all(queries_dict, k=K).todict() indices = [[result[0] for result in all_results[qid]] for qid in all_results.keys()] distances = [[result[2] for result in all_results[qid]] for qid in all_results.keys()] - return distances, indices + return RMOutput(distances=distances, indices=indices) diff --git a/lotus/models/cross_encoder_model.py b/lotus/models/cross_encoder_model.py deleted file mode 100644 index f49aa59a..00000000 --- a/lotus/models/cross_encoder_model.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -from sentence_transformers import CrossEncoder - -from lotus.models.reranker import Reranker - - -class CrossEncoderModel(Reranker): - """CrossEncoder reranker model. - - Args: - model (str): The name of the reranker model to use. - device (str): What device to keep the model on. - """ - - def __init__( - self, - model: str = "mixedbread-ai/mxbai-rerank-large-v1", - device: str | None = None, - batch_size: int = 32, - ): - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device: str = device - self.batch_size: int = batch_size - self.model = CrossEncoder(model, device=device) - - def __call__(self, query: str, docs: list[str], k: int) -> list[int]: - results = self.model.rank(query, docs, top_k=k, batch_size=self.batch_size) - return [int(result["corpus_id"]) for result in results] diff --git a/lotus/models/cross_encoder_reranker.py b/lotus/models/cross_encoder_reranker.py new file mode 100644 index 00000000..65827ce2 --- /dev/null +++ b/lotus/models/cross_encoder_reranker.py @@ -0,0 +1,28 @@ +from sentence_transformers import CrossEncoder + +from lotus.models.reranker import Reranker +from lotus.types import RerankerOutput + + +class CrossEncoderReranker(Reranker): + """CrossEncoder reranker model. + + Args: + model (str): The name of the reranker model to use. + device (str): What device to keep the model on. + max_batch_size (int): The maximum batch size to use for the model. + """ + + def __init__( + self, + model: str = "mixedbread-ai/mxbai-rerank-large-v1", + device: str | None = None, + max_batch_size: int = 64, + ): + self.max_batch_size: int = max_batch_size + self.model = CrossEncoder(model, device=device) # type: ignore # CrossEncoder has wrong type stubs + + def __call__(self, query: str, docs: list[str], K: int) -> RerankerOutput: + results = self.model.rank(query, docs, top_k=K, batch_size=self.max_batch_size) + indices = [int(result["corpus_id"]) for result in results] + return RerankerOutput(indices=indices) diff --git a/lotus/models/e5_model.py b/lotus/models/e5_model.py deleted file mode 100644 index d29c7ddf..00000000 --- a/lotus/models/e5_model.py +++ /dev/null @@ -1,141 +0,0 @@ -import os -import pickle -from typing import Any - -import numpy as np -import torch -import torch.nn.functional as F -from numpy.typing import NDArray -from tqdm import tqdm -from transformers import AutoModel, AutoTokenizer - -from lotus.models.rm import RM - - -class E5Model(RM): - """E5 retriever model""" - - def __init__(self, model: str = "intfloat/e5-base-v2", device: str | None = None, **kwargs: dict[str, Any]) -> None: - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = device - self.tokenizer = AutoTokenizer.from_pretrained(model) - self.model = AutoModel.from_pretrained(model).to(self.device) - self.faiss_index = None - self.index_dir: str | None = None - self.docs: list[str] | None = None - self.kwargs: dict[str, Any] = {"normalize": True, "index_type": "Flat", **kwargs} - self.batch_size: int = 100 - self.vecs: NDArray[np.float64] | None = None - - import faiss - - self.faiss = faiss - - def average_pool(self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - """Perform average pooling over the last hidden state. - - Args: - last_hidden_states: Hidden states from the model's last layer - attention_mask: Attention mask. - - Returns: - Average pool over the last hidden state. - """ - - last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) - return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] - - def embed(self, docs: list[str], **kwargs: dict[str, Any]) -> NDArray[np.float64]: - """Run the embedding model. - - Args: - docs: A list of documents to embed. - - Returns: - Embeddings of the documents. - """ - - kwargs = {**self.kwargs, **dict(kwargs)} - - batch_size = kwargs.get("batch_size", self.batch_size) - assert isinstance(batch_size, int), "batch_size must be an integer" - - # Calculating the embedding dimension - total_docs = len(docs) - first_batch = self.tokenizer(docs[:1], return_tensors="pt", padding=True, truncation=True).to(self.device) - embed_dim = self.model(**first_batch).last_hidden_state.size(-1) - - # Pre-allocate a tensor for all embeddings - embeddings = torch.empty((total_docs, embed_dim), device=self.device) - # Processing batches - with torch.inference_mode(): # Slightly faster than torch.no_grad() for inference - for i, batch_start in enumerate(tqdm(range(0, total_docs, batch_size))): - batch = docs[batch_start : batch_start + batch_size] - batch_dict = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) - outputs = self.model(**batch_dict) - batch_embeddings = self.average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) - embeddings[batch_start : batch_start + batch_size] = batch_embeddings - if kwargs["normalize"]: - embeddings = F.normalize(embeddings, p=2, dim=1) - - return embeddings.numpy(force=True) - - def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: - # Make index directory - os.makedirs(index_dir, exist_ok=True) - - # Get document embeddings - kwargs = {**self.kwargs, **kwargs} - embeddings = self.embed(docs, **kwargs) - d = embeddings.shape[1] - index = self.faiss.index_factory(d, kwargs["index_type"], self.faiss.METRIC_INNER_PRODUCT) - index.add(embeddings) - - # Store index and documents - self.faiss.write_index(index, f"{index_dir}/index") - with open(f"{index_dir}/docs", "wb") as fp: - pickle.dump(docs, fp) - with open(f"{index_dir}/vecs", "wb") as fp: - pickle.dump(embeddings, fp) - self.faiss_index = index - self.docs = docs - self.index_dir = index_dir - self.vecs = embeddings - - def load_index(self, index_dir: str) -> None: - self.index_dir = index_dir - self.faiss_index = self.faiss.read_index(f"{index_dir}/index") - with open(f"{index_dir}/docs", "rb") as fp: - self.docs = pickle.load(fp) - with open(f"{index_dir}/vecs", "rb") as fp: - self.vecs = pickle.load(fp) - - @classmethod - def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.float64]: - with open(f"{index_dir}/vecs", "rb") as fp: - vecs: NDArray[np.float64] = pickle.load(fp) - - return vecs[ids] - - def __call__( - self, - queries: str | list[str] | NDArray[np.float64], - k: int, - **kwargs: dict[str, Any], - ) -> tuple[list[list[float]], list[list[int]]]: - if isinstance(queries, str): - queries = [queries] - - if isinstance(queries[0], str): - str_queries: list[str] = [str(q) for q in queries] - embedded_queries = self.embed(str_queries, **kwargs) - else: - embedded_queries = np.asarray(queries, dtype=np.float32) - - if self.faiss_index is None: - raise ValueError("Index not loaded") - - distances, indicies = self.faiss_index.search(embedded_queries, k) - - return distances, indicies diff --git a/lotus/models/faiss_rm.py b/lotus/models/faiss_rm.py new file mode 100644 index 00000000..205129df --- /dev/null +++ b/lotus/models/faiss_rm.py @@ -0,0 +1,62 @@ +import os +import pickle +from abc import abstractmethod +from typing import Any + +import faiss +import numpy as np +from numpy.typing import NDArray + +from lotus.models.rm import RM +from lotus.types import RMOutput + + +class FaissRM(RM): + def __init__(self, factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODUCT): + super().__init__() + self.factory_string = factory_string + self.metric = metric + self.index_dir: str | None = None + self.faiss_index: faiss.Index | None = None + self.vecs: NDArray[np.float64] | None = None + + def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: + vecs = self._embed(docs) + self.faiss_index = faiss.index_factory(vecs.shape[1], self.factory_string, self.metric) + self.faiss_index.add(vecs) + self.index_dir = index_dir + + os.makedirs(index_dir, exist_ok=True) + with open(f"{index_dir}/vecs", "wb") as fp: + pickle.dump(vecs, fp) + faiss.write_index(self.faiss_index, f"{index_dir}/index") + + def load_index(self, index_dir: str) -> None: + self.index_dir = index_dir + self.faiss_index = faiss.read_index(f"{index_dir}/index") + with open(f"{index_dir}/vecs", "rb") as fp: + self.vecs = pickle.load(fp) + + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: + with open(f"{index_dir}/vecs", "rb") as fp: + vecs: NDArray[np.float64] = pickle.load(fp) + return vecs[ids] + + def __call__(self, queries: str | list[str] | NDArray[np.float64], K: int, **kwargs: dict[str, Any]) -> RMOutput: + if isinstance(queries, str): + queries = [queries] + + if isinstance(queries[0], str): + embedded_queries = self._embed([str(q) for q in queries]) + else: + embedded_queries = np.asarray(queries, dtype=np.float32) + + if self.faiss_index is None: + raise ValueError("Index not loaded") + + distances, indices = self.faiss_index.search(embedded_queries, K) + return RMOutput(distances=distances, indices=indices) + + @abstractmethod + def _embed(self, docs: list[str]) -> NDArray[np.float64]: + pass diff --git a/lotus/models/litellm_rm.py b/lotus/models/litellm_rm.py new file mode 100644 index 00000000..cadb4cf5 --- /dev/null +++ b/lotus/models/litellm_rm.py @@ -0,0 +1,29 @@ +import faiss +import numpy as np +from litellm import embedding +from litellm.types.utils import EmbeddingResponse +from numpy.typing import NDArray + +from lotus.models.faiss_rm import FaissRM + + +class LiteLLMRM(FaissRM): + def __init__( + self, + model: str = "text-embedding-3-small", + max_batch_size: int = 64, + factory_string: str = "Flat", + metric=faiss.METRIC_INNER_PRODUCT, + ): + super().__init__(factory_string, metric) + self.model: str = model + self.max_batch_size: int = max_batch_size + + def _embed(self, docs: list[str]) -> NDArray[np.float64]: + all_embeddings = [] + for i in range(0, len(docs), self.max_batch_size): + batch = docs[i : i + self.max_batch_size] + response: EmbeddingResponse = embedding(model=self.model, input=batch) + embeddings = np.array([d["embedding"] for d in response.data]) + all_embeddings.append(embeddings) + return np.vstack(all_embeddings) diff --git a/lotus/models/reranker.py b/lotus/models/reranker.py index 4e2f54ee..a7fd5996 100644 --- a/lotus/models/reranker.py +++ b/lotus/models/reranker.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +from lotus.types import RerankerOutput + class Reranker(ABC): """Abstract class for reranker models.""" @@ -8,15 +10,15 @@ def __init__(self) -> None: pass @abstractmethod - def __call__(self, query: str, docs: list[str], k: int) -> list[int]: + def __call__(self, query: str, docs: list[str], K: int) -> RerankerOutput: """Invoke the reranker. Args: query (str): The query to use for reranking. docs (list[str]): A list of documents to rerank. - k (int): The number of documents to keep after reranking. + K (int): The number of documents to keep after reranking. Returns: - list[int]: The indicies of the reranked documents. + RerankerOutput: The indicies of the reranked documents. """ pass diff --git a/lotus/models/rm.py b/lotus/models/rm.py index e7cb8ba7..330d7cd5 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -4,12 +4,14 @@ import numpy as np from numpy.typing import NDArray +from lotus.types import RMOutput + class RM(ABC): """Abstract class for retriever models.""" def __init__(self) -> None: - pass + self.index_dir: str | None = None @abstractmethod def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: @@ -31,7 +33,7 @@ def load_index(self, index_dir: str) -> None: pass @abstractmethod - def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.float64]: + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: """Get the vectors from the index. Args: @@ -48,17 +50,17 @@ def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.fl def __call__( self, queries: str | list[str] | NDArray[np.float64], - k: int, + K: int, **kwargs: dict[str, Any], - ) -> tuple[list[list[float]], list[list[int]]]: + ) -> RMOutput: """Run top-k search on the index. Args: queries (str | list[str] | NDArray[np.float64]): Either a query or a list of queries or a 2D FP32 array. - k (int): The k to use for top-k search. + K (int): The k to use for top-k search. **kwargs (dict[str, Any]): Additional keyword arguments. Returns: - tuple[list[list[float]], list[list[int]]]: A tuple of (distances, indices) of the top-k vectors + RMOutput: An RMOutput object containing the distances and indices of the top-k vectors. """ pass diff --git a/lotus/models/sentence_transformers_rm.py b/lotus/models/sentence_transformers_rm.py new file mode 100644 index 00000000..bbcd36f9 --- /dev/null +++ b/lotus/models/sentence_transformers_rm.py @@ -0,0 +1,36 @@ +import faiss +import numpy as np +import torch +from numpy.typing import NDArray +from sentence_transformers import SentenceTransformer + +from lotus.models.faiss_rm import FaissRM + + +class SentenceTransformersRM(FaissRM): + def __init__( + self, + model: str = "intfloat/e5-base-v2", + max_batch_size: int = 64, + normalize_embeddings: bool = True, + device: str | None = None, + factory_string: str = "Flat", + metric=faiss.METRIC_INNER_PRODUCT, + ): + super().__init__(factory_string, metric) + self.model: str = model + self.max_batch_size: int = max_batch_size + self.normalize_embeddings: bool = normalize_embeddings + self.transformer: SentenceTransformer = SentenceTransformer(model, device=device) + + def _embed(self, docs: list[str]) -> NDArray[np.float64]: + all_embeddings = [] + for i in range(0, len(docs), self.max_batch_size): + batch = docs[i : i + self.max_batch_size] + torch_embeddings = self.transformer.encode( + batch, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings + ) + assert isinstance(torch_embeddings, torch.Tensor) + cpu_embeddings = torch_embeddings.cpu().numpy() + all_embeddings.append(cpu_embeddings) + return np.vstack(all_embeddings) diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index 49da6a57..d9feb20f 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -3,6 +3,7 @@ import pandas as pd import lotus +from lotus.types import RerankerOutput, RMOutput @pd.api.extensions.register_dataframe_accessor("sem_search") @@ -55,9 +56,9 @@ def __call__( search_K = K while True: - scores, doc_idxs = rm(query, search_K) - doc_idxs = doc_idxs[0] - scores = scores[0] + rm_output: RMOutput = rm(query, search_K) + doc_idxs = rm_output.indices[0] + scores = rm_output.distances[0] assert len(doc_idxs) == len(scores) postfiltered_doc_idxs = [] @@ -83,7 +84,8 @@ def __call__( if n_rerank is not None: docs = new_df[col_name].tolist() - reranked_idxs = lotus.settings.reranker(query, docs, n_rerank) + reranked_output: RerankerOutput = lotus.settings.reranker(query, docs, n_rerank) + reranked_idxs = reranked_output.indices new_df = new_df.iloc[reranked_idxs] return new_df diff --git a/lotus/sem_ops/sem_sim_join.py b/lotus/sem_ops/sem_sim_join.py index d0094f74..04be885f 100644 --- a/lotus/sem_ops/sem_sim_join.py +++ b/lotus/sem_ops/sem_sim_join.py @@ -3,6 +3,8 @@ import pandas as pd import lotus +from lotus.models import RM +from lotus.types import RMOutput @pd.api.extensions.register_dataframe_accessor("sem_sim_join") @@ -46,8 +48,11 @@ def __call__( raise ValueError("Other Series must have a name") other = pd.DataFrame({other.name: other}) - # get rmodel and index rm = lotus.settings.rm + if not isinstance(rm, RM): + raise ValueError( + "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + ) # load query embeddings from index if they exist if left_on in self._obj.attrs.get("index_dirs", []): @@ -71,7 +76,9 @@ def __call__( rm.load_index(col_index_dir) assert rm.index_dir == col_index_dir - distances, indices = rm(queries, K) + rm_output: RMOutput = rm(queries, K) + distances = rm_output.distances + indices = rm_output.indices other_index_set = set(other.index) join_results = [] diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 43190e9a..1db8b514 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -159,7 +159,7 @@ def llm_naive_sort( def llm_quicksort( docs: list[str], user_instruction: str, - k: int, + K: int, embedding: bool = False, strategy: str | None = None, cascade_threshold: float | None = None, @@ -170,7 +170,7 @@ def llm_quicksort( Args: docs (list[str]): The list of documents to sort. user_instruction (str): The user instruction for sorting. - k (int): The number of documents to return. + K (int): The number of documents to return. embedding (bool): Whether to use embedding optimization. cascade_threshold (float | None): The confidence threshold for cascading to a larger model. @@ -187,14 +187,14 @@ def llm_quicksort( stats["total_small_calls"] = 0 stats["total_large_calls"] = 0 - def partition(indexes: list[int], low: int, high: int, k: int) -> int: + def partition(indexes: list[int], low: int, high: int, K: int) -> int: nonlocal stats i = low - 1 if embedding: # With embedding optimization - if k <= high - low: - pivot_value = heapq.nsmallest(k, indexes[low : high + 1])[-1] + if K <= high - low: + pivot_value = heapq.nsmallest(K, indexes[low : high + 1])[-1] else: pivot_value = heapq.nsmallest(int((high - low + 1) / 2), indexes[low : high + 1])[-1] pivot_index = indexes.index(pivot_value) @@ -231,21 +231,21 @@ def partition(indexes: list[int], low: int, high: int, k: int) -> int: indexes[i + 1], indexes[high] = indexes[high], indexes[i + 1] return i + 1 - def quicksort_recursive(indexes: list[int], low: int, high: int, k: int) -> None: + def quicksort_recursive(indexes: list[int], low: int, high: int, K: int) -> None: if high <= low: return if low < high: - pi = partition(indexes, low, high, k) + pi = partition(indexes, low, high, K) left_size = pi - low - if left_size + 1 >= k: - quicksort_recursive(indexes, low, pi - 1, k) + if left_size + 1 >= K: + quicksort_recursive(indexes, low, pi - 1, K) else: quicksort_recursive(indexes, low, pi - 1, left_size) - quicksort_recursive(indexes, pi + 1, high, k - left_size - 1) + quicksort_recursive(indexes, pi + 1, high, K - left_size - 1) indexes = list(range(len(docs))) - quicksort_recursive(indexes, 0, len(indexes) - 1, k) + quicksort_recursive(indexes, 0, len(indexes) - 1, K) return SemanticTopKOutput(indexes=indexes, stats=stats) @@ -273,7 +273,7 @@ def __lt__(self, other: "HeapDoc") -> bool: def llm_heapsort( docs: list[str], user_instruction: str, - k: int, + K: int, strategy: str | None = None, ) -> SemanticTopKOutput: """ @@ -282,7 +282,7 @@ def llm_heapsort( Args: docs (list[str]): The list of documents to sort. user_instruction (str): The user instruction for sorting. - k (int): The number of documents to return. + K (int): The number of documents to return. Returns: SemanticTopKOutput: The indexes of the top k documents and stats. @@ -292,7 +292,7 @@ def llm_heapsort( HeapDoc.strategy = strategy N = len(docs) heap = [HeapDoc(docs[idx], user_instruction, idx) for idx in range(N)] - heap = heapq.nsmallest(k, heap) + heap = heapq.nsmallest(K, heap) indexes = [heapq.heappop(heap).idx for _ in range(len(heap))] stats = {"total_tokens": HeapDoc.total_tokens, "total_llm_calls": HeapDoc.num_calls} diff --git a/lotus/types.py b/lotus/types.py index 54949024..8ae92af2 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -4,6 +4,9 @@ from pydantic import BaseModel +################################################################################ +# Mixins +################################################################################ class StatsMixin(BaseModel): stats: dict[str, Any] | None = None @@ -13,6 +16,36 @@ class LogprobsMixin(BaseModel): logprobs: list[list[ChatCompletionTokenLogprob]] | None = None +################################################################################ +# LM related +################################################################################ +class LMOutput(LogprobsMixin): + outputs: list[str] + + +class LMStats(BaseModel): + class TotalUsage(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + total_cost: float = 0.0 + api_calls: int = 0 + + total_usage: TotalUsage = TotalUsage() + + +class LogprobsForCascade(BaseModel): + tokens: list[list[str]] + confidences: list[list[float]] + + +class LogprobsForFilterCascade(LogprobsForCascade): + true_probs: list[float] + + +################################################################################ +# Semantic operation outputs +################################################################################ class SemanticMapPostprocessOutput(BaseModel): raw_outputs: list[str] outputs: list[str] @@ -58,25 +91,16 @@ class SemanticTopKOutput(StatsMixin): indexes: list[int] -class LMOutput(LogprobsMixin): - outputs: list[str] - - -class LogprobsForCascade(BaseModel): - tokens: list[list[str]] - confidences: list[list[float]] - - -class LogprobsForFilterCascade(LogprobsForCascade): - true_probs: list[float] +################################################################################ +# RM related +################################################################################ +class RMOutput(BaseModel): + distances: list[list[float]] + indices: list[list[int]] -class LMStats(BaseModel): - class TotalUsage(BaseModel): - prompt_tokens: int = 0 - completion_tokens: int = 0 - total_tokens: int = 0 - total_cost: float = 0.0 - api_calls: int = 0 - - total_usage: TotalUsage = TotalUsage() +################################################################################ +# Reranker related +################################################################################ +class RerankerOutput(BaseModel): + indices: list[int]