From 13d01618ec3f0e3c55730e64af225f24899c5304 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Mon, 13 Jan 2025 18:03:54 -0800 Subject: [PATCH] ruff format and removed excesss changes to mnimize PR --- examples/op_examples/filter.py | 1 - examples/op_examples/filter_cascade.py | 1 - lotus/sem_ops/sem_filter.py | 10 ++++++++-- lotus/templates/task_instructions.py | 8 ++++---- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index f20aa593..a1acc00d 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,7 +6,6 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) - data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/filter_cascade.py b/examples/op_examples/filter_cascade.py index a1b94f4d..104c8410 100644 --- a/examples/op_examples/filter_cascade.py +++ b/examples/op_examples/filter_cascade.py @@ -8,7 +8,6 @@ gpt_4o = LM("gpt-4o") lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) - data = { "Course Name": [ "Probability and Random Processes", diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index dfb3f05f..7a8cf4b0 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -27,7 +27,7 @@ def sem_filter( safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Filtering", - additional_cot_instructions: str = "" + additional_cot_instructions: str = "", ) -> SemanticFilterOutput: """ Filters a list of documents based on a given user instruction using a language model. @@ -49,7 +49,13 @@ def sem_filter( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( - doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy, reasoning_instructions=additional_cot_instructions + doc, + user_instruction, + examples_multimodal_data, + examples_answers, + cot_reasoning, + strategy, + reasoning_instructions=additional_cot_instructions, ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index fbef1ea2..a71acd8c 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -108,10 +108,10 @@ def filter_formatter( assert examples_answer is not None assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list) assert len(examples_multimodal_data) == len(examples_answer) - + if cot_reasoning: - # users don't have to provide cot reasoning examples - # but if they do, the number of examples must match + # users don't have to provide cot reasoning examples + # but if they do, the number of examples must match assert isinstance(cot_reasoning, list) assert len(examples_multimodal_data) == len(examples_answer) == len(cot_reasoning) @@ -128,7 +128,7 @@ def filter_formatter( content = cot_formatter("Reasoning omitted", str(ex_ans)) else: content = answer_only_formatter(str(ex_ans)) - + messages.extend( [ user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"),