Skip to content

Commit

Permalink
Added tests for COT and few shot COT for semantic filter
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruviyer committed Jan 18, 2025
1 parent 13d0161 commit 151ccd7
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
63 changes: 63 additions & 0 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,69 @@ def test_sem_extract(setup_models, model):
), f"Number of Championships '{row['Number of Championships']}' not found in '{row['Number of Championships_quote']}'"


################################################################################
# CoT tests
################################################################################
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1"))
def test_filter_operation_cot(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)

# Test filter operation on an easy dataframe
data = {
"Text": [
"I had two apples, then I gave away one",
"My friend gave me an apple",
"I gave away both of my apples",
"I gave away my apple, then a friend gave me his apple, then I threw my apple away",
]
}
df = pd.DataFrame(data)
user_instruction = "{Text} I have at least one apple"
filtered_df = df.sem_filter(user_instruction, strategy="cot")
expected_df = pd.DataFrame({"Text": ["I had two apples, then I gave away one", "My friend gave me an apple"]})
assert filtered_df.equals(expected_df)


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1"))
def test_filter_operation_cot_fewshot(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)

# Test filter operation on an easy dataframe
data = {
"Sequence": [
"Five, Four, Three",
"A, B, C",
"Pond, Lake, Ocean",
]
}
df = pd.DataFrame(data)
examples = {
"Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city"],
"Answer": [True, True, True],
"Reasoning": [
"1, 2, 3 is an increasing sequence of numbers",
"penny, nickel, dime, quarter is an increasing sequence of coins",
"villiage, town, city is an increasing sequence of settlements",
],
}
examples_df = pd.DataFrame(examples)

user_instruction = "{Sequence} is increasing"
filtered_df = df.sem_filter(user_instruction, strategy="cot", examples=examples_df)
expected_df = pd.DataFrame(
{
"Sequence": [
"A, B, C",
"Pond, Lake, Ocean",
]
},
index=[1, 2],
)
assert filtered_df.equals(expected_df)


################################################################################
# Cascade tests
################################################################################
Expand Down
1 change: 0 additions & 1 deletion lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def __call__(
examples_answers = examples["Answer"].tolist()

if strategy == "cot":
return_explanations = True
cot_reasoning = examples["Reasoning"].tolist()

pos_cascade_threshold, neg_cascade_threshold = None, None
Expand Down

0 comments on commit 151ccd7

Please sign in to comment.