Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
sidjha1 committed Nov 3, 2024
1 parent 77271fc commit 1d8aa5d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
9 changes: 8 additions & 1 deletion .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import pandas as pd
import pytest
from tokenizers import Tokenizer
Expand All @@ -19,10 +20,11 @@
MODEL_NAME_TO_ENABLED = {
"gpt-4o-mini": ENABLE_OPENAI_TESTS,
"gpt-4o": ENABLE_OPENAI_TESTS,
"ollama/llama3.2": ENABLE_OLLAMA_TESTS
"ollama/llama3.2": ENABLE_OLLAMA_TESTS,
}
ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled])


def get_enabled(*candidate_models: tuple) -> list[str]:
return [model for model in candidate_models if model in ENABLED_MODEL_NAMES]

Expand Down Expand Up @@ -64,6 +66,7 @@ def test_filter_operation(setup_models, model):
expected_df = pd.DataFrame({"Text": ["I am really excited to go to class today!"]})
assert filtered_df.equals(expected_df)


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_top_k(setup_models, model):
lm = setup_models[model]
Expand Down Expand Up @@ -105,6 +108,7 @@ def test_join(setup_models, model):
expected_pairs = set([("UC Berkeley", "Public School"), ("Stanford", "Private School")])
assert joined_pairs == expected_pairs


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2"))
def test_map_fewshot(setup_models, model):
lm = setup_models[model]
Expand Down Expand Up @@ -135,6 +139,7 @@ def test_agg_then_map(setup_models, model):
cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output")
assert cleaned_df["final_output"].values[0].lower().strip(".,!?\"'") == "john"


################################################################################
# Cascade tests
################################################################################
Expand Down Expand Up @@ -207,6 +212,7 @@ def test_filter_cascade(setup_models):
assert "I am very sad" not in filtered_df["Text"].values
assert stats["filters_resolved_by_helper_model"] > 0, stats


@pytest.mark.skipif(not ENABLE_OPENAI_TESTS, reason="Skipping test because OpenAI tests are not enabled")
def test_join_cascade(setup_models):
models = setup_models
Expand Down Expand Up @@ -234,6 +240,7 @@ def test_join_cascade(setup_models):
assert stats["filters_resolved_by_large_model"] == 4, stats
assert stats["filters_resolved_by_helper_model"] == 0, stats


################################################################################
# Token counting tests
################################################################################
Expand Down
2 changes: 1 addition & 1 deletion lotus/models/lm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

import numpy as np
import litellm
import numpy as np
from litellm import batch_completion, completion_cost
from litellm.types.utils import ChatCompletionTokenLogprob, Choices, ModelResponse
from litellm.utils import token_counter
Expand Down

0 comments on commit 1d8aa5d

Please sign in to comment.