Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into gharshit/multimodal
Browse files Browse the repository at this point in the history
  • Loading branch information
harshitgupta412 committed Nov 15, 2024
2 parents df709e0 + 2480012 commit 3675563
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
14 changes: 8 additions & 6 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
MODEL_NAME_TO_ENABLED = {
"gpt-4o-mini": ENABLE_OPENAI_TESTS,
"gpt-4o": ENABLE_OPENAI_TESTS,
"ollama/llama3.2": ENABLE_OLLAMA_TESTS,
"ollama/llama3.1": ENABLE_OLLAMA_TESTS,
}
ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled])

Expand Down Expand Up @@ -53,7 +53,7 @@ def print_usage_after_each_test(setup_models):
################################################################################
# Standard tests
################################################################################
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2"))
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1"))
def test_filter_operation(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_top_k(setup_models, model):
assert top_2_expected == top_2_actual


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2"))
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1"))
def test_join(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)
Expand All @@ -110,7 +110,7 @@ def test_join(setup_models, model):
assert joined_pairs == expected_pairs


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2"))
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1"))
def test_map_fewshot(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)
Expand All @@ -122,8 +122,10 @@ def test_map_fewshot(setup_models, model):
user_instruction = "What state is {School} in? Respond only with the two-letter abbreviation."
df = df.sem_map(user_instruction, examples=examples_df, suffix="State")

# clean up the state names to be more robust to free-form text
df["State"] = df["State"].str[-2:].str.lower()
pairs = set(zip(df["School"], df["State"]))
expected_pairs = set([("UC Berkeley", "CA"), ("Carnegie Mellon", "PA")])
expected_pairs = set([("UC Berkeley", "ca"), ("Carnegie Mellon", "pa")])
assert pairs == expected_pairs


Expand Down Expand Up @@ -285,7 +287,7 @@ def test_format_logprobs_for_filter_cascade(setup_models, model):
################################################################################
# Token counting tests
################################################################################
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2"))
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1"))
def test_count_tokens(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ jobs:
sleep 1
timeout=$((timeout - 1))
done
docker exec $(docker ps -q) ollama run llama3.2
docker exec $(docker ps -q) ollama run llama3.1
- name: Run LM tests
env:
Expand Down

0 comments on commit 3675563

Please sign in to comment.