Skip to content

Commit

Permalink
added a test to validate that examples of reasoning are not always ne…
Browse files Browse the repository at this point in the history
…eded with CoT
  • Loading branch information
dhruviyer committed Jan 18, 2025
1 parent 151ccd7 commit cfaf28e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
32 changes: 32 additions & 0 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,38 @@ def test_filter_operation_cot_fewshot(setup_models, model):
)
assert filtered_df.equals(expected_df)

@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1"))
def test_filter_operation_cot_fewshot_no_reasoning(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)

# Test filter operation on an easy dataframe
data = {
"Sequence": [
"Five, Four, Three",
"A, B, C",
"Pond, Lake, Ocean",
]
}
df = pd.DataFrame(data)
examples = {
"Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city"],
"Answer": [True, True, True],
}
examples_df = pd.DataFrame(examples)

user_instruction = "{Sequence} is increasing"
filtered_df = df.sem_filter(user_instruction, strategy="cot", examples=examples_df)
expected_df = pd.DataFrame(
{
"Sequence": [
"A, B, C",
"Pond, Lake, Ocean",
]
},
index=[1, 2],
)
assert filtered_df.equals(expected_df)

################################################################################
# Cascade tests
Expand Down
6 changes: 3 additions & 3 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __call__(
examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li)
examples_answers = examples["Answer"].tolist()

if strategy == "cot":
if strategy == "cot" and "Reasoning" in examples.columns:
cot_reasoning = examples["Reasoning"].tolist()

pos_cascade_threshold, neg_cascade_threshold = None, None
Expand All @@ -224,8 +224,8 @@ def __call__(
helper_examples_multimodal_data = task_instructions.df2multimodal_info(helper_examples, col_li)
helper_examples_answers = helper_examples["Answer"].tolist()

if helper_strategy == "cot":
helper_cot_reasoning = examples["Reasoning"].tolist()
if helper_strategy == "cot" and "Reasoning" in helper_examples.columns:
helper_cot_reasoning = helper_examples["Reasoning"].tolist()

if cascade_args and lotus.settings.helper_lm:
if helper_strategy == "cot":
Expand Down

0 comments on commit cfaf28e

Please sign in to comment.