From 103e42a989c1762595b210e61d5337e6256f0fa0 Mon Sep 17 00:00:00 2001 From: Sid Jha <45739834+sidjha1@users.noreply.github.com> Date: Tue, 5 Nov 2024 18:37:40 -0800 Subject: [PATCH] LiteLLM Integration + Type Fixing + Usage Tracking + Testing (#26) - [X] Remove `OpenAIModel` and just have a `LM` class that uses `LiteLLM` - [X] Add `mypy` and `pre-commit` and get `mypy` to pass (also add `mypy` to CI). - [X] Add usage tracking with `lm.print_total_usage()` - [X] Add `ollama` tests with `llama3.2:3b` - [X] Verify that code runs with `vLLM` - [X] Merge in `RM` and `Reranker` refactoring from #28 Notes: - I had to do a bit of random/weird plumbing to get `mypy` to pass since our types were pretty messed up --- .github/tests/lm_tests.py | 228 ++++++++---- .github/tests/rm_tests.py | 152 +++++--- .github/workflows/tests.yml | 109 +++++- .gitignore | 4 +- .pre-commit-config.yaml | 21 ++ CONTRIBUTING.md | 27 +- README.md | 4 +- docs/conf.py | 2 +- docs/quickstart.rst | 10 +- docs/requirements-docs.txt | 2 +- examples/op_examples/agg.py | 6 +- examples/op_examples/cluster.py | 6 +- examples/op_examples/dedup.py | 4 +- examples/op_examples/filter.py | 4 +- examples/op_examples/filter_cascade.py | 6 +- examples/op_examples/join.py | 4 +- examples/op_examples/map.py | 4 +- examples/op_examples/map_fewshot.py | 4 +- examples/op_examples/partition.py | 6 +- examples/op_examples/search.py | 8 +- examples/op_examples/sim_join.py | 7 +- examples/op_examples/top_k.py | 4 +- examples/provider_examples/oai.py | 20 -- examples/provider_examples/ollama.py | 25 -- examples/provider_examples/vllm.py | 24 -- lotus/__init__.py | 3 +- lotus/models/__init__.py | 16 +- .../{colbertv2_model.py => colbertv2_rm.py} | 51 ++- lotus/models/cross_encoder_model.py | 29 -- lotus/models/cross_encoder_reranker.py | 28 ++ lotus/models/e5_model.py | 152 -------- lotus/models/faiss_rm.py | 62 ++++ lotus/models/litellm_rm.py | 29 ++ lotus/models/lm.py | 218 ++++++++---- lotus/models/openai_model.py | 325 ------------------ lotus/models/reranker.py | 10 +- lotus/models/rm.py | 25 +- lotus/models/sentence_transformers_rm.py | 36 ++ lotus/sem_ops/cascade_utils.py | 73 ++-- lotus/sem_ops/sem_agg.py | 12 +- lotus/sem_ops/sem_extract.py | 12 +- lotus/sem_ops/sem_filter.py | 67 ++-- lotus/sem_ops/sem_join.py | 26 +- lotus/sem_ops/sem_map.py | 11 +- lotus/sem_ops/sem_search.py | 10 +- lotus/sem_ops/sem_sim_join.py | 11 +- lotus/sem_ops/sem_topk.py | 44 ++- lotus/settings.py | 4 +- lotus/types.py | 54 ++- mypy.ini | 6 + pyproject.toml | 4 +- requirements-dev.txt | 5 +- requirements.txt | 2 +- 53 files changed, 1021 insertions(+), 995 deletions(-) create mode 100644 .pre-commit-config.yaml delete mode 100644 examples/provider_examples/oai.py delete mode 100644 examples/provider_examples/ollama.py delete mode 100644 examples/provider_examples/vllm.py rename lotus/models/{colbertv2_model.py => colbertv2_rm.py} (55%) delete mode 100644 lotus/models/cross_encoder_model.py create mode 100644 lotus/models/cross_encoder_reranker.py delete mode 100644 lotus/models/e5_model.py create mode 100644 lotus/models/faiss_rm.py create mode 100644 lotus/models/litellm_rm.py delete mode 100644 lotus/models/openai_model.py create mode 100644 lotus/models/sentence_transformers_rm.py create mode 100644 mypy.ini diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index af23d1c7..ae68c109 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -1,24 +1,61 @@ +import os + import pandas as pd import pytest +from tokenizers import Tokenizer import lotus -from lotus.models import OpenAIModel +from lotus.models import LM +################################################################################ +# Setup +################################################################################ # Set logger level to DEBUG lotus.logger.setLevel("DEBUG") +# Environment flags to enable/disable tests +ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true" +ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true" + +MODEL_NAME_TO_ENABLED = { + "gpt-4o-mini": ENABLE_OPENAI_TESTS, + "gpt-4o": ENABLE_OPENAI_TESTS, + "ollama/llama3.2": ENABLE_OLLAMA_TESTS, +} +ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled]) + -@pytest.fixture +def get_enabled(*candidate_models: str) -> list[str]: + return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] + + +@pytest.fixture(scope="session") def setup_models(): - # Setup GPT models - gpt_4o_mini = OpenAIModel(model="gpt-4o-mini") - gpt_4o = OpenAIModel(model="gpt-4o") - return gpt_4o_mini, gpt_4o + models = {} + + for model_path in ENABLED_MODEL_NAMES: + models[model_path] = LM(model=model_path) + return models -def test_filter_operation(setup_models): - gpt_4o_mini, _ = setup_models - lotus.settings.configure(lm=gpt_4o_mini) + +@pytest.fixture(autouse=True) +def print_usage_after_each_test(setup_models): + yield # this runs the test + models = setup_models + for model_name, model in models.items(): + print(f"\nUsage stats for {model_name} after test:") + model.print_total_usage() + model.reset_stats() + + +################################################################################ +# Standard tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2")) +def test_filter_operation(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) # Test filter operation on an easy dataframe data = {"Text": ["I am really excited to go to class today!", "I am very sad"]} @@ -30,9 +67,86 @@ def test_filter_operation(setup_models): assert filtered_df.equals(expected_df) +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_top_k(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + data = { + "Text": [ + "Lionel Messi is a good soccer player", + "Michael Jordan is a good basketball player", + "Steph Curry is a good basketball player", + "Tom Brady is a good football player", + ] + } + df = pd.DataFrame(data) + user_instruction = "Which {Text} is most related to basketball?" + top_2_expected = set(["Michael Jordan is a good basketball player", "Steph Curry is a good basketball player"]) + + strategies = ["quick", "heap", "naive"] + for strategy in strategies: + sorted_df = df.sem_topk(user_instruction, K=2, strategy=strategy) + + top_2_actual = set(sorted_df["Text"].values) + assert top_2_expected == top_2_actual + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2")) +def test_join(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + data1 = {"School": ["UC Berkeley", "Stanford"]} + data2 = {"School Type": ["Public School", "Private School"]} + + df1 = pd.DataFrame(data1) + df2 = pd.DataFrame(data2) + join_instruction = "{School} is a {School Type}" + joined_df = df1.sem_join(df2, join_instruction) + joined_pairs = set(zip(joined_df["School"], joined_df["School Type"])) + expected_pairs = set([("UC Berkeley", "Public School"), ("Stanford", "Private School")]) + assert joined_pairs == expected_pairs + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2")) +def test_map_fewshot(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + data = {"School": ["UC Berkeley", "Carnegie Mellon"]} + df = pd.DataFrame(data) + examples = {"School": ["Stanford", "MIT"], "Answer": ["CA", "MA"]} + examples_df = pd.DataFrame(examples) + user_instruction = "What state is {School} in? Respond only with the two-letter abbreviation." + df = df.sem_map(user_instruction, examples=examples_df, suffix="State") + + pairs = set(zip(df["School"], df["State"])) + expected_pairs = set([("UC Berkeley", "CA"), ("Carnegie Mellon", "PA")]) + assert pairs == expected_pairs + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_agg_then_map(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + data = {"Text": ["My name is John", "My name is Jane", "My name is John"]} + df = pd.DataFrame(data) + agg_instruction = "What is the most common name in {Text}?" + agg_df = df.sem_agg(agg_instruction, suffix="draft_output") + map_instruction = "{draft_output} is a draft answer to the question 'What is the most common name?'. Clean up the draft answer so that there is just a single name. Your answer MUST be on word" + cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output") + assert cleaned_df["final_output"].values[0].lower().strip(".,!?\"'") == "john" + + +################################################################################ +# Cascade tests +################################################################################ +@pytest.mark.skipif(not ENABLE_OPENAI_TESTS, reason="Skipping test because OpenAI tests are not enabled") def test_filter_cascade(setup_models): - gpt_4o_mini, gpt_4o = setup_models - lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) + models = setup_models + lotus.settings.configure(lm=models["gpt-4o"], helper_lm=models["gpt-4o-mini"]) data = { "Text": [ @@ -57,7 +171,6 @@ def test_filter_cascade(setup_models): "Everything is going as planned, couldn't be happier.", "Feeling super motivated and ready to take on challenges!", "I appreciate all the small things that bring me joy.", - # Negative examples "I am very sad.", "Today has been really tough; I feel exhausted.", @@ -100,46 +213,10 @@ def test_filter_cascade(setup_models): assert stats["filters_resolved_by_helper_model"] > 0, stats -def test_top_k(setup_models): - gpt_4o_mini, _ = setup_models - lotus.settings.configure(lm=gpt_4o_mini) - - data = { - "Text": [ - "Lionel Messi is a good soccer player", - "Michael Jordan is a good basketball player", - "Steph Curry is a good basketball player", - "Tom Brady is a good football player", - ] - } - df = pd.DataFrame(data) - user_instruction = "Which {Text} is most related to basketball?" - sorted_df = df.sem_topk(user_instruction, K=2) - - top_2_expected = set(["Michael Jordan is a good basketball player", "Steph Curry is a good basketball player"]) - top_2_actual = set(sorted_df["Text"].values) - assert top_2_expected == top_2_actual - - -def test_join(setup_models): - gpt_4o_mini, _ = setup_models - lotus.settings.configure(lm=gpt_4o_mini) - - data1 = {"School": ["UC Berkeley", "Stanford"]} - data2 = {"School Type": ["Public School", "Private School"]} - - df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2) - join_instruction = "{School} is a {School Type}" - joined_df = df1.sem_join(df2, join_instruction) - joined_pairs = set(zip(joined_df["School"], joined_df["School Type"])) - expected_pairs = set([("UC Berkeley", "Public School"), ("Stanford", "Private School")]) - assert joined_pairs == expected_pairs - - +@pytest.mark.skipif(not ENABLE_OPENAI_TESTS, reason="Skipping test because OpenAI tests are not enabled") def test_join_cascade(setup_models): - gpt_4o_mini, gpt_4o = setup_models - lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) + models = setup_models + lotus.settings.configure(lm=models["gpt-4o"], helper_lm=models["gpt-4o-mini"]) data1 = {"School": ["UC Berkeley", "Stanford"]} data2 = {"School Type": ["Public School", "Private School"]} @@ -164,17 +241,38 @@ def test_join_cascade(setup_models): assert stats["filters_resolved_by_helper_model"] == 0, stats -def test_map_fewshot(setup_models): - gpt_4o_mini, _ = setup_models - lotus.settings.configure(lm=gpt_4o_mini) - - data = {"School": ["UC Berkeley", "Carnegie Mellon"]} - df = pd.DataFrame(data) - examples = {"School": ["Stanford", "MIT"], "Answer": ["CA", "MA"]} - examples_df = pd.DataFrame(examples) - user_instruction = "What state is {School} in? Respond only with the two-letter abbreviation." - df = df.sem_map(user_instruction, examples=examples_df, suffix="State") - - pairs = set(zip(df["School"], df["State"])) - expected_pairs = set([("UC Berkeley", "CA"), ("Carnegie Mellon", "PA")]) - assert pairs == expected_pairs +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_format_logprobs_for_filter_cascade(setup_models, model): + lm = setup_models[model] + messages = [ + [{"role": "user", "content": "True or False: The sky is blue?"}], + ] + response = lm(messages, logprobs=True) + formatted_logprobs = lm.format_logprobs_for_filter_cascade(response.logprobs) + true_probs = formatted_logprobs.true_probs + assert len(true_probs) == 1 + + # Very safe (in practice its ~1) + assert true_probs[0] > 0.8 + assert len(formatted_logprobs.tokens) == len(formatted_logprobs.confidences) + + +################################################################################ +# Token counting tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2")) +def test_count_tokens(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + tokens = lm.count_tokens("Hello, world!") + assert lm.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens + assert tokens < 100 + + +def test_custom_tokenizer(): + custom_tokenizer = Tokenizer.from_pretrained("gpt2") + custom_lm = LM(model="doesn't matter", tokenizer=custom_tokenizer) + tokens = custom_lm.count_tokens("Hello, world!") + assert custom_lm.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens + assert tokens < 100 diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 3940944a..2c00e116 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -1,23 +1,55 @@ +import os + import pandas as pd import pytest import lotus -from lotus.models import CrossEncoderModel, E5Model +from lotus.models import CrossEncoderReranker, LiteLLMRM, SentenceTransformersRM +################################################################################ +# Setup +################################################################################ # Set logger level to DEBUG lotus.logger.setLevel("DEBUG") +# Environment flags to enable/disable tests +ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true" +ENABLE_LOCAL_TESTS = os.getenv("ENABLE_LOCAL_TESTS", "false").lower() == "true" + +# TODO: Add colbertv2 tests +MODEL_NAME_TO_ENABLED = { + "intfloat/e5-small-v2": ENABLE_LOCAL_TESTS, + "mixedbread-ai/mxbai-rerank-xsmall-v1": ENABLE_LOCAL_TESTS, + "text-embedding-3-small": ENABLE_OPENAI_TESTS, +} +ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled]) + +MODEL_NAME_TO_CLS = { + "intfloat/e5-small-v2": SentenceTransformersRM, + "mixedbread-ai/mxbai-rerank-xsmall-v1": CrossEncoderReranker, + "text-embedding-3-small": LiteLLMRM, +} + + +def get_enabled(*candidate_models: str) -> list[str]: + return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] -@pytest.fixture + +@pytest.fixture(scope="session") def setup_models(): - # Set up embedder and reranker model - rm = E5Model(model="intfloat/e5-small-v2") - reranker = CrossEncoderModel(model="mixedbread-ai/mxbai-rerank-xsmall-v1") - return rm, reranker + models = {} + + for model_name in ENABLED_MODEL_NAMES: + models[model_name] = MODEL_NAME_TO_CLS[model_name](model=model_name) + return models -def test_cluster_by(setup_models): - rm, _ = setup_models +################################################################################ +# RM Only Tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_cluster_by(setup_models, model): + rm = setup_models[model] lotus.settings.configure(rm=rm) data = { @@ -44,8 +76,9 @@ def test_cluster_by(setup_models): assert probability_group == {"Probability and Random Processes", "Optimization Methods in Engineering"}, groups -def test_search_rm_only(setup_models): - rm, _ = setup_models +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_search_rm_only(setup_models, model): + rm = setup_models[model] lotus.settings.configure(rm=rm) data = { @@ -62,43 +95,35 @@ def test_search_rm_only(setup_models): assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] -def test_search_reranker_only(setup_models): - _, reranker = setup_models - lotus.settings.configure(reranker=reranker) +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_sim_join(setup_models, model): + rm = setup_models[model] + lotus.settings.configure(rm=rm) - data = { + data1 = { "Course Name": [ - "Probability and Random Processes", - "Cooking", - "Food Sciences", - "Optimization Methods in Engineering", + "History of the Atlantic World", + "Riemannian Geometry", ] } - df = pd.DataFrame(data) - df = df.sem_search("Course Name", "Optimization", n_rerank=2) - assert df["Course Name"].tolist() == ["Optimization Methods in Engineering", "Probability and Random Processes"] + data2 = {"Skill": ["Math", "History"]} -def test_search(setup_models): - rm, reranker = setup_models - lotus.settings.configure(rm=rm, reranker=reranker) - - data = { - "Course Name": [ - "Probability and Random Processes", - "Cooking", - "Food Sciences", - "Optimization Methods in Engineering", - ] - } - df = pd.DataFrame(data) - df = df.sem_index("Course Name", "index_dir") - df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1) - assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] + df1 = pd.DataFrame(data1) + df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir") + joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1) + joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"])) + expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")} + assert joined_pairs == expected_pairs, joined_pairs +# TODO: threshold is hardcoded for intfloat/e5-small-v2 +@pytest.mark.skipif( + "intfloat/e5-small-v2" not in ENABLED_MODEL_NAMES, + reason="Skipping test because intfloat/e5-small-v2 is not enabled", +) def test_dedup(setup_models): - rm, _ = setup_models + rm = setup_models["intfloat/e5-small-v2"] lotus.settings.configure(rm=rm) data = { "Text": [ @@ -117,22 +142,47 @@ def test_dedup(setup_models): assert "Probability" in kept[1], kept -def test_sim_join(setup_models): - rm, _ = setup_models - lotus.settings.configure(rm=rm) +################################################################################ +# Reranker Only Tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("mixedbread-ai/mxbai-rerank-xsmall-v1")) +def test_search_reranker_only(setup_models, model): + reranker = setup_models[model] + lotus.settings.configure(reranker=reranker) - data1 = { + data = { "Course Name": [ - "History of the Atlantic World", - "Riemannian Geometry", + "Probability and Random Processes", + "Cooking", + "Food Sciences", + "Optimization Methods in Engineering", ] } + df = pd.DataFrame(data) + df = df.sem_search("Course Name", "Optimization", n_rerank=2) + assert df["Course Name"].tolist() == ["Optimization Methods in Engineering", "Probability and Random Processes"] - data2 = {"Skill": ["Math", "History"]} - df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir") - joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1) - joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"])) - expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")} - assert joined_pairs == expected_pairs, joined_pairs +################################################################################ +# Combined Tests +################################################################################ +# TODO: Figure out how to parameterize pairs of models +@pytest.mark.skipif(not ENABLE_LOCAL_TESTS, reason="Skipping test because local tests are not enabled") +def test_search(setup_models): + models = setup_models + rm = models["intfloat/e5-small-v2"] + reranker = models["mixedbread-ai/mxbai-rerank-xsmall-v1"] + lotus.settings.configure(rm=rm, reranker=reranker) + + data = { + "Course Name": [ + "Probability and Random Processes", + "Cooking", + "Food Sciences", + "Optimization Methods in Engineering", + ] + } + df = pd.DataFrame(data) + df = df.sem_index("Course Name", "index_dir") + df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1) + assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 35b6d58d..07a9f3ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,13 +26,36 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff==0.5.2 + pip install ruff==0.7.2 - name: Run ruff run: ruff check . - test: - name: Python Tests + mypy: + name: Type Check + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install mypy==1.13.0 + pip install -r requirements.txt + pip install -e . + + - name: Run mypy + run: mypy lotus/ + + openai_lm_test: + name: OpenAI Language Model Tests runs-on: ubuntu-latest timeout-minutes: 5 @@ -55,9 +78,83 @@ jobs: - name: Set OpenAI API Key run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV - - name: Run Python tests + - name: Run LM tests env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ENABLE_OPENAI_TESTS: true + run: pytest .github/tests/lm_tests.py + + ollama_lm_test: + name: Ollama Language Model Tests + runs-on: ubuntu-latest + timeout-minutes: 10 + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies run: | - pytest .github/tests/lm_tests.py - pytest .github/tests/rm_tests.py \ No newline at end of file + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -e . + pip install pytest + + - name: Start Ollama container + run: | + docker pull ollama/ollama:latest + docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama + # Wait for Ollama server to be ready + timeout=30 + while ! curl -s http://localhost:11434/ >/dev/null; do + if [ $timeout -le 0 ]; then + echo "Timed out waiting for Ollama server" + exit 1 + fi + echo "Waiting for Ollama server to be ready..." + sleep 1 + timeout=$((timeout - 1)) + done + docker exec $(docker ps -q) ollama run llama3.2 + + - name: Run LM tests + env: + ENABLE_OLLAMA_TESTS: true + run: pytest .github/tests/lm_tests.py + + + rm_test: + name: Retrieval Model Tests + runs-on: ubuntu-latest + timeout-minutes: 5 + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -e . + pip install pytest + + - name: Set OpenAI API Key + run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV + + - name: Run RM tests + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ENABLE_OPENAI_TESTS: true + ENABLE_LOCAL_TESTS: true + run: pytest .github/tests/rm_tests.py diff --git a/.gitignore b/.gitignore index f2118286..1a45a8d6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,6 @@ __pycache__/ *.log dist/ docs/_build -.ruff_cache \ No newline at end of file +.ruff_cache +.mypy_cache +.pytest_cache \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..0964e87f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.7.2 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.13.0 + hooks: + - id: mypy + args: ["--config-file", "mypy.ini"] + additional_dependencies: + - types-setuptools + - litellm>=1.51.0 + - numpy>=1.25.0 + - pandas>=2.0.0 + - sentence-transformers>=3.0.1 + - tiktoken>=0.7.0 + - tqdm>=4.66.4 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 02cac94f..0a08df16 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,27 +9,32 @@ conda activate lotus git clone git@github.com:stanford-futuredata/lotus.git pip install -e . pip install -r requirements-dev.txt +pre-commit install ``` ## Dev Flow After making your changes, please make a PR to get your changes merged upstream. -## Running vLLM Models -To use vLLM for model serving, you just need to make an OpenAI compatible vLLM server. Then, the `OpenAIModel` class can be used to point to the server. See an example below. +## Running Models +To run a model, you can use the `LM` class in `lotus.models.LM`. We use the `litellm` library to interface with the model. +This allows you to use any model provider that is supported by `litellm`. -Create the server +Here's an example of creating an `LM` object for `gpt-4o` ``` -python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-70B-Instruct --port 8000 --tensor-parallel-size 8 +from lotus.models import LM +lm = LM(model="gpt-4o") ``` -In LOTUS, you should instantiate your model as follows +Here's an example of creating an `LM` object to use `llama3.2` on Ollama ``` -from lotus.models import OpenAIModel -lm = OpenAIModel( - model="meta-llama/Meta-Llama-3.1-70B-Instruct", - api_base="http://localhost:8000/v1", - provider="vllm", -) +from lotus.models import LM +lm = LM(model="ollama/llama3.2") +``` + +Here's an example of creating an `LM` object to use `Meta-Llama-3-8B-Instruct` on vLLM +``` +from lotus.models import LM +lm = LM(model='hosted_vllm/meta-llama/Meta-Llama-3-8B-Instruct', api_base='http://localhost:8000/v1') ``` ## Helpful Examples diff --git a/README.md b/README.md index acfd61e4..cdd56c62 100644 --- a/README.md +++ b/README.md @@ -45,10 +45,10 @@ If you're already familiar with Pandas, getting started will be a breeze! Below ```python import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM # configure the LM, and remember to export your API key -lm = OpenAIModel() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) # create dataframes with course names and skills diff --git a/docs/conf.py b/docs/conf.py index 3b4508c1..a9bf8b8f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,7 +14,7 @@ project = "LOTUS" copyright = "2024, Liana Patel, Siddharth Jha, Carlos Guestrin, Matei Zaharia" author = "Liana Patel, Siddharth Jha, Carlos Guestrin, Matei Zaharia" -release = "0.2.2" +release = "0.3.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/quickstart.rst b/docs/quickstart.rst index fe7e99d2..b2fcd059 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -50,11 +50,11 @@ This can be achieved by applying a semantic filter followed by a semantic aggreg import pandas as pd import lotus - from lotus.models import E5Model, OpenAIModel + from lotus.models import SentenceTransformersRM, LM # Configure models for LOTUS - lm = OpenAIModel(max_tokens=512) - rm = E5Model() + lm = LM(model="gpt-4o-mini") + rm = SentenceTransformersRM(model="intfloat/e5-base-v2") lotus.settings.configure(lm=lm, rm=rm) @@ -90,7 +90,7 @@ If we wanted the challenge of taking courses with a high workload, we can also u .. code-block:: python - top_2_hardest = df.sem_topk("What {Description} indicates the highest workload?", 2) + top_2_hardest = df.sem_topk("What {Description} indicates the highest workload?", K=2) LOTUS's semantic join operator can be used to join two dataframes based on a predicate. Suppose we had a second dataframe containing skills we wanted to get better at (SQL and Chip Design in our case). @@ -113,7 +113,7 @@ Let's create a semantic index on the course description column and then search f # Create a semantic index on the description column and save it to the index_dir directory df = df.sem_index("Description", "index_dir") - top_conv_df = df.sem_search("Description", "Convolutional Neural Network", 1) + top_conv_df = df.sem_search("Description", "Convolutional Neural Network", K=1) Another useful operator is the semantic map operator. Let's see how it can be used to get some next topics to explore for each class. Additionally, let's provide some examples to the model that can be used for demonstrations. diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 101058ec..404a5c89 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -3,8 +3,8 @@ sphinx-rtd-theme==2.0.0 backoff==2.2.1 faiss-cpu==1.8.0.post1 +litellm==1.51.0 numpy==1.26.4 -openai==1.35.13 pandas==2.2.2 sentence-transformers==3.0.1 tiktoken==0.7.0 diff --git a/examples/op_examples/agg.py b/examples/op_examples/agg.py index add1711e..206e3cc8 100644 --- a/examples/op_examples/agg.py +++ b/examples/op_examples/agg.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, SentenceTransformersRM -lm = OpenAIModel() -rm = E5Model() +lm = LM(model="gpt-4o-mini") +rm = SentenceTransformersRM(model="intfloat/e5-base-v2") lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/cluster.py b/examples/op_examples/cluster.py index 7bcc307b..e117b249 100644 --- a/examples/op_examples/cluster.py +++ b/examples/op_examples/cluster.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, SentenceTransformersRM -lm = OpenAIModel() -rm = E5Model() +lm = LM(model="gpt-4o-mini") +rm = SentenceTransformersRM(model="intfloat/e5-base-v2") lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/dedup.py b/examples/op_examples/dedup.py index 5d21087f..1494df95 100644 --- a/examples/op_examples/dedup.py +++ b/examples/op_examples/dedup.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import E5Model +from lotus.models import SentenceTransformersRM -rm = E5Model() +rm = SentenceTransformersRM(model="intfloat/e5-base-v2") lotus.settings.configure(rm=rm) data = { diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index b89f74f2..a1acc00d 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/filter_cascade.py b/examples/op_examples/filter_cascade.py index 5af900b2..583fd78b 100644 --- a/examples/op_examples/filter_cascade.py +++ b/examples/op_examples/filter_cascade.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -gpt_35_turbo = OpenAIModel("gpt-3.5-turbo") -gpt_4o = OpenAIModel("gpt-4o") +gpt_35_turbo = LM("gpt-3.5-turbo") +gpt_4o = LM("gpt-4o") lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_35_turbo) data = { diff --git a/examples/op_examples/join.py b/examples/op_examples/join.py index 2c850497..7291c575 100644 --- a/examples/op_examples/join.py +++ b/examples/op_examples/join.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/map.py b/examples/op_examples/map.py index 6323899d..4fb163f2 100644 --- a/examples/op_examples/map.py +++ b/examples/op_examples/map.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/map_fewshot.py b/examples/op_examples/map_fewshot.py index fea45dc8..365f7c9a 100644 --- a/examples/op_examples/map_fewshot.py +++ b/examples/op_examples/map_fewshot.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/partition.py b/examples/op_examples/partition.py index ca42d171..932b170b 100644 --- a/examples/op_examples/partition.py +++ b/examples/op_examples/partition.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, SentenceTransformersRM -lm = OpenAIModel(max_tokens=2048) -rm = E5Model() +lm = LM(max_tokens=2048) +rm = SentenceTransformersRM(model="intfloat/e5-base-v2") lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/search.py b/examples/op_examples/search.py index b7ebf67d..c9382aae 100644 --- a/examples/op_examples/search.py +++ b/examples/op_examples/search.py @@ -1,11 +1,11 @@ import pandas as pd import lotus -from lotus.models import CrossEncoderModel, E5Model, OpenAIModel +from lotus.models import LM, CrossEncoderReranker, SentenceTransformersRM -lm = OpenAIModel() -rm = E5Model() -reranker = CrossEncoderModel() +lm = LM(model="gpt-4o-mini") +rm = SentenceTransformersRM(model="intfloat/e5-base-v2") +reranker = CrossEncoderReranker(model="mixedbread-ai/mxbai-rerank-large-v1") lotus.settings.configure(lm=lm, rm=rm, reranker=reranker) data = { diff --git a/examples/op_examples/sim_join.py b/examples/op_examples/sim_join.py index 7d3981ed..efc97427 100644 --- a/examples/op_examples/sim_join.py +++ b/examples/op_examples/sim_join.py @@ -1,10 +1,11 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, LiteLLMRM -lm = OpenAIModel() -rm = E5Model() +lm = LM(model="gpt-4o-mini") +# rm = SentenceTransformersRM(model="intfloat/e5-base-v2") +rm = LiteLLMRM() lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/top_k.py b/examples/op_examples/top_k.py index 2930e305..8654ea18 100644 --- a/examples/op_examples/top_k.py +++ b/examples/op_examples/top_k.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) data = { diff --git a/examples/provider_examples/oai.py b/examples/provider_examples/oai.py deleted file mode 100644 index b89f74f2..00000000 --- a/examples/provider_examples/oai.py +++ /dev/null @@ -1,20 +0,0 @@ -import pandas as pd - -import lotus -from lotus.models import OpenAIModel - -lm = OpenAIModel() - -lotus.settings.configure(lm=lm) -data = { - "Course Name": [ - "Probability and Random Processes", - "Optimization Methods in Engineering", - "Digital Design and Integrated Circuits", - "Computer Security", - ] -} -df = pd.DataFrame(data) -user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction) -print(df) diff --git a/examples/provider_examples/ollama.py b/examples/provider_examples/ollama.py deleted file mode 100644 index 727add7d..00000000 --- a/examples/provider_examples/ollama.py +++ /dev/null @@ -1,25 +0,0 @@ -import pandas as pd - -import lotus -from lotus.models import OpenAIModel - -lm = OpenAIModel( - api_base="http://localhost:11434/v1", - model="llama3.2", - hf_name="meta-llama/Llama-3.2-3B-Instruct", - provider="ollama", -) - -lotus.settings.configure(lm=lm) -data = { - "Course Name": [ - "Probability and Random Processes", - "Optimization Methods in Engineering", - "Digital Design and Integrated Circuits", - "Computer Security", - ] -} -df = pd.DataFrame(data) -user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction) -print(df) diff --git a/examples/provider_examples/vllm.py b/examples/provider_examples/vllm.py deleted file mode 100644 index 76a46884..00000000 --- a/examples/provider_examples/vllm.py +++ /dev/null @@ -1,24 +0,0 @@ -import pandas as pd - -import lotus -from lotus.models import OpenAIModel - -lm = OpenAIModel( - model="meta-llama/Meta-Llama-3.1-70B-Instruct", - api_base="http://localhost:8000/v1", - provider="vllm", -) - -lotus.settings.configure(lm=lm) -data = { - "Course Name": [ - "Probability and Random Processes", - "Optimization Methods in Engineering", - "Digital Design and Integrated Circuits", - "Computer Security", - ] -} -df = pd.DataFrame(data) -user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction) -print(df) diff --git a/lotus/__init__.py b/lotus/__init__.py index 190d9d52..58f4575e 100644 --- a/lotus/__init__.py +++ b/lotus/__init__.py @@ -19,7 +19,8 @@ sem_dedup, sem_topk, ) -from lotus.settings import settings +from lotus.settings import settings # type: ignore[attr-defined] + logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/lotus/models/__init__.py b/lotus/models/__init__.py index 194d7259..f88f1dd4 100644 --- a/lotus/models/__init__.py +++ b/lotus/models/__init__.py @@ -1,17 +1,17 @@ -from lotus.models.colbertv2_model import ColBERTv2Model -from lotus.models.cross_encoder_model import CrossEncoderModel -from lotus.models.e5_model import E5Model +from lotus.models.cross_encoder_reranker import CrossEncoderReranker from lotus.models.lm import LM -from lotus.models.openai_model import OpenAIModel from lotus.models.reranker import Reranker from lotus.models.rm import RM +from lotus.models.litellm_rm import LiteLLMRM +from lotus.models.sentence_transformers_rm import SentenceTransformersRM +from lotus.models.colbertv2_rm import ColBERTv2RM __all__ = [ - "OpenAIModel", - "E5Model", - "ColBERTv2Model", - "CrossEncoderModel", + "CrossEncoderReranker", "LM", "RM", "Reranker", + "LiteLLMRM", + "SentenceTransformersRM", + "ColBERTv2RM", ] diff --git a/lotus/models/colbertv2_model.py b/lotus/models/colbertv2_rm.py similarity index 55% rename from lotus/models/colbertv2_model.py rename to lotus/models/colbertv2_rm.py index 2c407a2c..018af594 100644 --- a/lotus/models/colbertv2_model.py +++ b/lotus/models/colbertv2_rm.py @@ -1,33 +1,32 @@ import pickle from typing import Any +import numpy as np +from numpy.typing import NDArray + from lotus.models.rm import RM +from lotus.types import RMOutput +try: + from colbert import Indexer, Searcher + from colbert.infra import ColBERTConfig, Run, RunConfig +except ImportError: + pass -class ColBERTv2Model(RM): - """ColBERTv2 Model""" - def __init__(self, **kwargs): +class ColBERTv2RM(RM): + def __init__(self) -> None: self.docs: list[str] | None = None - self.kwargs: dict[str, Any] = {"doc_maxlen": 300, "nbits": 2, **kwargs} + self.kwargs: dict[str, Any] = {"doc_maxlen": 300, "nbits": 2} self.index_dir: str | None = None - from colbert import Indexer, Searcher - from colbert.infra import ColBERTConfig, Run, RunConfig - - self.Indexer = Indexer - self.Searcher = Searcher - self.ColBERTConfig = ColBERTConfig - self.Run = Run - self.RunConfig = RunConfig - def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: kwargs = {**self.kwargs, **kwargs} checkpoint = "colbert-ir/colbertv2.0" - with self.Run().context(self.RunConfig(nranks=1, experiment="lotus")): - config = self.ColBERTConfig(doc_maxlen=kwargs["doc_maxlen"], nbits=kwargs["nbits"], kmeans_niters=4) - indexer = self.Indexer(checkpoint=checkpoint, config=config) + with Run().context(RunConfig(nranks=1, experiment="lotus")): + config = ColBERTConfig(doc_maxlen=kwargs["doc_maxlen"], nbits=kwargs["nbits"], kmeans_niters=4) + indexer = Indexer(checkpoint=checkpoint, config=config) indexer.index(name=f"{index_dir}/index", collection=docs, overwrite=True) with open(f"experiments/lotus/indexes/{index_dir}/index/docs", "wb") as fp: @@ -41,26 +40,26 @@ def load_index(self, index_dir: str) -> None: with open(f"experiments/lotus/indexes/{index_dir}/index/docs", "rb") as fp: self.docs = pickle.load(fp) - def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list: - raise NotImplementedError("This method is not implemented for ColBERTv2Model") + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: + raise NotImplementedError("This method is not implemented for ColBERTv2RM") def __call__( self, - queries: str | list[str] | list[list[float]], - k: int, + queries: str | list[str] | NDArray[np.float64], + K: int, **kwargs: dict[str, Any], - ) -> tuple[list[float], list[int]]: + ) -> RMOutput: if isinstance(queries, str): queries = [queries] - with self.Run().context(self.RunConfig(experiment="lotus")): - searcher = self.Searcher(index=f"{self.index_dir}/index", collection=self.docs) + with Run().context(RunConfig(experiment="lotus")): + searcher = Searcher(index=f"{self.index_dir}/index", collection=self.docs) # make queries a dict with keys as query ids - queries = {i: q for i, q in enumerate(queries)} - all_results = searcher.search_all(queries, k=k).todict() + queries_dict = {i: q for i, q in enumerate(queries)} + all_results = searcher.search_all(queries_dict, k=K).todict() indices = [[result[0] for result in all_results[qid]] for qid in all_results.keys()] distances = [[result[2] for result in all_results[qid]] for qid in all_results.keys()] - return distances, indices + return RMOutput(distances=distances, indices=indices) diff --git a/lotus/models/cross_encoder_model.py b/lotus/models/cross_encoder_model.py deleted file mode 100644 index 1f4c9512..00000000 --- a/lotus/models/cross_encoder_model.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -from sentence_transformers import CrossEncoder - -from lotus.models.reranker import Reranker - - -class CrossEncoderModel(Reranker): - """CrossEncoder reranker model. - - Args: - model (str): The name of the reranker model to use. - device (str): What device to keep the model on. - """ - - def __init__( - self, - model: str = "mixedbread-ai/mxbai-rerank-large-v1", - device: str | None = None, - **kwargs, - ): - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = device - self.model = CrossEncoder(model, device=device, **kwargs) - - def __call__(self, query: str, docs: list[str], k: int) -> list[int]: - results = self.model.rank(query, docs, top_k=k) - results = [result["corpus_id"] for result in results] - return results diff --git a/lotus/models/cross_encoder_reranker.py b/lotus/models/cross_encoder_reranker.py new file mode 100644 index 00000000..65827ce2 --- /dev/null +++ b/lotus/models/cross_encoder_reranker.py @@ -0,0 +1,28 @@ +from sentence_transformers import CrossEncoder + +from lotus.models.reranker import Reranker +from lotus.types import RerankerOutput + + +class CrossEncoderReranker(Reranker): + """CrossEncoder reranker model. + + Args: + model (str): The name of the reranker model to use. + device (str): What device to keep the model on. + max_batch_size (int): The maximum batch size to use for the model. + """ + + def __init__( + self, + model: str = "mixedbread-ai/mxbai-rerank-large-v1", + device: str | None = None, + max_batch_size: int = 64, + ): + self.max_batch_size: int = max_batch_size + self.model = CrossEncoder(model, device=device) # type: ignore # CrossEncoder has wrong type stubs + + def __call__(self, query: str, docs: list[str], K: int) -> RerankerOutput: + results = self.model.rank(query, docs, top_k=K, batch_size=self.max_batch_size) + indices = [int(result["corpus_id"]) for result in results] + return RerankerOutput(indices=indices) diff --git a/lotus/models/e5_model.py b/lotus/models/e5_model.py deleted file mode 100644 index 310a2428..00000000 --- a/lotus/models/e5_model.py +++ /dev/null @@ -1,152 +0,0 @@ -import os -import pickle -from typing import Any - -import numpy as np -import torch -import torch.nn.functional as F -from tqdm import tqdm -from transformers import AutoModel, AutoTokenizer - -from lotus.models.rm import RM - - -class E5Model(RM): - """E5 retriever model""" - - def __init__(self, model: str = "intfloat/e5-base-v2", device: str | None = None, **kwargs: dict[str, Any]) -> None: - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = device - self.tokenizer = AutoTokenizer.from_pretrained(model) - self.model = AutoModel.from_pretrained(model).to(self.device) - self.faiss_index = None - self.index_dir: str | None = None - self.docs: list[str] | None = None - self.kwargs: dict[str, Any] = {"normalize": True, "index_type": "Flat", **kwargs} - self.batch_size: int = 100 - self.vecs: np.ndarray[Any, np.dtype[np.float32]] | None = None - - import faiss - - self.faiss = faiss - - def average_pool(self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - """Perform average pooling over the last hidden state. - - Args: - last_hidden_states: Hidden states from the model's last layer - attention_mask: Attention mask. - - Returns: - Average pool over the last hidden state. - """ - - last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) - return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] - - def embed(self, docs: list[str], **kwargs: dict[str, Any]) -> np.ndarray[Any, np.dtype[np.float32]]: - """Run the embedding model. - - Args: - docs: A list of documents to embed. - - Returns: - Embeddings of the documents. - """ - - kwargs = {**self.kwargs, **dict(kwargs)} - - batch_size = kwargs.get("batch_size", self.batch_size) - assert isinstance(batch_size, int), "batch_size must be an integer" - - # Calculating the embedding dimension - total_docs = len(docs) - first_batch = self.tokenizer(docs[:1], return_tensors="pt", padding=True, truncation=True).to(self.device) - embed_dim = self.model(**first_batch).last_hidden_state.size(-1) - - # Pre-allocate a tensor for all embeddings - embeddings = torch.empty((total_docs, embed_dim), device=self.device) - # Processing batches - with torch.inference_mode(): # Slightly faster than torch.no_grad() for inference - for i, batch_start in enumerate(tqdm(range(0, total_docs, batch_size))): - batch = docs[batch_start : batch_start + batch_size] - batch_dict = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) - outputs = self.model(**batch_dict) - batch_embeddings = self.average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) - embeddings[batch_start : batch_start + batch_size] = batch_embeddings - if kwargs["normalize"]: - embeddings = F.normalize(embeddings, p=2, dim=1) - - return embeddings.numpy(force=True) - - def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: - # Make index directory - os.makedirs(index_dir, exist_ok=True) - - # Get document embeddings - kwargs = {**self.kwargs, **kwargs} - embeddings = self.embed(docs, **kwargs) - d = embeddings.shape[1] - index = self.faiss.index_factory(d, kwargs["index_type"], self.faiss.METRIC_INNER_PRODUCT) - index.add(embeddings) - - # Store index and documents - self.faiss.write_index(index, f"{index_dir}/index") - with open(f"{index_dir}/docs", "wb") as fp: - pickle.dump(docs, fp) - with open(f"{index_dir}/vecs", "wb") as fp: - pickle.dump(embeddings, fp) - self.faiss_index = index - self.docs = docs - self.index_dir = index_dir - self.vecs = embeddings - - def load_index(self, index_dir: str) -> None: - self.index_dir = index_dir - self.faiss_index = self.faiss.read_index(f"{index_dir}/index") - with open(f"{index_dir}/docs", "rb") as fp: - self.docs = pickle.load(fp) - with open(f"{index_dir}/vecs", "rb") as fp: - self.vecs = pickle.load(fp) - - @classmethod - def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list[np.ndarray[Any, np.dtype[np.float32]]]: - with open(f"{index_dir}/vecs", "rb") as fp: - vecs: np.ndarray[Any, np.dtype[np.float32]] = pickle.load(fp) - - return vecs[ids] - - def load_vecs(self, index_dir: str, ids: list[int]) -> list: - """loads vectors to the rm and returns them - Args: - index_dir (str): Directory of the index. - ids (list[int]): The ids of the vectors to retrieve - - Returns: - The vectors matching the specified ids. - """ - - if self.vecs is None: - with open(f"{index_dir}/vecs", "rb") as fp: - self.vecs = pickle.load(fp) - - return self.vecs[ids] - - def __call__( - self, - queries: str | list[str] | list[list[float]], - k: int, - **kwargs: dict[str, Any], - ) -> tuple[list[float], list[int]]: - if isinstance(queries, str): - queries = [queries] - - if isinstance(queries[0], str): - embedded_queries = self.embed(queries, **kwargs) - else: - embedded_queries = queries - - distances, indicies = self.faiss_index.search(embedded_queries, k) - - return distances, indicies diff --git a/lotus/models/faiss_rm.py b/lotus/models/faiss_rm.py new file mode 100644 index 00000000..205129df --- /dev/null +++ b/lotus/models/faiss_rm.py @@ -0,0 +1,62 @@ +import os +import pickle +from abc import abstractmethod +from typing import Any + +import faiss +import numpy as np +from numpy.typing import NDArray + +from lotus.models.rm import RM +from lotus.types import RMOutput + + +class FaissRM(RM): + def __init__(self, factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODUCT): + super().__init__() + self.factory_string = factory_string + self.metric = metric + self.index_dir: str | None = None + self.faiss_index: faiss.Index | None = None + self.vecs: NDArray[np.float64] | None = None + + def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: + vecs = self._embed(docs) + self.faiss_index = faiss.index_factory(vecs.shape[1], self.factory_string, self.metric) + self.faiss_index.add(vecs) + self.index_dir = index_dir + + os.makedirs(index_dir, exist_ok=True) + with open(f"{index_dir}/vecs", "wb") as fp: + pickle.dump(vecs, fp) + faiss.write_index(self.faiss_index, f"{index_dir}/index") + + def load_index(self, index_dir: str) -> None: + self.index_dir = index_dir + self.faiss_index = faiss.read_index(f"{index_dir}/index") + with open(f"{index_dir}/vecs", "rb") as fp: + self.vecs = pickle.load(fp) + + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: + with open(f"{index_dir}/vecs", "rb") as fp: + vecs: NDArray[np.float64] = pickle.load(fp) + return vecs[ids] + + def __call__(self, queries: str | list[str] | NDArray[np.float64], K: int, **kwargs: dict[str, Any]) -> RMOutput: + if isinstance(queries, str): + queries = [queries] + + if isinstance(queries[0], str): + embedded_queries = self._embed([str(q) for q in queries]) + else: + embedded_queries = np.asarray(queries, dtype=np.float32) + + if self.faiss_index is None: + raise ValueError("Index not loaded") + + distances, indices = self.faiss_index.search(embedded_queries, K) + return RMOutput(distances=distances, indices=indices) + + @abstractmethod + def _embed(self, docs: list[str]) -> NDArray[np.float64]: + pass diff --git a/lotus/models/litellm_rm.py b/lotus/models/litellm_rm.py new file mode 100644 index 00000000..cadb4cf5 --- /dev/null +++ b/lotus/models/litellm_rm.py @@ -0,0 +1,29 @@ +import faiss +import numpy as np +from litellm import embedding +from litellm.types.utils import EmbeddingResponse +from numpy.typing import NDArray + +from lotus.models.faiss_rm import FaissRM + + +class LiteLLMRM(FaissRM): + def __init__( + self, + model: str = "text-embedding-3-small", + max_batch_size: int = 64, + factory_string: str = "Flat", + metric=faiss.METRIC_INNER_PRODUCT, + ): + super().__init__(factory_string, metric) + self.model: str = model + self.max_batch_size: int = max_batch_size + + def _embed(self, docs: list[str]) -> NDArray[np.float64]: + all_embeddings = [] + for i in range(0, len(docs), self.max_batch_size): + batch = docs[i : i + self.max_batch_size] + response: EmbeddingResponse = embedding(model=self.model, input=batch) + embeddings = np.array([d["embedding"] for d in response.data]) + all_embeddings.append(embeddings) + return np.vstack(all_embeddings) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index d39aea22..ee03a444 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -1,62 +1,162 @@ -from abc import ABC, abstractmethod from typing import Any +import litellm +import numpy as np +from litellm import batch_completion, completion_cost +from litellm.types.utils import ChatCompletionTokenLogprob, Choices, ModelResponse +from litellm.utils import token_counter +from openai import OpenAIError +from tokenizers import Tokenizer -class LM(ABC): - """Abstract class for language models.""" - - def _init__(self): - pass - - @abstractmethod - def count_tokens(self, prompt: str | list) -> int: - """ - Counts the number of tokens in the given prompt. - - Args: - prompt (str | list): The prompt to count tokens for. This can be a string or a list of messages. - - Returns: - int: The number of tokens in the prompt. - """ - pass - - def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]: - """ - Formats the logprobs for the cascade. - - Args: - logprobs (list): The logprobs to format. - - Returns: - tuple[list[list[str]], list[list[float]]]: A tuple containing the tokens and their corresponding confidences. - """ - pass - - @abstractmethod - def __call__( - self, messages_batch: list | list[list], **kwargs: dict[str, Any] - ) -> list[str] | tuple[list[str], list[dict[str, Any]]]: - """Invoke the LLM. - - Args: - messages_batch (list | list[list]): Either one prompt or a list of prompts in message format. - kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters. - - Returns: - list[str] | tuple[list[str], list[dict[str, Any]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - pass - - @property - @abstractmethod - def max_ctx_len(self) -> int: - """The maximum context length of the LLM.""" - pass - - @property - @abstractmethod - def max_tokens(self) -> int: - """The maximum number of tokens that can be generated by the LLM.""" - pass +import lotus +from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade + + +class LM: + def __init__( + self, + model: str = "gpt-4o-mini", + temperature: float = 0.0, + max_ctx_len: int = 128000, + max_tokens: int = 512, + max_batch_size: int = 64, + tokenizer: Tokenizer | None = None, + **kwargs: dict[str, Any], + ): + self.model = model + self.max_ctx_len = max_ctx_len + self.max_tokens = max_tokens + self.max_batch_size = max_batch_size + self.tokenizer = tokenizer + self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) + + self.stats: LMStats = LMStats() + + def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput: + all_kwargs = {**self.kwargs, **kwargs} + + # Set top_logprobs if logprobs requested + if all_kwargs.get("logprobs", False): + all_kwargs["top_logprobs"] = all_kwargs.get("top_logprobs", 10) + + all_responses: list[ModelResponse] = [] + for i in range(0, len(messages), self.max_batch_size): + batch = messages[i : i + self.max_batch_size] + responses: list[ModelResponse] = batch_completion( + self.model, + batch, + drop_params=True, + **all_kwargs, # type: ignore + ) + all_responses.extend(responses) + + # throw errors, if any + for resp in all_responses: + if isinstance(resp, OpenAIError): + raise resp + + outputs = [self._get_top_choice(resp) for resp in all_responses] + logprobs = ( + [self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None + ) + + for resp in all_responses: + self._update_stats(resp) + + return LMOutput(outputs=outputs, logprobs=logprobs) + + def _update_stats(self, response: ModelResponse): + if not hasattr(response, "usage"): + return + + self.stats.total_usage.prompt_tokens += response.usage.prompt_tokens + self.stats.total_usage.completion_tokens += response.usage.completion_tokens + self.stats.total_usage.total_tokens += response.usage.total_tokens + + try: + self.stats.total_usage.total_cost += completion_cost(completion_response=response) + except litellm.exceptions.NotFoundError as e: + # Sometimes the model's pricing information is not available + lotus.logger.debug(f"Error updating completion cost: {e}") + + def _get_top_choice(self, response: ModelResponse) -> str: + choice = response.choices[0] + assert isinstance(choice, Choices) + if choice.message.content is None: + raise ValueError(f"No content in response: {response}") + return choice.message.content + + def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompletionTokenLogprob]: + choice = response.choices[0] + assert isinstance(choice, Choices) + logprobs = choice.logprobs["content"] + return [ChatCompletionTokenLogprob(**logprob) for logprob in logprobs] + + def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade: + all_tokens = [] + all_confidences = [] + for resp_logprobs in logprobs: + tokens = [logprob.token for logprob in resp_logprobs] + confidences = [np.exp(logprob.logprob) for logprob in resp_logprobs] + all_tokens.append(tokens) + all_confidences.append(confidences) + return LogprobsForCascade(tokens=all_tokens, confidences=all_confidences) + + def format_logprobs_for_filter_cascade( + self, logprobs: list[list[ChatCompletionTokenLogprob]] + ) -> LogprobsForFilterCascade: + # Get base cascade format first + base_cascade = self.format_logprobs_for_cascade(logprobs) + all_true_probs = [] + + def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None: + if "True" in token_probs and "False" in token_probs: + true_prob = token_probs["True"] + false_prob = token_probs["False"] + return true_prob / (true_prob + false_prob) + return None + + # Get true probabilities for filter cascade + for resp_idx, response_logprobs in enumerate(logprobs): + true_prob = None + for logprob in response_logprobs: + token_probs = {top.token: np.exp(top.logprob) for top in logprob.top_logprobs} + true_prob = get_normalized_true_prob(token_probs) + if true_prob is not None: + break + + # Default to 1 if "True" in tokens, 0 if not + if true_prob is None: + true_prob = 1 if "True" in base_cascade.tokens[resp_idx] else 0 + + all_true_probs.append(true_prob) + + return LogprobsForFilterCascade( + tokens=base_cascade.tokens, confidences=base_cascade.confidences, true_probs=all_true_probs + ) + + def count_tokens(self, messages: list[dict[str, str]] | str) -> int: + """Count tokens in messages using either custom tokenizer or model's default tokenizer""" + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + custom_tokenizer: dict[str, Any] | None = None + if self.tokenizer: + custom_tokenizer = dict(type="huggingface_tokenizer", tokenizer=self.tokenizer) + + return token_counter( + custom_tokenizer=custom_tokenizer, + model=self.model, + messages=messages, + ) + + def print_total_usage(self): + print(f"Total cost: ${self.stats.total_usage.total_cost:.6f}") + print(f"Total prompt tokens: {self.stats.total_usage.prompt_tokens}") + print(f"Total completion tokens: {self.stats.total_usage.completion_tokens}") + print(f"Total tokens: {self.stats.total_usage.total_tokens}") + + def reset_stats(self): + self.stats = LMStats( + total_usage=LMStats.TotalUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0, total_cost=0.0) + ) diff --git a/lotus/models/openai_model.py b/lotus/models/openai_model.py deleted file mode 100644 index 57fb20eb..00000000 --- a/lotus/models/openai_model.py +++ /dev/null @@ -1,325 +0,0 @@ -import os -import threading -from typing import Any - -import backoff -import numpy as np -import openai -import tiktoken -from openai import OpenAI -from transformers import AutoTokenizer - -import lotus -from lotus.models.lm import LM - -ERRORS = (openai.RateLimitError, openai.APIError) - - -def backoff_hdlr(details): - """Handler from https://pypi.org/project/backoff/""" - print( - "Backing off {wait:0.1f} seconds after {tries} tries " - "calling function {target} with kwargs " - "{kwargs}".format(**details), - ) - - -class OpenAIModel(LM): - """Wrapper around OpenAI, Databricks, and vLLM OpenAI server - - Args: - model (str): The name of the model to use. - api_key (str | None): An API key (e.g. from OpenAI or Databricks). - api_base (str | None): The endpoint of the server. - provider (str): Either openai, dbrx, or vllm. - max_batch_size (int): The maximum batch size for the model. - max_ctx_len (int): The maximum context length for the model. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters. - """ - - def __init__( - self, - model: str = "gpt-4o-mini", - hf_name: str | None = None, - api_key: str | None = None, - api_base: str | None = None, - provider: str = "openai", - max_batch_size: int = 64, - max_ctx_len: int = 4096, - **kwargs: dict[str, Any], - ): - super().__init__() - self.provider = provider - self.use_chat = provider in ["openai", "dbrx", "ollama"] - self.max_batch_size = max_batch_size - self.hf_name = hf_name if hf_name is not None else model - self.__dict__["max_ctx_len"] = max_ctx_len - - self.kwargs = { - "model": model, - "temperature": 0.0, - "max_tokens": 512, - "top_p": 1, - "n": 1, - **kwargs, - } - - api_key = api_key or os.environ.get("OPENAI_API_KEY", "None") - self.client = OpenAI(api_key=api_key if api_key else "None", base_url=api_base) - - # TODO: Refactor this - if self.provider == "openai": - self.tokenizer = tiktoken.encoding_for_model(model) - else: - self.tokenizer = AutoTokenizer.from_pretrained(self.hf_name) - - def handle_chat_request( - self, messages: list, **kwargs: dict[str, Any] - ) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle single chat request to OpenAI server. - - Args: - messages_batch (list): A prompt in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - list | tuple[list[list[str]], list[list[float]]]: A list of outputs for each prompt in the batch (just one in this case). If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - if kwargs.get("logprobs", False): - kwargs["top_logprobs"] = 10 - - kwargs = {**self.kwargs, **kwargs} - kwargs["messages"] = messages - response = self.chat_request(**kwargs) - - choices = response["choices"] - completions = [c["message"]["content"] for c in choices] - - if kwargs.get("logprobs", False): - logprobs = [c["logprobs"] for c in choices] - return completions, logprobs - - return completions - - def handle_completion_request( - self, messages: list, **kwargs: dict[str, Any] - ) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle a potentially batched completions request to OpenAI server. - - Args: - messages_batch (list): A list of prompts in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - list | tuple[list[list[str]], list[list[float]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - if not isinstance(messages[0], list): - prompt = [self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)] - else: - prompt = [ - self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) - for message in messages - ] - - kwargs = {**self.kwargs, **kwargs} - kwargs["prompt"] = prompt - if kwargs.get("logprobs", False): - kwargs["logprobs"] = 10 - response = self.completion_request(**kwargs) - - choices = response["choices"] - completions = [c["text"] for c in choices] - - if kwargs.get("logprobs", False): - logprobs = [c["logprobs"] for c in choices] - return completions, logprobs - - return completions - - @backoff.on_exception( - backoff.expo, - ERRORS, - max_time=1000, - on_backoff=backoff_hdlr, - ) - def request(self, messages: list, **kwargs: dict[str, Any]) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle single request to OpenAI server. Decides whether chat or completion endpoint is necessary. - - Args: - messages_batch (list): A prompt in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - A list of text outputs for each prompt in the batch (just one in this case). - If logprobs is specified in the keyword arguments, hen a list of logprobs is also returned (also of size one). - """ - if self.use_chat: - return self.handle_chat_request(messages, **kwargs) - else: - return self.handle_completion_request(messages, **kwargs) - - def batched_chat_request( - self, messages_batch: list, **kwargs: dict[str, Any] - ) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle batched chat request to OpenAI server. - - Args: - messages_batch (list): Either one prompt or a list of prompts in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters. - - Returns: - list | tuple[list[list[str]], list[list[float]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - - batch_size = len(messages_batch) - text_ret = [None] * batch_size - logprobs_ret = [None] * batch_size - threads = [] - - def thread_function(idx, messages, kwargs): - text = self(messages, **kwargs) - if kwargs.get("logprobs", False): - text, logprobs = text - logprobs_ret[idx] = logprobs[0] - text_ret[idx] = text[0] - - for idx, messages in enumerate(messages_batch): - thread = threading.Thread(target=thread_function, args=(idx, messages, kwargs)) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() - - if kwargs.get("logprobs", False): - return text_ret, logprobs_ret - - return text_ret - - def __call__( - self, messages_batch: list | list[list], **kwargs: dict[str, Any] - ) -> list[str] | tuple[list[str], list[dict[str, Any]]]: - lotus.logger.debug(f"OpenAIModel.__call__ messages_batch: {messages_batch}") - lotus.logger.debug(f"OpenAIModel.__call__ kwargs: {kwargs}") - # Bakes max batch size into model call. # TODO: Figure out less hacky way to do this. - if isinstance(messages_batch[0], list) and len(messages_batch) > self.max_batch_size: - text_ret = [] - logprobs_ret = [] - for i in range(0, len(messages_batch), self.max_batch_size): - res = self(messages_batch[i : i + self.max_batch_size], **kwargs) - if kwargs.get("logprobs", False): - text, logprobs = res - logprobs_ret.extend(logprobs) - else: - text = res - text_ret.extend(text) - - if kwargs.get("logprobs", False): - return text_ret, logprobs_ret - return text_ret - - if self.use_chat and isinstance(messages_batch[0], list): - return self.batched_chat_request(messages_batch, **kwargs) - - return self.request(messages_batch, **kwargs) - - def count_tokens(self, prompt: str | list) -> int: - if isinstance(prompt, str): - if self.provider != "openai": - return len(self.tokenizer(prompt)["input_ids"]) - - return len(self.tokenizer.encode(prompt)) - else: - if self.provider != "openai": - return len(self.tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True)) - - return sum(len(self.tokenizer.encode(message["content"])) for message in prompt) - - def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]: - all_tokens = [] - all_confidences = [] - for idx in range(len(logprobs)): - if self.provider == "vllm": - tokens = logprobs[idx]["tokens"] - confidences = np.exp(logprobs[idx]["token_logprobs"]) - elif self.provider == "openai": - content = logprobs[idx]["content"] - tokens = [content[t_idx]["token"] for t_idx in range(len(content))] - confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))]) - all_tokens.append(tokens) - all_confidences.append(confidences) - - return all_tokens, all_confidences - - def format_logprobs_for_filter_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]: - all_tokens = [] - all_confidences = [] - all_true_probs = [] - for idx in range(len(logprobs)): - if self.provider == "vllm": - tokens = logprobs[idx]["tokens"] - confidences = np.exp(logprobs[idx]["token_logprobs"]) - top_logprobs = logprobs[idx]["top_logprobs"][0] - if 'True' in top_logprobs and 'False' in top_logprobs: - true_prob = np.exp(top_logprobs['True']) - false_prob = np.exp(top_logprobs['False']) - all_true_probs.append(true_prob / (true_prob + false_prob)) - else: - all_true_probs.append(1 if 'True' in top_logprobs else 0) - - elif self.provider == "openai": - content = logprobs[idx]["content"] - tokens = [content[t_idx]["token"] for t_idx in range(len(content))] - confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))]) - top_logprobs = {x["token"]:x["logprob"] for x in content[0]["top_logprobs"]} - - true_prob, false_prob = 0, 0 - if top_logprobs and 'True' in top_logprobs and 'False' in top_logprobs: - true_prob = np.exp(top_logprobs['True']) - false_prob = np.exp(top_logprobs['False']) - all_true_probs.append(true_prob / (true_prob + false_prob)) - else: - all_true_probs.append(1 if 'True' in top_logprobs else 0) - - all_tokens.append(tokens) - all_confidences.append(confidences) - - return all_tokens, all_confidences, all_true_probs - - def chat_request(self, **kwargs: dict[str, Any]) -> dict[str, Any]: - """Send chat request to OpenAI server. - - Args: - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - dict: OpenAI chat completion response. - """ - return self.client.chat.completions.create(**kwargs).model_dump() - - def completion_request(self, **kwargs: dict[str, Any]) -> dict[str, Any]: - """Send completion request to OpenAI server. - - Args: - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - dict: OpenAI completion response. - """ - return self.client.completions.create(**kwargs).model_dump() - - @property - def max_tokens(self) -> int: - return self.kwargs["max_tokens"] - - @property - def max_ctx_len(self) -> int: - return self.__dict__["max_ctx_len"] diff --git a/lotus/models/reranker.py b/lotus/models/reranker.py index 736656f4..a7fd5996 100644 --- a/lotus/models/reranker.py +++ b/lotus/models/reranker.py @@ -1,22 +1,24 @@ from abc import ABC, abstractmethod +from lotus.types import RerankerOutput + class Reranker(ABC): """Abstract class for reranker models.""" - def _init__(self): + def __init__(self) -> None: pass @abstractmethod - def __call__(self, query: str, docs: list[str], k: int) -> list[int]: + def __call__(self, query: str, docs: list[str], K: int) -> RerankerOutput: """Invoke the reranker. Args: query (str): The query to use for reranking. docs (list[str]): A list of documents to rerank. - k (int): The number of documents to keep after reranking. + K (int): The number of documents to keep after reranking. Returns: - list[int]: The indicies of the reranked documents. + RerankerOutput: The indicies of the reranked documents. """ pass diff --git a/lotus/models/rm.py b/lotus/models/rm.py index ed7b70e2..330d7cd5 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -1,12 +1,17 @@ from abc import ABC, abstractmethod from typing import Any +import numpy as np +from numpy.typing import NDArray + +from lotus.types import RMOutput + class RM(ABC): """Abstract class for retriever models.""" - def _init__(self): - pass + def __init__(self) -> None: + self.index_dir: str | None = None @abstractmethod def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: @@ -28,7 +33,7 @@ def load_index(self, index_dir: str) -> None: pass @abstractmethod - def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list: + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: """Get the vectors from the index. Args: @@ -36,7 +41,7 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list: ids (list[int]): The ids of the vectors to retrieve Returns: - list: The vectors matching the specified ids. + NDArray[np.float64]: The vectors matching the specified ids. """ pass @@ -44,18 +49,18 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list: @abstractmethod def __call__( self, - queries: str | list[str] | list[list[float]], - k: int, + queries: str | list[str] | NDArray[np.float64], + K: int, **kwargs: dict[str, Any], - ) -> tuple[list[float], list[int]]: + ) -> RMOutput: """Run top-k search on the index. Args: - queries (str | list[str] | list[list[float]]): Either a query or a list of queries or a 2D FP32 array. - k (int): The k to use for top-k search. + queries (str | list[str] | NDArray[np.float64]): Either a query or a list of queries or a 2D FP32 array. + K (int): The k to use for top-k search. **kwargs (dict[str, Any]): Additional keyword arguments. Returns: - tuple[list[float], list[int]]: A tuple of (distances, indices) of the top-k vectors + RMOutput: An RMOutput object containing the distances and indices of the top-k vectors. """ pass diff --git a/lotus/models/sentence_transformers_rm.py b/lotus/models/sentence_transformers_rm.py new file mode 100644 index 00000000..bbcd36f9 --- /dev/null +++ b/lotus/models/sentence_transformers_rm.py @@ -0,0 +1,36 @@ +import faiss +import numpy as np +import torch +from numpy.typing import NDArray +from sentence_transformers import SentenceTransformer + +from lotus.models.faiss_rm import FaissRM + + +class SentenceTransformersRM(FaissRM): + def __init__( + self, + model: str = "intfloat/e5-base-v2", + max_batch_size: int = 64, + normalize_embeddings: bool = True, + device: str | None = None, + factory_string: str = "Flat", + metric=faiss.METRIC_INNER_PRODUCT, + ): + super().__init__(factory_string, metric) + self.model: str = model + self.max_batch_size: int = max_batch_size + self.normalize_embeddings: bool = normalize_embeddings + self.transformer: SentenceTransformer = SentenceTransformer(model, device=device) + + def _embed(self, docs: list[str]) -> NDArray[np.float64]: + all_embeddings = [] + for i in range(0, len(docs), self.max_batch_size): + batch = docs[i : i + self.max_batch_size] + torch_embeddings = self.transformer.encode( + batch, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings + ) + assert isinstance(torch_embeddings, torch.Tensor) + cpu_embeddings = torch_embeddings.cpu().numpy() + all_embeddings.append(cpu_embeddings) + return np.vstack(all_embeddings) diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 0e8eabdf..088ba9be 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -1,4 +1,5 @@ import numpy as np +from numpy.typing import NDArray import lotus @@ -6,53 +7,57 @@ def importance_sampling( proxy_scores: list[float], sample_percentage: float, -) -> tuple[list[int], list[float]]: +) -> tuple[NDArray[np.int64], NDArray[np.float64]]: """Uses importance sampling and returns the list of indices from which to learn cascade thresholds.""" w = np.sqrt(proxy_scores) is_weight = lotus.settings.cascade_is_weight w = is_weight * w / np.sum(w) + (1 - is_weight) * np.ones((len(proxy_scores))) / len(proxy_scores) indices = np.arange(len(proxy_scores)) - sample_size = (int) (sample_percentage * len(proxy_scores)) + sample_size = (int)(sample_percentage * len(proxy_scores)) sample_indices = np.random.choice(indices, sample_size, p=w) - correction_factors = (1/len(proxy_scores)) / w + correction_factors = (1 / len(proxy_scores)) / w return sample_indices, correction_factors + def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]: """Transforms true probabilities to calibrate LLM proxies.""" num_quantiles = lotus.settings.cascade_num_calibration_quantiles quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1)) - true_probs = ((np.digitize(true_probs, quantile_values) - 1) / num_quantiles) - true_probs = np.clip(true_probs, 0, 1) + true_probs = (np.digitize(true_probs, quantile_values) - 1) / num_quantiles + true_probs = list(np.clip(true_probs, 0, 1)) return true_probs + def learn_cascade_thresholds( proxy_scores: list[float], - oracle_outputs: list[float], - sample_correction_factors: list[float], + oracle_outputs: list[bool], + sample_correction_factors: NDArray[np.float64], recall_target: float, precision_target: float, - delta: float + delta: float, ) -> tuple[tuple[float, float], int]: - """Learns cascade thresholds given targets and proxy scores, + """Learns cascade thresholds given targets and proxy scores, oracle outputs over the sample, and correction factors for the sample.""" - def UB(mean, std_dev, s, delta): - return mean + (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5) + def UB(mean: float, std_dev: float, s: int, delta: float) -> float: + return float(mean + (std_dev / (s**0.5)) * ((2 * np.log(1 / delta)) ** 0.5)) - def LB(mean, std_dev, s, delta): - return mean - (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5) + def LB(mean: float, std_dev: float, s: int, delta: float) -> float: + return float(mean - (std_dev / (s**0.5)) * ((2 * np.log(1 / delta)) ** 0.5)) - def recall(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: + def recall(pos_threshold: float, neg_threshold: float, sorted_pairs: list[tuple[float, bool, float]]) -> float: helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold] sent_to_oracle = [x for x in sorted_pairs if x[0] < pos_threshold and x[0] > neg_threshold] total_correct = sum(pair[1] * pair[2] for pair in sorted_pairs) - recall = (sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + sum(x[1] * x[2] for x in sent_to_oracle)) / total_correct + recall = ( + sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + sum(x[1] * x[2] for x in sent_to_oracle) + ) / total_correct return recall - def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: + def precision(pos_threshold: float, neg_threshold: float, sorted_pairs: list[tuple[float, bool, float]]) -> float: helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold] sent_to_oracle = [x for x in sorted_pairs if pos_threshold > x[0] > neg_threshold] oracle_positive = sum(x[1] for x in sent_to_oracle) @@ -66,46 +71,48 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: sorted_pairs = sorted(paired_data, key=lambda x: x[0], reverse=True) sample_size = len(sorted_pairs) - best_combination = (1,0) # initial tau_+, tau_- + best_combination = (1.0, 0.0) # initial tau_+, tau_- # Find tau_negative based on recall - tau_neg_0 = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= recall_target) + tau_neg_0 = max( + x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= recall_target + ) best_combination = (best_combination[0], tau_neg_0) # Do a statistical correction to get a new target recall Z1 = [int(x[1]) * x[2] for x in sorted_pairs if x[0] >= best_combination[1]] Z2 = [int(x[1]) * x[2] for x in sorted_pairs if x[0] < best_combination[1]] - mean_z1 = np.mean(Z1) if Z1 else 0 - std_z1 = np.std(Z1) if Z1 else 0 - mean_z2 = np.mean(Z2) if Z2 else 0 - std_z2 = np.std(Z2) if Z2 else 0 + mean_z1 = float(np.mean(Z1)) if Z1 else 0.0 + std_z1 = float(np.std(Z1)) if Z1 else 0.0 + mean_z2 = float(np.mean(Z2)) if Z2 else 0.0 + std_z2 = float(np.std(Z2)) if Z2 else 0.0 - corrected_recall_target = UB(mean_z1, std_z1, sample_size, delta/2)/(UB(mean_z1, std_z1, sample_size, delta/2) + LB(mean_z2, std_z2, sample_size, delta/2)) + corrected_recall_target = UB(mean_z1, std_z1, sample_size, delta / 2) / ( + UB(mean_z1, std_z1, sample_size, delta / 2) + LB(mean_z2, std_z2, sample_size, delta / 2) + ) corrected_recall_target = min(1, corrected_recall_target) - tau_neg_prime = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= corrected_recall_target) + tau_neg_prime = max( + x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= corrected_recall_target + ) best_combination = (best_combination[0], tau_neg_prime) # Do a statistical correction to get a target satisfying precision - candidate_thresholds = [1] + candidate_thresholds: list[float] = [1.0] for pair in sorted_pairs: possible_threshold = pair[0] Z = [int(x[1]) for x in sorted_pairs if x[0] >= possible_threshold] - mean_z = np.mean(Z) if Z else 0 - std_z = np.std(Z) if Z else 0 - p_l = LB(mean_z, std_z, len(Z), delta/len(sorted_pairs)) + mean_z = float(np.mean(Z)) if Z else 0.0 + std_z = float(np.std(Z)) if Z else 0.0 + p_l = LB(mean_z, std_z, len(Z), delta / len(sorted_pairs)) if p_l > precision_target: candidate_thresholds.append(possible_threshold) best_combination = (max(best_combination[1], min(candidate_thresholds)), best_combination[1]) oracle_calls = sum(1 for x in proxy_scores if best_combination[0] > x > best_combination[1]) - no_correction_sorted_pairs = [tup[:2] + (1,) for tup in sorted_pairs] + no_correction_sorted_pairs = [tup[:2] + (1.0,) for tup in sorted_pairs] lotus.logger.info(f"Sample recall: {recall(best_combination[0], best_combination[1], no_correction_sorted_pairs)}") lotus.logger.info(f"Sample precision: {precision(best_combination[0], best_combination[1], sorted_pairs)}") return best_combination, oracle_calls - -def calibrate_sem_sim_join(true_score: list[float]) -> list[float]: - true_score = np.clip(true_score, 0, 1) - return true_score \ No newline at end of file diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index 6d77f8fe..6fd9e8b0 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -2,9 +2,9 @@ import pandas as pd -import lotus +import lotus.models from lotus.templates import task_instructions -from lotus.types import SemanticAggOutput +from lotus.types import LMOutput, SemanticAggOutput def sem_agg( @@ -108,13 +108,9 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str: lotus.logger.debug(f"Prompt added to batch: {prompt}") batch.append([{"role": "user", "content": prompt}]) new_partition_ids.append(cur_partition_id) - result = model(batch) + lm_output: LMOutput = model(batch) - # TODO: this is a weird hack for model typing - if isinstance(result, tuple): - summaries, _ = result - else: - summaries = result + summaries = lm_output.outputs partition_ids = new_partition_ids new_partition_ids = [] diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 6336b196..82e82a98 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -4,7 +4,7 @@ import lotus from lotus.templates import task_instructions -from lotus.types import SemanticExtractOutput, SemanticExtractPostprocessOutput +from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput from .postprocessors import extract_postprocess @@ -36,15 +36,11 @@ def sem_extract( inputs.append(prompt) # call model - raw_outputs = model(inputs) - if isinstance(raw_outputs, tuple): - raw_outputs, _ = raw_outputs - else: - assert isinstance(raw_outputs, list) + lm_output: LMOutput = model(inputs) # post process results - postprocess_output = postprocessor(raw_outputs) - lotus.logger.debug(f"raw_outputs: {raw_outputs}") + postprocess_output = postprocessor(lm_output.outputs) + lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"quotes: {postprocess_output.quotes}") diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index aafd6aa6..00be6789 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -1,10 +1,12 @@ from typing import Any +import numpy as np import pandas as pd +from numpy.typing import NDArray import lotus from lotus.templates import task_instructions -from lotus.types import SemanticFilterOutput +from lotus.types import LMOutput, LogprobsForFilterCascade, SemanticFilterOutput from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds from .postprocessors import filter_postprocess @@ -45,23 +47,20 @@ def sem_filter( lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) kwargs: dict[str, Any] = {"logprobs": logprobs} - res = model(inputs, **kwargs) - if logprobs: - assert isinstance(res, tuple) - raw_outputs, raw_logprobs = res - else: - assert isinstance(res, list) - raw_outputs = res - - postprocess_output = filter_postprocess(raw_outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"]) + lm_output: LMOutput = model(inputs, **kwargs) + + postprocess_output = filter_postprocess( + lm_output.outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"] + ) lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") - return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=raw_logprobs if logprobs else None) + return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=lm_output.logprobs if logprobs else None) + def learn_filter_cascade_thresholds( - sample_df_txt: str, + sample_df_txt: list[str], lm: lotus.models.LM, formatted_usr_instr: str, default: bool, @@ -69,14 +68,14 @@ def learn_filter_cascade_thresholds( precision_target: float, delta: float, helper_true_probs: list[float], - sample_correction_factors: list[float], - examples_df_txt: str | None = None, - examples_answers: str | None = None, - cot_reasoning: list | None = None, + sample_correction_factors: NDArray[np.float64], + examples_df_txt: list[str] | None = None, + examples_answers: list[bool] | None = None, + cot_reasoning: list[str] | None = None, strategy: str | None = None, ) -> tuple[float, float]: - """Automatically learns the cascade thresholds for a cascade - filter given a sample of data and doing a search across threshold + """Automatically learns the cascade thresholds for a cascade + filter given a sample of data and doing a search across threshold to see what threshold gives the best accuracy.""" try: @@ -97,7 +96,7 @@ def learn_filter_cascade_thresholds( sample_correction_factors=sample_correction_factors, recall_target=recall_target, precision_target=precision_target, - delta=delta + delta=delta, ) lotus.logger.info(f"Learned cascade thresholds: {best_combination}") @@ -105,7 +104,8 @@ def learn_filter_cascade_thresholds( except Exception as e: lotus.logger.error(f"Error while learning filter cascade thresholds: {e}") - return None + raise e + @pd.api.extensions.register_dataframe_accessor("sem_filter") class SemFilterDataframe: @@ -198,14 +198,16 @@ def __call__( if helper_strategy == "cot": helper_cot_reasoning = examples["Reasoning"].tolist() - + if learn_cascade_threshold_sample_percentage and lotus.settings.helper_lm: if helper_strategy == "cot": lotus.logger.error("CoT not supported for helper models in cascades.") raise Exception if recall_target is None or precision_target is None or failure_probability is None: - lotus.logger.error("Recall target, precision target, and confidence need to be specified for learned thresholds.") + lotus.logger.error( + "Recall target, precision target, and confidence need to be specified for learned thresholds." + ) raise Exception # Run small LM and get logits @@ -221,11 +223,14 @@ def __call__( strategy=helper_strategy, ) helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs - _, _, helper_true_probs = lotus.settings.helper_lm.format_logprobs_for_filter_cascade(helper_logprobs) - - helper_true_probs = calibrate_llm_logprobs(helper_true_probs) + formatted_helper_logprobs: LogprobsForFilterCascade = ( + lotus.settings.helper_lm.format_logprobs_for_filter_cascade(helper_logprobs) + ) + helper_true_probs = calibrate_llm_logprobs(formatted_helper_logprobs.true_probs) - sample_indices, correction_factors = importance_sampling(helper_true_probs, learn_cascade_threshold_sample_percentage) + sample_indices, correction_factors = importance_sampling( + helper_true_probs, learn_cascade_threshold_sample_percentage + ) sample_df = self._obj.loc[sample_indices] sample_df_txt = task_instructions.df2text(sample_df, col_li) sample_helper_true_probs = [helper_true_probs[i] for i in sample_indices] @@ -238,7 +243,7 @@ def __call__( default=default, recall_target=recall_target, precision_target=precision_target, - delta=failure_probability/2, + delta=failure_probability / 2, helper_true_probs=sample_helper_true_probs, sample_correction_factors=sample_correction_factors, examples_df_txt=examples_df_txt, @@ -261,7 +266,13 @@ def __call__( true_prob = helper_true_probs[idx_i] if true_prob >= pos_cascade_threshold or true_prob <= neg_cascade_threshold: high_conf_idxs.add(idx_i) - helper_outputs[idx_i] = True if true_prob >= pos_cascade_threshold else False if true_prob <= neg_cascade_threshold else helper_outputs[idx_i] + helper_outputs[idx_i] = ( + True + if true_prob >= pos_cascade_threshold + else False + if true_prob <= neg_cascade_threshold + else helper_outputs[idx_i] + ) lotus.logger.info(f"Num routed to smaller model: {len(high_conf_idxs)}") stats["num_routed_to_helper_model"] = len(high_conf_idxs) diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index 05f6bacc..0d01cc55 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -1,6 +1,5 @@ from typing import Any -import numpy as np import pandas as pd import lotus @@ -136,24 +135,17 @@ def sem_join_cascade( assert helper_logprobs is not None high_conf_idxs = set() - for idx_i in range(len(helper_outputs)): - tokens: list[str] - confidences: np.ndarray[Any, np.dtype[np.float64]] - # Get the logprobs - if lotus.settings.helper_lm.provider == "vllm": - tokens = helper_logprobs[idx_i]["tokens"] - confidences = np.exp(helper_logprobs[idx_i]["token_logprobs"]) - elif lotus.settings.helper_lm.provider == "openai": - content: list[dict[str, Any]] = helper_logprobs[idx_i]["content"] - tokens = [content[t_idx]["token"] for t_idx in range(len(content))] - confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))]) + # Get the logprobs in a standardized format + formatted_logprobs = lotus.settings.helper_lm.format_logprobs_for_cascade(helper_logprobs) + tokens, confidences = formatted_logprobs.tokens, formatted_logprobs.confidences + for doc_idx in range(len(helper_outputs)): # Find where true/false is said and look at confidence - for idx_j in range(len(tokens) - 1, -1, -1): - if tokens[idx_j].strip(" \n").lower() in ["true", "false"]: - conf = confidences[idx_j] - if conf >= cascade_threshold: - high_conf_idxs.add(idx_i) + for token_idx in range(len(tokens[doc_idx]) - 1, -1, -1): + if tokens[doc_idx][token_idx].strip(" \n").lower() in ["true", "false"]: + confidence = confidences[doc_idx][token_idx] + if confidence >= cascade_threshold: + high_conf_idxs.add(doc_idx) # Send low confidence samples to large LM low_conf_idxs = sorted([i for i in range(len(helper_outputs)) if i not in high_conf_idxs]) diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 7ee84e62..9074c094 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -4,7 +4,7 @@ import lotus from lotus.templates import task_instructions -from lotus.types import SemanticMapOutput, SemanticMapPostprocessOutput +from lotus.types import LMOutput, SemanticMapOutput, SemanticMapPostprocessOutput from .postprocessors import map_postprocess @@ -45,14 +45,11 @@ def sem_map( inputs.append(prompt) # call model - raw_outputs = model(inputs) - assert isinstance(raw_outputs, list) and all( - isinstance(item, str) for item in raw_outputs - ), "Model must return a list of strings" + lm_output: LMOutput = model(inputs) # post process results - postprocess_output = postprocessor(raw_outputs, strategy in ["cot", "zs-cot"]) - lotus.logger.debug(f"raw_outputs: {raw_outputs}") + postprocess_output = postprocessor(lm_output.outputs, strategy in ["cot", "zs-cot"]) + lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index 49da6a57..d9feb20f 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -3,6 +3,7 @@ import pandas as pd import lotus +from lotus.types import RerankerOutput, RMOutput @pd.api.extensions.register_dataframe_accessor("sem_search") @@ -55,9 +56,9 @@ def __call__( search_K = K while True: - scores, doc_idxs = rm(query, search_K) - doc_idxs = doc_idxs[0] - scores = scores[0] + rm_output: RMOutput = rm(query, search_K) + doc_idxs = rm_output.indices[0] + scores = rm_output.distances[0] assert len(doc_idxs) == len(scores) postfiltered_doc_idxs = [] @@ -83,7 +84,8 @@ def __call__( if n_rerank is not None: docs = new_df[col_name].tolist() - reranked_idxs = lotus.settings.reranker(query, docs, n_rerank) + reranked_output: RerankerOutput = lotus.settings.reranker(query, docs, n_rerank) + reranked_idxs = reranked_output.indices new_df = new_df.iloc[reranked_idxs] return new_df diff --git a/lotus/sem_ops/sem_sim_join.py b/lotus/sem_ops/sem_sim_join.py index d0094f74..04be885f 100644 --- a/lotus/sem_ops/sem_sim_join.py +++ b/lotus/sem_ops/sem_sim_join.py @@ -3,6 +3,8 @@ import pandas as pd import lotus +from lotus.models import RM +from lotus.types import RMOutput @pd.api.extensions.register_dataframe_accessor("sem_sim_join") @@ -46,8 +48,11 @@ def __call__( raise ValueError("Other Series must have a name") other = pd.DataFrame({other.name: other}) - # get rmodel and index rm = lotus.settings.rm + if not isinstance(rm, RM): + raise ValueError( + "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + ) # load query embeddings from index if they exist if left_on in self._obj.attrs.get("index_dirs", []): @@ -71,7 +76,9 @@ def __call__( rm.load_index(col_index_dir) assert rm.index_dir == col_index_dir - distances, indices = rm(queries, K) + rm_output: RMOutput = rm(queries, K) + distances = rm_output.distances + indices = rm_output.indices other_index_set = set(other.index) join_results = [] diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index a2b98a44..1db8b514 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -7,7 +7,7 @@ import lotus from lotus.templates import task_instructions -from lotus.types import SemanticTopKOutput +from lotus.types import LMOutput, SemanticTopKOutput def get_match_prompt_binary( @@ -59,14 +59,12 @@ def compare_batch_binary( pairs: list[tuple[str, str]], user_instruction: str, strategy: str | None = None ) -> tuple[list[bool], int]: match_prompts = [] - results = [] tokens = 0 for doc1, doc2 in pairs: match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction, strategy=strategy)) tokens += lotus.settings.lm.count_tokens(match_prompts[-1]) - - results = lotus.settings.lm(match_prompts) - results = list(map(parse_ans_binary, results)) + lm_results: LMOutput = lotus.settings.lm(match_prompts) + results: list[bool] = list(map(parse_ans_binary, lm_results.outputs)) return results, tokens @@ -109,8 +107,8 @@ def compare_batch_binary_cascade( large_match_prompts.append(match_prompts[i]) large_tokens += lotus.settings.lm.count_tokens(large_match_prompts[-1]) - results = lotus.settings.lm(large_match_prompts) - for idx, res in enumerate(results): + large_lm_results: LMOutput = lotus.settings.lm(large_match_prompts) + for idx, res in enumerate(large_lm_results.outputs): new_idx = low_conf_idxs[idx] parsed_res = parse_ans_binary(res) parsed_results[new_idx] = parsed_res @@ -161,7 +159,7 @@ def llm_naive_sort( def llm_quicksort( docs: list[str], user_instruction: str, - k: int, + K: int, embedding: bool = False, strategy: str | None = None, cascade_threshold: float | None = None, @@ -172,7 +170,7 @@ def llm_quicksort( Args: docs (list[str]): The list of documents to sort. user_instruction (str): The user instruction for sorting. - k (int): The number of documents to return. + K (int): The number of documents to return. embedding (bool): Whether to use embedding optimization. cascade_threshold (float | None): The confidence threshold for cascading to a larger model. @@ -189,14 +187,14 @@ def llm_quicksort( stats["total_small_calls"] = 0 stats["total_large_calls"] = 0 - def partition(indexes: list[int], low: int, high: int, k: int) -> int: + def partition(indexes: list[int], low: int, high: int, K: int) -> int: nonlocal stats i = low - 1 if embedding: # With embedding optimization - if k <= high - low: - pivot_value = heapq.nsmallest(k, indexes[low : high + 1])[-1] + if K <= high - low: + pivot_value = heapq.nsmallest(K, indexes[low : high + 1])[-1] else: pivot_value = heapq.nsmallest(int((high - low + 1) / 2), indexes[low : high + 1])[-1] pivot_index = indexes.index(pivot_value) @@ -233,21 +231,21 @@ def partition(indexes: list[int], low: int, high: int, k: int) -> int: indexes[i + 1], indexes[high] = indexes[high], indexes[i + 1] return i + 1 - def quicksort_recursive(indexes: list[int], low: int, high: int, k: int) -> None: + def quicksort_recursive(indexes: list[int], low: int, high: int, K: int) -> None: if high <= low: return if low < high: - pi = partition(indexes, low, high, k) + pi = partition(indexes, low, high, K) left_size = pi - low - if left_size + 1 >= k: - quicksort_recursive(indexes, low, pi - 1, k) + if left_size + 1 >= K: + quicksort_recursive(indexes, low, pi - 1, K) else: quicksort_recursive(indexes, low, pi - 1, left_size) - quicksort_recursive(indexes, pi + 1, high, k - left_size - 1) + quicksort_recursive(indexes, pi + 1, high, K - left_size - 1) indexes = list(range(len(docs))) - quicksort_recursive(indexes, 0, len(indexes) - 1, k) + quicksort_recursive(indexes, 0, len(indexes) - 1, K) return SemanticTopKOutput(indexes=indexes, stats=stats) @@ -268,14 +266,14 @@ def __lt__(self, other: "HeapDoc") -> bool: prompt = get_match_prompt_binary(self.doc, other.doc, self.user_instruction, strategy=self.strategy) HeapDoc.num_calls += 1 HeapDoc.total_tokens += lotus.settings.lm.count_tokens(prompt) - result = lotus.settings.lm(prompt) - return parse_ans_binary(result[0]) + result: LMOutput = lotus.settings.lm([prompt]) + return parse_ans_binary(result.outputs[0]) def llm_heapsort( docs: list[str], user_instruction: str, - k: int, + K: int, strategy: str | None = None, ) -> SemanticTopKOutput: """ @@ -284,7 +282,7 @@ def llm_heapsort( Args: docs (list[str]): The list of documents to sort. user_instruction (str): The user instruction for sorting. - k (int): The number of documents to return. + K (int): The number of documents to return. Returns: SemanticTopKOutput: The indexes of the top k documents and stats. @@ -294,7 +292,7 @@ def llm_heapsort( HeapDoc.strategy = strategy N = len(docs) heap = [HeapDoc(docs[idx], user_instruction, idx) for idx in range(N)] - heap = heapq.nsmallest(k, heap) + heap = heapq.nsmallest(K, heap) indexes = [heapq.heappop(heap).idx for _ in range(len(heap))] stats = {"total_tokens": HeapDoc.total_tokens, "total_llm_calls": HeapDoc.num_calls} diff --git a/lotus/settings.py b/lotus/settings.py index ee8acba4..a928880c 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -1,3 +1,5 @@ +# type: ignore + import copy import threading from contextlib import contextmanager @@ -113,4 +115,4 @@ def __repr__(self) -> str: # set defaults settings = Settings() -settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50) \ No newline at end of file +settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50) diff --git a/lotus/types.py b/lotus/types.py index a33339d5..28cbcfe9 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -1,18 +1,51 @@ from typing import Any +from litellm.types.utils import ChatCompletionTokenLogprob from pydantic import BaseModel +################################################################################ +# Mixins +################################################################################ class StatsMixin(BaseModel): stats: dict[str, Any] | None = None -# TODO: Figure out better logprobs type class LogprobsMixin(BaseModel): - logprobs: list[dict[str, Any]] | None = None + # for each response, we have a list of tokens, and for each token, we have a ChatCompletionTokenLogprob + logprobs: list[list[ChatCompletionTokenLogprob]] | None = None -class SemanticMapPostprocessOutput(StatsMixin, LogprobsMixin): +################################################################################ +# LM related +################################################################################ +class LMOutput(LogprobsMixin): + outputs: list[str] + + +class LMStats(BaseModel): + class TotalUsage(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + total_cost: float = 0.0 + + total_usage: TotalUsage = TotalUsage() + + +class LogprobsForCascade(BaseModel): + tokens: list[list[str]] + confidences: list[list[float]] + + +class LogprobsForFilterCascade(LogprobsForCascade): + true_probs: list[float] + + +################################################################################ +# Semantic operation outputs +################################################################################ +class SemanticMapPostprocessOutput(BaseModel): raw_outputs: list[str] outputs: list[str] explanations: list[str | None] @@ -55,3 +88,18 @@ class SemanticJoinOutput(StatsMixin): class SemanticTopKOutput(StatsMixin): indexes: list[int] + + +################################################################################ +# RM related +################################################################################ +class RMOutput(BaseModel): + distances: list[list[float]] + indices: list[list[int]] + + +################################################################################ +# Reranker related +################################################################################ +class RerankerOutput(BaseModel): + indices: list[int] diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..0d73c326 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,6 @@ +[mypy] +python_version = 3.10 +ignore_missing_imports = True +strict_optional = True +show_error_codes = True +files = lotus/**/*.py diff --git a/pyproject.toml b/pyproject.toml index 1ce25c08..6a170ce0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "lotus-ai" -version = "0.2.2" +version = "0.3.0" description = "lotus" readme = "README.md" authors = [ @@ -24,8 +24,8 @@ classifiers = [ dependencies = [ "backoff>=2.2.1,<3.0.0", "faiss-cpu>=1.8.0.post1,<2.0.0", + "litellm>=1.51.0,<2.0.0", "numpy>=1.25.0,<2.0.0", - "openai>=1.35.13,<2.0.0", "pandas>=2.0.0,<3.0.0", "sentence-transformers>=3.0.1,<4.0.0", "tiktoken>=0.7.0,<1.0.0", diff --git a/requirements-dev.txt b/requirements-dev.txt index 883ede67..9fc528f5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,7 @@ -r requirements.txt # Additional development dependencies -ruff==0.5.2 +ruff==0.7.2 +mypy==1.13.0 +pytest==8.3.3 +pre-commit==4.0.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ba74caf9..655dde54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ backoff==2.2.1 faiss-cpu==1.8.0.post1 +litellm==1.51.0 numpy==1.26.4 -openai==1.35.13 pandas==2.2.2 sentence-transformers==3.0.1 tiktoken==0.7.0