Skip to content

Commit

Permalink
Refactor RM and Reranker (#28)
Browse files Browse the repository at this point in the history
Refactors RM and Reranker classes. Also changes every parameter of `k`
to `K` to be more consistent. No tests for `ColBERTv2RM` yet since its
not so easy to just exchange it for other models.
  • Loading branch information
sidjha1 authored Nov 6, 2024
1 parent 243070d commit 0fa7d01
Show file tree
Hide file tree
Showing 23 changed files with 391 additions and 313 deletions.
152 changes: 101 additions & 51 deletions .github/tests/rm_tests.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,55 @@
import os

import pandas as pd
import pytest

import lotus
from lotus.models import CrossEncoderModel, E5Model
from lotus.models import CrossEncoderReranker, LiteLLMRM, SentenceTransformersRM

################################################################################
# Setup
################################################################################
# Set logger level to DEBUG
lotus.logger.setLevel("DEBUG")

# Environment flags to enable/disable tests
ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true"
ENABLE_LOCAL_TESTS = os.getenv("ENABLE_LOCAL_TESTS", "false").lower() == "true"

# TODO: Add colbertv2 tests
MODEL_NAME_TO_ENABLED = {
"intfloat/e5-small-v2": ENABLE_LOCAL_TESTS,
"mixedbread-ai/mxbai-rerank-xsmall-v1": ENABLE_LOCAL_TESTS,
"text-embedding-3-small": ENABLE_OPENAI_TESTS,
}
ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled])

MODEL_NAME_TO_CLS = {
"intfloat/e5-small-v2": SentenceTransformersRM,
"mixedbread-ai/mxbai-rerank-xsmall-v1": CrossEncoderReranker,
"text-embedding-3-small": LiteLLMRM,
}


def get_enabled(*candidate_models: str) -> list[str]:
return [model for model in candidate_models if model in ENABLED_MODEL_NAMES]

@pytest.fixture

@pytest.fixture(scope="session")
def setup_models():
# Set up embedder and reranker model
rm = E5Model(model="intfloat/e5-small-v2")
reranker = CrossEncoderModel(model="mixedbread-ai/mxbai-rerank-xsmall-v1")
return rm, reranker
models = {}

for model_name in ENABLED_MODEL_NAMES:
models[model_name] = MODEL_NAME_TO_CLS[model_name](model=model_name)
return models


def test_cluster_by(setup_models):
rm, _ = setup_models
################################################################################
# RM Only Tests
################################################################################
@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small"))
def test_cluster_by(setup_models, model):
rm = setup_models[model]
lotus.settings.configure(rm=rm)

data = {
Expand All @@ -44,8 +76,9 @@ def test_cluster_by(setup_models):
assert probability_group == {"Probability and Random Processes", "Optimization Methods in Engineering"}, groups


def test_search_rm_only(setup_models):
rm, _ = setup_models
@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small"))
def test_search_rm_only(setup_models, model):
rm = setup_models[model]
lotus.settings.configure(rm=rm)

data = {
Expand All @@ -62,43 +95,35 @@ def test_search_rm_only(setup_models):
assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"]


def test_search_reranker_only(setup_models):
_, reranker = setup_models
lotus.settings.configure(reranker=reranker)
@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small"))
def test_sim_join(setup_models, model):
rm = setup_models[model]
lotus.settings.configure(rm=rm)

data = {
data1 = {
"Course Name": [
"Probability and Random Processes",
"Cooking",
"Food Sciences",
"Optimization Methods in Engineering",
"History of the Atlantic World",
"Riemannian Geometry",
]
}
df = pd.DataFrame(data)
df = df.sem_search("Course Name", "Optimization", n_rerank=2)
assert df["Course Name"].tolist() == ["Optimization Methods in Engineering", "Probability and Random Processes"]

data2 = {"Skill": ["Math", "History"]}

def test_search(setup_models):
rm, reranker = setup_models
lotus.settings.configure(rm=rm, reranker=reranker)

data = {
"Course Name": [
"Probability and Random Processes",
"Cooking",
"Food Sciences",
"Optimization Methods in Engineering",
]
}
df = pd.DataFrame(data)
df = df.sem_index("Course Name", "index_dir")
df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1)
assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"]
df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir")
joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1)
joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"]))
expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")}
assert joined_pairs == expected_pairs, joined_pairs


# TODO: threshold is hardcoded for intfloat/e5-small-v2
@pytest.mark.skipif(
"intfloat/e5-small-v2" not in ENABLED_MODEL_NAMES,
reason="Skipping test because intfloat/e5-small-v2 is not enabled",
)
def test_dedup(setup_models):
rm, _ = setup_models
rm = setup_models["intfloat/e5-small-v2"]
lotus.settings.configure(rm=rm)
data = {
"Text": [
Expand All @@ -117,22 +142,47 @@ def test_dedup(setup_models):
assert "Probability" in kept[1], kept


def test_sim_join(setup_models):
rm, _ = setup_models
lotus.settings.configure(rm=rm)
################################################################################
# Reranker Only Tests
################################################################################
@pytest.mark.parametrize("model", get_enabled("mixedbread-ai/mxbai-rerank-xsmall-v1"))
def test_search_reranker_only(setup_models, model):
reranker = setup_models[model]
lotus.settings.configure(reranker=reranker)

data1 = {
data = {
"Course Name": [
"History of the Atlantic World",
"Riemannian Geometry",
"Probability and Random Processes",
"Cooking",
"Food Sciences",
"Optimization Methods in Engineering",
]
}
df = pd.DataFrame(data)
df = df.sem_search("Course Name", "Optimization", n_rerank=2)
assert df["Course Name"].tolist() == ["Optimization Methods in Engineering", "Probability and Random Processes"]

data2 = {"Skill": ["Math", "History"]}

df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir")
joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1)
joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"]))
expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")}
assert joined_pairs == expected_pairs, joined_pairs
################################################################################
# Combined Tests
################################################################################
# TODO: Figure out how to parameterize pairs of models
@pytest.mark.skipif(not ENABLE_LOCAL_TESTS, reason="Skipping test because local tests are not enabled")
def test_search(setup_models):
models = setup_models
rm = models["intfloat/e5-small-v2"]
reranker = models["mixedbread-ai/mxbai-rerank-xsmall-v1"]
lotus.settings.configure(rm=rm, reranker=reranker)

data = {
"Course Name": [
"Probability and Random Processes",
"Cooking",
"Food Sciences",
"Optimization Methods in Engineering",
]
}
df = pd.DataFrame(data)
df = df.sem_index("Course Name", "index_dir")
df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1)
assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"]
9 changes: 8 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,12 @@ jobs:
pip install -e .
pip install pytest
- name: Set OpenAI API Key
run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV

- name: Run RM tests
run: pytest .github/tests/rm_tests.py
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ENABLE_OPENAI_TESTS: true
ENABLE_LOCAL_TESTS: true
run: pytest .github/tests/rm_tests.py
4 changes: 2 additions & 2 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ This can be achieved by applying a semantic filter followed by a semantic aggreg
import pandas as pd
import lotus
from lotus.models import E5Model, LM
from lotus.models import SentenceTransformersRM, LM
# Configure models for LOTUS
lm = LM()
rm = E5Model()
rm = SentenceTransformersRM()
lotus.settings.configure(lm=lm, rm=rm)
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/agg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pandas as pd

import lotus
from lotus.models import LM, E5Model
from lotus.models import LM, SentenceTransformersRM

lm = LM()
rm = E5Model()
rm = SentenceTransformersRM()

lotus.settings.configure(lm=lm, rm=rm)
data = {
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/cluster.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pandas as pd

import lotus
from lotus.models import LM, E5Model
from lotus.models import LM, SentenceTransformersRM

lm = LM()
rm = E5Model()
rm = SentenceTransformersRM()

lotus.settings.configure(lm=lm, rm=rm)
data = {
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/dedup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import E5Model
from lotus.models import SentenceTransformersRM

rm = E5Model()
rm = SentenceTransformersRM()

lotus.settings.configure(rm=rm)
data = {
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/partition.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pandas as pd

import lotus
from lotus.models import LM, E5Model
from lotus.models import LM, SentenceTransformersRM

lm = LM(max_tokens=2048)
rm = E5Model()
rm = SentenceTransformersRM()

lotus.settings.configure(lm=lm, rm=rm)
data = {
Expand Down
6 changes: 3 additions & 3 deletions examples/op_examples/search.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pandas as pd

import lotus
from lotus.models import LM, CrossEncoderModel, E5Model
from lotus.models import LM, CrossEncoderReranker, SentenceTransformersRM

lm = LM()
rm = E5Model()
reranker = CrossEncoderModel()
rm = SentenceTransformersRM()
reranker = CrossEncoderReranker()

lotus.settings.configure(lm=lm, rm=rm, reranker=reranker)
data = {
Expand Down
5 changes: 3 additions & 2 deletions examples/op_examples/sim_join.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pandas as pd

import lotus
from lotus.models import LM, E5Model
from lotus.models import LM, LiteLLMRM

lm = LM()
rm = E5Model()
# rm = SentenceTransformersRM()
rm = LiteLLMRM()

lotus.settings.configure(lm=lm, rm=rm)
data = {
Expand Down
14 changes: 8 additions & 6 deletions lotus/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from lotus.models.colbertv2_model import ColBERTv2Model
from lotus.models.cross_encoder_model import CrossEncoderModel
from lotus.models.e5_model import E5Model
from lotus.models.cross_encoder_reranker import CrossEncoderReranker
from lotus.models.lm import LM
from lotus.models.reranker import Reranker
from lotus.models.rm import RM
from lotus.models.litellm_rm import LiteLLMRM
from lotus.models.sentence_transformers_rm import SentenceTransformersRM
from lotus.models.colbertv2_rm import ColBERTv2RM

__all__ = [
"E5Model",
"ColBERTv2Model",
"CrossEncoderModel",
"CrossEncoderReranker",
"LM",
"RM",
"Reranker",
"LiteLLMRM",
"SentenceTransformersRM",
"ColBERTv2RM",
]
Loading

0 comments on commit 0fa7d01

Please sign in to comment.