From 151ccd755bd3b998c7ab5c6b89e5257acef37074 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Sat, 18 Jan 2025 11:31:10 -0800 Subject: [PATCH] Added tests for COT and few shot COT for semantic filter --- .github/tests/lm_tests.py | 63 +++++++++++++++++++++++++++++++++++++ lotus/sem_ops/sem_filter.py | 1 - 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 3a18bf8..d418726 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -210,6 +210,69 @@ def test_sem_extract(setup_models, model): ), f"Number of Championships '{row['Number of Championships']}' not found in '{row['Number of Championships_quote']}'" +################################################################################ +# CoT tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) +def test_filter_operation_cot(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + data = { + "Text": [ + "I had two apples, then I gave away one", + "My friend gave me an apple", + "I gave away both of my apples", + "I gave away my apple, then a friend gave me his apple, then I threw my apple away", + ] + } + df = pd.DataFrame(data) + user_instruction = "{Text} I have at least one apple" + filtered_df = df.sem_filter(user_instruction, strategy="cot") + expected_df = pd.DataFrame({"Text": ["I had two apples, then I gave away one", "My friend gave me an apple"]}) + assert filtered_df.equals(expected_df) + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) +def test_filter_operation_cot_fewshot(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + data = { + "Sequence": [ + "Five, Four, Three", + "A, B, C", + "Pond, Lake, Ocean", + ] + } + df = pd.DataFrame(data) + examples = { + "Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city"], + "Answer": [True, True, True], + "Reasoning": [ + "1, 2, 3 is an increasing sequence of numbers", + "penny, nickel, dime, quarter is an increasing sequence of coins", + "villiage, town, city is an increasing sequence of settlements", + ], + } + examples_df = pd.DataFrame(examples) + + user_instruction = "{Sequence} is increasing" + filtered_df = df.sem_filter(user_instruction, strategy="cot", examples=examples_df) + expected_df = pd.DataFrame( + { + "Sequence": [ + "A, B, C", + "Pond, Lake, Ocean", + ] + }, + index=[1, 2], + ) + assert filtered_df.equals(expected_df) + + ################################################################################ # Cascade tests ################################################################################ diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 7a8cf4b..9ffcc64 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -211,7 +211,6 @@ def __call__( examples_answers = examples["Answer"].tolist() if strategy == "cot": - return_explanations = True cot_reasoning = examples["Reasoning"].tolist() pos_cascade_threshold, neg_cascade_threshold = None, None