From a77a580e3f0cc1140cf5dfd5b457c9f022d20f14 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 9 Jan 2025 19:35:46 -0800 Subject: [PATCH 01/18] cot and zs-cot support for semantic filter --- examples/op_examples/filter.py | 1 + examples/op_examples/filter_cascade.py | 1 + lotus/sem_ops/postprocessors.py | 88 ++++++++++----------- lotus/sem_ops/sem_filter.py | 4 +- lotus/templates/task_instructions.py | 101 ++++++++----------------- 5 files changed, 75 insertions(+), 120 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index a1acc00d..f20aa593 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,6 +6,7 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) + data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/filter_cascade.py b/examples/op_examples/filter_cascade.py index 104c8410..a1b94f4d 100644 --- a/examples/op_examples/filter_cascade.py +++ b/examples/op_examples/filter_cascade.py @@ -8,6 +8,7 @@ gpt_4o = LM("gpt-4o") lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) + data = { "Course Name": [ "Probability and Random Processes", diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index d531099c..e11dbcf2 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -7,6 +7,33 @@ SemanticMapPostprocessOutput, ) +def cot_postprocessor(llm_answers: list[str]): + outputs: list[str | None] = [] + explanations: list[str | None] = [] + for llm_answer in llm_answers: + import xml.etree.ElementTree as ET + try: + root = ET.fromstring(f"{llm_answer}") + reasoning = root.find('Reasoning') + answer = root.find('Answer') + + if reasoning is None or answer is None: + raise ValueError("Failed to parse reasoning or answer") + + reasoning = reasoning.text.strip() if reasoning.text else None + answer = answer.text.strip() if answer.text else "" + + explanations.append(reasoning) + outputs.append(answer) + + lotus.logger.debug(f"{llm_answer}") + + except (ET.ParseError, ValueError): + lotus.logger.debug(f"\t Failed to parse reasoning and answer from: {llm_answer}") + explanations.append(None) + outputs.append("") + + return outputs, explanations def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput: """ @@ -79,49 +106,9 @@ def extract_postprocess(llm_answers: list[str]) -> SemanticExtractPostprocessOut return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data) - -def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFilterPostprocessOutput: - """ - Postprocess the output of the filter operator with CoT reasoning. - - Args: - llm_answers (list[str]): The list of llm answers. - default (bool): The default value to use if we fail to parse the answer. - - Returns: - SemanticFilterPostprocessOutput - """ - outputs: list[bool] = [] - explanations: list[str | None] = [] - - for llm_answer in llm_answers: - reasoning_idx = llm_answer.find("Reasoning:\n") - if reasoning_idx == -1: - reasoning_idx = 0 - else: - reasoning_idx += len("Reasoning:\n") - - answer_idx = llm_answer.find("Answer:") - reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n") - answer = llm_answer[answer_idx + len("Answer:") :] - - explanations.append(reasoning) - - if "True" in answer: - outputs.append(True) - elif "False" in answer: - outputs.append(False) - else: - lotus.logger.info(f"\t Failed to parse: defaulting to {default}") - outputs.append(default) - - return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) - - def filter_postprocess( llm_answers: list[str], default: bool = True, - cot_reasoning: bool = False, ) -> SemanticFilterPostprocessOutput: """ Postprocess the output of the filter operator. @@ -134,18 +121,21 @@ def filter_postprocess( Returns: SemanticFilterPostprocessOutput """ - if cot_reasoning: - return filter_postprocess_cot(llm_answers, default) + outputs, explanations = cot_postprocessor(llm_answers) + + def process_outputs(answer): + if answer is None: + lotus.logger.info(f"\t Failed to parse {answer}: defaulting to {default}") + return default - outputs: list[bool] = [] - explanations: list[str | None] = [None] * len(llm_answers) - for answer in llm_answers: if "True" in answer: - outputs.append(True) + return True elif "False" in answer: - outputs.append(False) + return False else: - lotus.logger.info(f"\t Failed to parse: defaulting to {default}") - outputs.append(default) + lotus.logger.info(f"\t Failed to parse {answer}: defaulting to {default}") + return default + + outputs = [process_outputs(answer) for answer in outputs] return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index d6253b8d..fba85c5e 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -47,7 +47,7 @@ def sem_filter( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( - doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy + doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) @@ -63,7 +63,7 @@ def sem_filter( ) postprocess_output = filter_postprocess( - lm_output.outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"] + lm_output.outputs, default=default ) lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index fc30efd9..53393027 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -8,6 +8,16 @@ from lotus.types import SerializationFormat +def cot_formatter(reasoning, answer): + return f"""{reasoning}{answer}""" + +def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: str = "") -> str: + reasoning_instructions = f"Provide your reasoning here.{reasoning_instructions}" + answer_instructions = f"Provide your answer here. {answer_instructions}" + return f"""Let's think step by step. Use the following format to provide your answer: + {cot_formatter(reasoning_instructions, answer_instructions)} + """ + def context_formatter( multimodal_data: dict[str, Any] | str, ) -> tuple[str, list[dict[str, str]]]: @@ -54,79 +64,22 @@ def user_message_formatter( "content": content, } - -def filter_formatter_cot( - multimodal_data: dict[str, Any], - user_instruction: str, - examples_multimodal_data: list[dict[str, Any]], - examples_answer: list[bool], - cot_reasoning: list[str], -) -> list[dict[str, str]]: - sys_instruction = ( - "The user will provide a claim and some relevant context.\n" - "Your job is to determine whether the claim is true for the given context.\n" - 'First give your reasoning. Then you MUST end your output with "Answer: True or False"' - ) - messages = [ - {"role": "system", "content": sys_instruction}, - ] - - for idx in range(len(examples_multimodal_data)): - ex_multimodal_data = examples_multimodal_data[idx] - ex_ans = examples_answer[idx] - cot = cot_reasoning[idx] - messages.extend( - [ - user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), - { - "role": "assistant", - "content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}", - }, - ] - ) - - messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) - return messages - - -def filter_formatter_zs_cot( - multimodal_data: dict[str, Any], - user_instruction: str, -) -> list[dict[str, str]]: - sys_instruction = ( - "The user will provide a claim and some relevant context.\n" - "Your job is to determine whether the claim is true for the given context.\n" - 'First give your reasoning. Then you MUST end your output with "Answer: True or False"' - ) - messages = [ - {"role": "system", "content": sys_instruction}, - ] - - messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) - return messages - - def filter_formatter( multimodal_data: dict[str, Any], user_instruction: str, examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answer: list[bool] | None = None, - cot_reasoning: list[str] | None = None, - strategy: str | None = None, + cot_reasoning: list[str] | None = None ) -> list[dict[str, str]]: - if cot_reasoning: - assert examples_multimodal_data is not None and examples_answer is not None - return filter_formatter_cot( - multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning - ) - elif strategy == "zs-cot": - return filter_formatter_zs_cot(multimodal_data, user_instruction) - + sys_instruction = ( - "The user will provide a claim and some relevant context.\n" - "Your job is to determine whether the claim is true for the given context.\n" - 'You must answer with a single word, "True" or "False".' + f"""The user will provide a claim and some relevant context. + Your job is to determine whether the claim is true for the given context. + + {cot_prompt_formatter(answer_instructions="The answer should be either True or False")} + """ ) + messages = [ {"role": "system", "content": sys_instruction}, ] @@ -134,13 +87,23 @@ def filter_formatter( if examples_multimodal_data: assert examples_answer is not None assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list) - for i in range(len(examples_multimodal_data)): - ex_multimodal_data = examples_multimodal_data[i] - ex_ans = examples_answer[i] + assert len(examples_multimodal_data) == len(examples_answer) + + if cot_reasoning: + assert isinstance(cot_reasoning, list) + assert len(examples_multimodal_data) == len(examples_answer) == len(cot_reasoning) + + for idx in range(len(examples_multimodal_data)): + ex_multimodal_data = examples_multimodal_data[idx] + ex_ans = examples_answer[idx] + cot = cot_reasoning[idx] if cot_reasoning else "" messages.extend( [ user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), - {"role": "assistant", "content": str(ex_ans)}, + { + "role": "assistant", + "content": f"""{cot_formatter(cot, ex_ans)}""", + }, ] ) From 2999d584ae64a013d0887896b27d3d349f739cdd Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 9 Jan 2025 23:24:08 -0800 Subject: [PATCH 02/18] made cot optional --- examples/op_examples/filter.py | 4 ++-- lotus/sem_ops/postprocessors.py | 24 +++++++++++-------- lotus/sem_ops/sem_filter.py | 2 +- lotus/templates/task_instructions.py | 35 ++++++++++++++++++++-------- 4 files changed, 42 insertions(+), 23 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index f20aa593..69c19a15 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,7 +6,7 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) - +lotus.logger.setLevel("DEBUG") data = { "Course Name": [ "Probability and Random Processes", @@ -17,5 +17,5 @@ } df = pd.DataFrame(data) user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction) +df = df.sem_filter(user_instruction, strategy="cot") print(df) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index e11dbcf2..0e377cb8 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -14,21 +14,25 @@ def cot_postprocessor(llm_answers: list[str]): import xml.etree.ElementTree as ET try: root = ET.fromstring(f"{llm_answer}") - reasoning = root.find('Reasoning') - answer = root.find('Answer') + reasoning = root.find('.//Reasoning') # Use XPath to find nested tags + answer = root.find('.//Answer') # Use XPath to find nested tags - if reasoning is None or answer is None: - raise ValueError("Failed to parse reasoning or answer") - - reasoning = reasoning.text.strip() if reasoning.text else None - answer = answer.text.strip() if answer.text else "" + if answer is not None and answer.text: + answer = answer.text.strip() + else: + lotus.logger.error(f"\t Failed to parse answer from: {llm_answer}") + answer = "" + + if reasoning is not None and reasoning.text: + reasoning = reasoning.text.strip() + else: + lotus.logger.debug(f"\t Failed to parse reasoning from: {llm_answer}") + reasoning = None explanations.append(reasoning) outputs.append(answer) - - lotus.logger.debug(f"{llm_answer}") - except (ET.ParseError, ValueError): + except (ET.ParseError): lotus.logger.debug(f"\t Failed to parse reasoning and answer from: {llm_answer}") explanations.append(None) outputs.append("") diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index fba85c5e..e4ec5939 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -47,7 +47,7 @@ def sem_filter( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( - doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning + doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 53393027..94454299 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -7,7 +7,6 @@ from lotus.dtype_extensions import ImageDtype from lotus.types import SerializationFormat - def cot_formatter(reasoning, answer): return f"""{reasoning}{answer}""" @@ -17,6 +16,9 @@ def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: return f"""Let's think step by step. Use the following format to provide your answer: {cot_formatter(reasoning_instructions, answer_instructions)} """ +def non_cot_prompt_formatter(answer_instructions: str = "") -> str: + answer_instructions = f"Provide your answer here. {answer_instructions}" + return f"""{answer_instructions}""" def context_formatter( multimodal_data: dict[str, Any] | str, @@ -45,7 +47,6 @@ def context_formatter( raise ValueError("multimodal_data must be a dictionary or a string") return text, image_inputs - def user_message_formatter( multimodal_data: dict[str, Any] | str, user_instruction_with_tag: str | None = None, @@ -69,16 +70,30 @@ def filter_formatter( user_instruction: str, examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answer: list[bool] | None = None, - cot_reasoning: list[str] | None = None + cot_reasoning: list[str] | None = None, + strategy: str | None = None, + reasoning_instructions: str = "", ) -> list[dict[str, str]]: + answer_instructions="The answer should be either True or False" - sys_instruction = ( - f"""The user will provide a claim and some relevant context. - Your job is to determine whether the claim is true for the given context. + if strategy == "cot": + sys_instruction = ( + f"""The user will provide a claim and some relevant context. + Your job is to determine whether the claim is true for the given context. + + {cot_prompt_formatter( + reasoning_instructions=reasoning_instructions, + answer_instructions=answer_instructions)} + """ + ) + else: + sys_instruction = ( + f"""The user will provide a claim and some relevant context. + Your job is to determine whether the claim is true for the given context. - {cot_prompt_formatter(answer_instructions="The answer should be either True or False")} - """ - ) + {non_cot_prompt_formatter(answer_instructions=answer_instructions)} + """ + ) messages = [ {"role": "system", "content": sys_instruction}, @@ -96,7 +111,7 @@ def filter_formatter( for idx in range(len(examples_multimodal_data)): ex_multimodal_data = examples_multimodal_data[idx] ex_ans = examples_answer[idx] - cot = cot_reasoning[idx] if cot_reasoning else "" + cot = cot_reasoning[idx] if cot_reasoning else "Reasoning for this example has not been provided" messages.extend( [ user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), From 1ef7446a2b9a6fa4add6173a20df8eeb0257f812 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 9 Jan 2025 23:43:32 -0800 Subject: [PATCH 03/18] linting and formatting --- examples/op_examples/filter.py | 2 +- lotus/sem_ops/postprocessors.py | 22 +++++++----- lotus/sem_ops/sem_filter.py | 4 +-- lotus/templates/task_instructions.py | 51 ++++++++++++++++------------ 4 files changed, 44 insertions(+), 35 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index 69c19a15..ebc17088 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -17,5 +17,5 @@ } df = pd.DataFrame(data) user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction, strategy="cot") +df = df.sem_filter(user_instruction, strategy="") print(df) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 0e377cb8..8bec4180 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -7,18 +7,20 @@ SemanticMapPostprocessOutput, ) + def cot_postprocessor(llm_answers: list[str]): outputs: list[str | None] = [] explanations: list[str | None] = [] for llm_answer in llm_answers: import xml.etree.ElementTree as ET + try: root = ET.fromstring(f"{llm_answer}") - reasoning = root.find('.//Reasoning') # Use XPath to find nested tags - answer = root.find('.//Answer') # Use XPath to find nested tags + reasoning = root.find(".//Reasoning") # Use XPath to find nested tags + answer = root.find(".//Answer") # Use XPath to find nested tags if answer is not None and answer.text: - answer = answer.text.strip() + answer = answer.text.strip() else: lotus.logger.error(f"\t Failed to parse answer from: {llm_answer}") answer = "" @@ -26,19 +28,20 @@ def cot_postprocessor(llm_answers: list[str]): if reasoning is not None and reasoning.text: reasoning = reasoning.text.strip() else: - lotus.logger.debug(f"\t Failed to parse reasoning from: {llm_answer}") + lotus.logger.debug(f"\t Unable to extract reasoning from: {llm_answer}. Was CoT used?") reasoning = None explanations.append(reasoning) outputs.append(answer) - - except (ET.ParseError): - lotus.logger.debug(f"\t Failed to parse reasoning and answer from: {llm_answer}") + + except ET.ParseError: + lotus.logger.debug(f"\t XML error parsing: {llm_answer}") explanations.append(None) outputs.append("") - + return outputs, explanations + def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput: """ Postprocess the output of the map operator with CoT reasoning. @@ -110,6 +113,7 @@ def extract_postprocess(llm_answers: list[str]) -> SemanticExtractPostprocessOut return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data) + def filter_postprocess( llm_answers: list[str], default: bool = True, @@ -139,7 +143,7 @@ def process_outputs(answer): else: lotus.logger.info(f"\t Failed to parse {answer}: defaulting to {default}") return default - + outputs = [process_outputs(answer) for answer in outputs] return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index e4ec5939..ea8605cf 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -62,9 +62,7 @@ def sem_filter( inputs, show_progress_bar=show_progress_bar, progress_bar_desc=progress_bar_desc, **kwargs ) - postprocess_output = filter_postprocess( - lm_output.outputs, default=default - ) + postprocess_output = filter_postprocess(lm_output.outputs, default=default) lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 94454299..c177a48d 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -7,18 +7,29 @@ from lotus.dtype_extensions import ImageDtype from lotus.types import SerializationFormat + def cot_formatter(reasoning, answer): return f"""{reasoning}{answer}""" + +def answer_only_formatter(answer): + return f"""{answer}""" + + def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: str = "") -> str: - reasoning_instructions = f"Provide your reasoning here.{reasoning_instructions}" + reasoning_instructions = f"Provide your reasoning here. {reasoning_instructions}" answer_instructions = f"Provide your answer here. {answer_instructions}" return f"""Let's think step by step. Use the following format to provide your answer: {cot_formatter(reasoning_instructions, answer_instructions)} """ + + def non_cot_prompt_formatter(answer_instructions: str = "") -> str: answer_instructions = f"Provide your answer here. {answer_instructions}" - return f"""{answer_instructions}""" + return f"""Use the following format to provide your answer: + {answer_only_formatter(answer_instructions)} + """ + def context_formatter( multimodal_data: dict[str, Any] | str, @@ -47,6 +58,7 @@ def context_formatter( raise ValueError("multimodal_data must be a dictionary or a string") return text, image_inputs + def user_message_formatter( multimodal_data: dict[str, Any] | str, user_instruction_with_tag: str | None = None, @@ -65,6 +77,7 @@ def user_message_formatter( "content": content, } + def filter_formatter( multimodal_data: dict[str, Any], user_instruction: str, @@ -74,26 +87,18 @@ def filter_formatter( strategy: str | None = None, reasoning_instructions: str = "", ) -> list[dict[str, str]]: - answer_instructions="The answer should be either True or False" - - if strategy == "cot": - sys_instruction = ( - f"""The user will provide a claim and some relevant context. - Your job is to determine whether the claim is true for the given context. + answer_instructions = "The answer should be either True or False" - {cot_prompt_formatter( - reasoning_instructions=reasoning_instructions, - answer_instructions=answer_instructions)} - """ - ) - else: - sys_instruction = ( - f"""The user will provide a claim and some relevant context. - Your job is to determine whether the claim is true for the given context. + sys_instruction = """The user will provide a claim and some relevant context. + Your job is to determine whether the claim is true for the given context. + """ - {non_cot_prompt_formatter(answer_instructions=answer_instructions)} - """ + if strategy == "cot": + sys_instruction += cot_prompt_formatter( + reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions ) + else: + sys_instruction += non_cot_prompt_formatter(answer_instructions=answer_instructions) messages = [ {"role": "system", "content": sys_instruction}, @@ -107,17 +112,19 @@ def filter_formatter( if cot_reasoning: assert isinstance(cot_reasoning, list) assert len(examples_multimodal_data) == len(examples_answer) == len(cot_reasoning) - + for idx in range(len(examples_multimodal_data)): ex_multimodal_data = examples_multimodal_data[idx] ex_ans = examples_answer[idx] - cot = cot_reasoning[idx] if cot_reasoning else "Reasoning for this example has not been provided" + messages.extend( [ user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), { "role": "assistant", - "content": f"""{cot_formatter(cot, ex_ans)}""", + "content": cot_formatter(cot_reasoning[idx], str(ex_ans)) + if cot_reasoning + else answer_only_formatter(str(ex_ans)), }, ] ) From ced50e076ffd7d4cf02dfa4a21e4bab1643d65cd Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 9 Jan 2025 23:49:18 -0800 Subject: [PATCH 04/18] cleaning up for code review --- examples/op_examples/filter.py | 4 ++-- lotus/sem_ops/postprocessors.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index ebc17088..f20aa593 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,7 +6,7 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) -lotus.logger.setLevel("DEBUG") + data = { "Course Name": [ "Probability and Random Processes", @@ -17,5 +17,5 @@ } df = pd.DataFrame(data) user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction, strategy="") +df = df.sem_filter(user_instruction) print(df) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 8bec4180..70baedfd 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -16,8 +16,8 @@ def cot_postprocessor(llm_answers: list[str]): try: root = ET.fromstring(f"{llm_answer}") - reasoning = root.find(".//Reasoning") # Use XPath to find nested tags - answer = root.find(".//Answer") # Use XPath to find nested tags + reasoning = root.find(".//Reasoning") + answer = root.find(".//Answer") if answer is not None and answer.text: answer = answer.text.strip() From fb6bdf3bcf702c75d656c679336575570c6e7fc4 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Fri, 10 Jan 2025 00:14:14 -0800 Subject: [PATCH 05/18] exposed ability to add custom reasoning instructions and disaggregated providing examples and requiring the model to use CoT --- examples/op_examples/filter.py | 1 + lotus/sem_ops/sem_filter.py | 12 +++++++++++- lotus/templates/task_instructions.py | 20 +++++++++++++++----- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index f20aa593..1580c7f4 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,6 +6,7 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) +lotus.logger.setLevel("DEBUG") data = { "Course Name": [ diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index ea8605cf..dfb3f05f 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -27,6 +27,7 @@ def sem_filter( safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Filtering", + additional_cot_instructions: str = "" ) -> SemanticFilterOutput: """ Filters a list of documents based on a given user instruction using a language model. @@ -40,6 +41,7 @@ def sem_filter( examples_answers (list[bool] | None): The answers for examples. Defaults to None. cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None. logprobs (bool): Whether to return log probabilities. Defaults to False. + additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "". Returns: SemanticFilterOutput: The True/False outputs, raw outputs, and explanations, and log probabilities. @@ -47,7 +49,7 @@ def sem_filter( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( - doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy + doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy, reasoning_instructions=additional_cot_instructions ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) @@ -85,6 +87,7 @@ def learn_filter_cascade_thresholds( examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, strategy: str | None = None, + additional_cot_instructions: str = "", ) -> tuple[float, float]: """Automatically learns the cascade thresholds for a cascade filter given a sample of data and doing a search across threshold @@ -102,6 +105,7 @@ def learn_filter_cascade_thresholds( strategy=strategy, safe_mode=False, progress_bar_desc="Running oracle for threshold learning", + additional_cot_instructions=additional_cot_instructions, ).outputs best_combination, _ = learn_cascade_thresholds( @@ -148,6 +152,7 @@ def __call__( return_stats: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Filtering", + additional_cot_instructions: str = "", ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: """ Applies semantic filter over a dataframe. @@ -166,6 +171,7 @@ def __call__( sampling_percentage (float): The percentage of the data to sample when cascading. Defaults to 0.1. failure_probability (float): The failure probability when cascading. Defaults to 0.2. return_stats (bool): Whether to return statistics. Defaults to False. + additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "". Returns: pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: The filtered dataframe or a tuple containing the filtered dataframe and statistics. @@ -245,6 +251,7 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc="Running helper LM", + additional_cot_instructions=additional_cot_instructions, ) helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs assert helper_logprobs is not None @@ -271,6 +278,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + additional_cot_instructions=additional_cot_instructions, ) stats["pos_cascade_threshold"] = pos_cascade_threshold @@ -325,6 +333,7 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", + additional_cot_instructions=additional_cot_instructions, ) for idx, large_idx in enumerate(low_conf_idxs): @@ -348,6 +357,7 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc=progress_bar_desc, + additional_cot_instructions=additional_cot_instructions, ) outputs = output.outputs raw_outputs = output.raw_outputs diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index c177a48d..680da4ba 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -108,23 +108,33 @@ def filter_formatter( assert examples_answer is not None assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list) assert len(examples_multimodal_data) == len(examples_answer) - + if cot_reasoning: + # users don't have to provide cot reasoning examples + # but if they do, the number of examples must match assert isinstance(cot_reasoning, list) assert len(examples_multimodal_data) == len(examples_answer) == len(cot_reasoning) for idx in range(len(examples_multimodal_data)): ex_multimodal_data = examples_multimodal_data[idx] ex_ans = examples_answer[idx] - + content = "" + + # if cot reasoning is provided, use it. Otherwise, supply a default + # reasoning as filler if the user wants cot reasoning + if cot_reasoning: + content = cot_formatter(cot_reasoning[idx], str(ex_ans)) + elif strategy == "cot": + content = cot_formatter("Reasoning omitted", str(ex_ans)) + else: + content = answer_only_formatter(str(ex_ans)) + messages.extend( [ user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), { "role": "assistant", - "content": cot_formatter(cot_reasoning[idx], str(ex_ans)) - if cot_reasoning - else answer_only_formatter(str(ex_ans)), + "content": content, }, ] ) From 436882e96e6f07cfb0111121d825f17a87eedd18 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Fri, 10 Jan 2025 00:28:39 -0800 Subject: [PATCH 06/18] remove debug level in filter example --- examples/op_examples/filter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index 1580c7f4..f20aa593 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,7 +6,6 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) -lotus.logger.setLevel("DEBUG") data = { "Course Name": [ From e0640e9a00e6602b88298d5bbfe34b6acf680a11 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Fri, 10 Jan 2025 00:35:30 -0800 Subject: [PATCH 07/18] fix mypy errors --- lotus/sem_ops/postprocessors.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 70baedfd..0472a010 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -20,19 +20,19 @@ def cot_postprocessor(llm_answers: list[str]): answer = root.find(".//Answer") if answer is not None and answer.text: - answer = answer.text.strip() + answer_str = answer.text.strip() else: lotus.logger.error(f"\t Failed to parse answer from: {llm_answer}") - answer = "" + answer_str = "" if reasoning is not None and reasoning.text: - reasoning = reasoning.text.strip() + reasoning_str= reasoning.text.strip() else: lotus.logger.debug(f"\t Unable to extract reasoning from: {llm_answer}. Was CoT used?") - reasoning = None + reasoning_str = None - explanations.append(reasoning) - outputs.append(answer) + explanations.append(reasoning_str) + outputs.append(answer_str) except ET.ParseError: lotus.logger.debug(f"\t XML error parsing: {llm_answer}") From 06857687d9f048795bf6700a0d1a80379f22b702 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Fri, 10 Jan 2025 17:53:09 -0800 Subject: [PATCH 08/18] prompt LLM to generate valid XML --- lotus/templates/task_instructions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 680da4ba..9351866f 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -21,6 +21,8 @@ def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: answer_instructions = f"Provide your answer here. {answer_instructions}" return f"""Let's think step by step. Use the following format to provide your answer: {cot_formatter(reasoning_instructions, answer_instructions)} + + Your response must be valid XML format. """ @@ -28,7 +30,8 @@ def non_cot_prompt_formatter(answer_instructions: str = "") -> str: answer_instructions = f"Provide your answer here. {answer_instructions}" return f"""Use the following format to provide your answer: {answer_only_formatter(answer_instructions)} - """ + + Your response must be valid XML format.""" def context_formatter( From 2d7df7a97a38eec7ad41b4bde1f77b800566c453 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Mon, 13 Jan 2025 17:34:39 -0800 Subject: [PATCH 09/18] revert using XML for CoT --- lotus/sem_ops/postprocessors.py | 36 +++++++++------------------- lotus/templates/task_instructions.py | 15 +++++------- 2 files changed, 17 insertions(+), 34 deletions(-) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 0472a010..a361a166 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -12,32 +12,18 @@ def cot_postprocessor(llm_answers: list[str]): outputs: list[str | None] = [] explanations: list[str | None] = [] for llm_answer in llm_answers: - import xml.etree.ElementTree as ET + reasoning_idx = llm_answer.find("Reasoning:\n") + if reasoning_idx == -1: + reasoning_idx = 0 + else: + reasoning_idx += len("Reasoning:\n") - try: - root = ET.fromstring(f"{llm_answer}") - reasoning = root.find(".//Reasoning") - answer = root.find(".//Answer") - - if answer is not None and answer.text: - answer_str = answer.text.strip() - else: - lotus.logger.error(f"\t Failed to parse answer from: {llm_answer}") - answer_str = "" - - if reasoning is not None and reasoning.text: - reasoning_str= reasoning.text.strip() - else: - lotus.logger.debug(f"\t Unable to extract reasoning from: {llm_answer}. Was CoT used?") - reasoning_str = None - - explanations.append(reasoning_str) - outputs.append(answer_str) - - except ET.ParseError: - lotus.logger.debug(f"\t XML error parsing: {llm_answer}") - explanations.append(None) - outputs.append("") + answer_idx = llm_answer.find("Answer:") + reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n") + answer = llm_answer[answer_idx + len("Answer:") :] + + explanations.append(reasoning) + outputs.append(answer) return outputs, explanations diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 9351866f..fbef1ea2 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -9,29 +9,26 @@ def cot_formatter(reasoning, answer): - return f"""{reasoning}{answer}""" + return f"""Reasoning:\n{reasoning}\n\nAnswer: {answer}""" def answer_only_formatter(answer): - return f"""{answer}""" + return f"""Answer: {answer}""" def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: str = "") -> str: - reasoning_instructions = f"Provide your reasoning here. {reasoning_instructions}" - answer_instructions = f"Provide your answer here. {answer_instructions}" + reasoning_instructions = f"" + answer_instructions = f"" return f"""Let's think step by step. Use the following format to provide your answer: {cot_formatter(reasoning_instructions, answer_instructions)} - - Your response must be valid XML format. """ def non_cot_prompt_formatter(answer_instructions: str = "") -> str: - answer_instructions = f"Provide your answer here. {answer_instructions}" + answer_instructions = f"" return f"""Use the following format to provide your answer: {answer_only_formatter(answer_instructions)} - - Your response must be valid XML format.""" + """ def context_formatter( From 13d01618ec3f0e3c55730e64af225f24899c5304 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Mon, 13 Jan 2025 18:03:54 -0800 Subject: [PATCH 10/18] ruff format and removed excesss changes to mnimize PR --- examples/op_examples/filter.py | 1 - examples/op_examples/filter_cascade.py | 1 - lotus/sem_ops/sem_filter.py | 10 ++++++++-- lotus/templates/task_instructions.py | 8 ++++---- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index f20aa593..a1acc00d 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,7 +6,6 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) - data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/filter_cascade.py b/examples/op_examples/filter_cascade.py index a1b94f4d..104c8410 100644 --- a/examples/op_examples/filter_cascade.py +++ b/examples/op_examples/filter_cascade.py @@ -8,7 +8,6 @@ gpt_4o = LM("gpt-4o") lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) - data = { "Course Name": [ "Probability and Random Processes", diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index dfb3f05f..7a8cf4b0 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -27,7 +27,7 @@ def sem_filter( safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Filtering", - additional_cot_instructions: str = "" + additional_cot_instructions: str = "", ) -> SemanticFilterOutput: """ Filters a list of documents based on a given user instruction using a language model. @@ -49,7 +49,13 @@ def sem_filter( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( - doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy, reasoning_instructions=additional_cot_instructions + doc, + user_instruction, + examples_multimodal_data, + examples_answers, + cot_reasoning, + strategy, + reasoning_instructions=additional_cot_instructions, ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index fbef1ea2..a71acd8c 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -108,10 +108,10 @@ def filter_formatter( assert examples_answer is not None assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list) assert len(examples_multimodal_data) == len(examples_answer) - + if cot_reasoning: - # users don't have to provide cot reasoning examples - # but if they do, the number of examples must match + # users don't have to provide cot reasoning examples + # but if they do, the number of examples must match assert isinstance(cot_reasoning, list) assert len(examples_multimodal_data) == len(examples_answer) == len(cot_reasoning) @@ -128,7 +128,7 @@ def filter_formatter( content = cot_formatter("Reasoning omitted", str(ex_ans)) else: content = answer_only_formatter(str(ex_ans)) - + messages.extend( [ user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), From 151ccd755bd3b998c7ab5c6b89e5257acef37074 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Sat, 18 Jan 2025 11:31:10 -0800 Subject: [PATCH 11/18] Added tests for COT and few shot COT for semantic filter --- .github/tests/lm_tests.py | 63 +++++++++++++++++++++++++++++++++++++ lotus/sem_ops/sem_filter.py | 1 - 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 3a18bf86..d418726e 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -210,6 +210,69 @@ def test_sem_extract(setup_models, model): ), f"Number of Championships '{row['Number of Championships']}' not found in '{row['Number of Championships_quote']}'" +################################################################################ +# CoT tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) +def test_filter_operation_cot(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + data = { + "Text": [ + "I had two apples, then I gave away one", + "My friend gave me an apple", + "I gave away both of my apples", + "I gave away my apple, then a friend gave me his apple, then I threw my apple away", + ] + } + df = pd.DataFrame(data) + user_instruction = "{Text} I have at least one apple" + filtered_df = df.sem_filter(user_instruction, strategy="cot") + expected_df = pd.DataFrame({"Text": ["I had two apples, then I gave away one", "My friend gave me an apple"]}) + assert filtered_df.equals(expected_df) + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) +def test_filter_operation_cot_fewshot(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + data = { + "Sequence": [ + "Five, Four, Three", + "A, B, C", + "Pond, Lake, Ocean", + ] + } + df = pd.DataFrame(data) + examples = { + "Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city"], + "Answer": [True, True, True], + "Reasoning": [ + "1, 2, 3 is an increasing sequence of numbers", + "penny, nickel, dime, quarter is an increasing sequence of coins", + "villiage, town, city is an increasing sequence of settlements", + ], + } + examples_df = pd.DataFrame(examples) + + user_instruction = "{Sequence} is increasing" + filtered_df = df.sem_filter(user_instruction, strategy="cot", examples=examples_df) + expected_df = pd.DataFrame( + { + "Sequence": [ + "A, B, C", + "Pond, Lake, Ocean", + ] + }, + index=[1, 2], + ) + assert filtered_df.equals(expected_df) + + ################################################################################ # Cascade tests ################################################################################ diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 7a8cf4b0..9ffcc64c 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -211,7 +211,6 @@ def __call__( examples_answers = examples["Answer"].tolist() if strategy == "cot": - return_explanations = True cot_reasoning = examples["Reasoning"].tolist() pos_cascade_threshold, neg_cascade_threshold = None, None From cfaf28e818840be5bd3caf306480d22a0f6e05a4 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Sat, 18 Jan 2025 11:39:04 -0800 Subject: [PATCH 12/18] added a test to validate that examples of reasoning are not always needed with CoT --- .github/tests/lm_tests.py | 32 ++++++++++++++++++++++++++++++++ lotus/sem_ops/sem_filter.py | 6 +++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index d418726e..325b6693 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -272,6 +272,38 @@ def test_filter_operation_cot_fewshot(setup_models, model): ) assert filtered_df.equals(expected_df) +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) +def test_filter_operation_cot_fewshot_no_reasoning(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + data = { + "Sequence": [ + "Five, Four, Three", + "A, B, C", + "Pond, Lake, Ocean", + ] + } + df = pd.DataFrame(data) + examples = { + "Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city"], + "Answer": [True, True, True], + } + examples_df = pd.DataFrame(examples) + + user_instruction = "{Sequence} is increasing" + filtered_df = df.sem_filter(user_instruction, strategy="cot", examples=examples_df) + expected_df = pd.DataFrame( + { + "Sequence": [ + "A, B, C", + "Pond, Lake, Ocean", + ] + }, + index=[1, 2], + ) + assert filtered_df.equals(expected_df) ################################################################################ # Cascade tests diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 9ffcc64c..9291163c 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -210,7 +210,7 @@ def __call__( examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) examples_answers = examples["Answer"].tolist() - if strategy == "cot": + if strategy == "cot" and "Reasoning" in examples.columns: cot_reasoning = examples["Reasoning"].tolist() pos_cascade_threshold, neg_cascade_threshold = None, None @@ -224,8 +224,8 @@ def __call__( helper_examples_multimodal_data = task_instructions.df2multimodal_info(helper_examples, col_li) helper_examples_answers = helper_examples["Answer"].tolist() - if helper_strategy == "cot": - helper_cot_reasoning = examples["Reasoning"].tolist() + if helper_strategy == "cot" and "Reasoning" in helper_examples.columns: + helper_cot_reasoning = helper_examples["Reasoning"].tolist() if cascade_args and lotus.settings.helper_lm: if helper_strategy == "cot": From 3401dcfae954aa6dafc144d238a1ad5a50afc4d7 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Sat, 18 Jan 2025 12:43:12 -0800 Subject: [PATCH 13/18] added a guidelines for llama3.1 CI test --- .github/tests/lm_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 325b6693..31662616 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -260,7 +260,7 @@ def test_filter_operation_cot_fewshot(setup_models, model): examples_df = pd.DataFrame(examples) user_instruction = "{Sequence} is increasing" - filtered_df = df.sem_filter(user_instruction, strategy="cot", examples=examples_df) + filtered_df = df.sem_filter(user_instruction, strategy="cot", examples=examples_df, additional_cot_instructions="Assume the most typical or logical case.") expected_df = pd.DataFrame( { "Sequence": [ From c4c22d1eb2424d32a97efa40fe6403599fb91281 Mon Sep 17 00:00:00 2001 From: liana313 <54730332+liana313@users.noreply.github.com> Date: Mon, 10 Feb 2025 11:31:56 -0800 Subject: [PATCH 14/18] =?UTF-8?q?added=20example,=20and=20some=20options?= =?UTF-8?q?=20to=20filter=20for=20returning=20all=20labels=20an=E2=80=A6?= =?UTF-8?q?=20(#100)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feat: adds example for cot, and adds options to filter to allow all rows to be returned with labels and explanations --- examples/op_examples/filter_cot.py | 28 ++++++++++++++++++++++ lotus/sem_ops/sem_filter.py | 37 +++++++++++++++++++++--------- 2 files changed, 54 insertions(+), 11 deletions(-) create mode 100644 examples/op_examples/filter_cot.py diff --git a/examples/op_examples/filter_cot.py b/examples/op_examples/filter_cot.py new file mode 100644 index 00000000..4058aaa5 --- /dev/null +++ b/examples/op_examples/filter_cot.py @@ -0,0 +1,28 @@ +import pandas as pd + +import lotus +from lotus.models import LM + +lm = LM(model="gpt-4o-mini") + +lotus.settings.configure(lm=lm) + + + + +# Test filter operation on an easy dataframe +data = { + "Text": [ + "I had two apples, then I gave away one", + "My friend gave me an apple", + "I gave away both of my apples", + "I gave away my apple, then a friend gave me his apple, then I threw my apple away", + ] +} +df = pd.DataFrame(data) +user_instruction = "{Text} I have at least one apple" +filtered_df = df.sem_filter(user_instruction, strategy="cot", return_all=True) +# filtered_df = df.sem_filter(user_instruction, strategy="cot", return_all=True, return_explanations=True) # uncomment to see reasoning chains + +print(filtered_df) +# print(filtered_df) \ No newline at end of file diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 9291163c..e31d581b 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -149,6 +149,7 @@ def __call__( user_instruction: str, return_raw_outputs: bool = False, return_explanations: bool = False, + return_all: bool = False, default: bool = True, suffix: str = "_filter", examples: pd.DataFrame | None = None, @@ -368,19 +369,33 @@ def __call__( raw_outputs = output.raw_outputs explanations = output.explanations - # find indices where output is True - ids = [i for i, x in enumerate(outputs) if x] - idx_ids = [self._obj.index[i] for i, x in enumerate(outputs) if x] - lotus.logger.debug(f"ids: {ids}") - lotus.logger.debug(f"idx_ids: {idx_ids}") + if return_all == False: + # find indices where output is True + ids = [i for i, x in enumerate(outputs) if x] + idx_ids = [self._obj.index[i] for i, x in enumerate(outputs) if x] + lotus.logger.debug(f"ids: {ids}") + lotus.logger.debug(f"idx_ids: {idx_ids}") - [outputs[i] for i in ids] - filtered_explanations = [explanations[i] for i in ids] - filtered_raw_outputs = [raw_outputs[i] for i in ids] - lotus.logger.debug(f"filtered_raw_outputs: {filtered_raw_outputs}") + [outputs[i] for i in ids] + filtered_explanations = [explanations[i] for i in ids] + filtered_raw_outputs = [raw_outputs[i] for i in ids] + lotus.logger.debug(f"filtered_raw_outputs: {filtered_raw_outputs}") - new_df = self._obj.iloc[ids] - new_df.attrs["index_dirs"] = self._obj.attrs.get("index_dirs", None) + new_df = self._obj.iloc[ids] + new_df.attrs["index_dirs"] = self._obj.attrs.get("index_dirs", None) + else: + def get_out_col_name(df, col_name): + if col_name in df.columns: + i = 1 + while f"{col_name}_{i}" in new_df.columns: + i +=1 + return f"{col_name}_{i}" + else: + return col_name + new_df = self._obj.copy() + new_df[get_out_col_name(new_df, "filter_label")] = outputs + filtered_explanations = explanations + filtered_raw_outputs = raw_outputs # return rows where output is True if return_explanations and return_raw_outputs: From 4140102498700a3323f203b16de881bf8aa2cfa0 Mon Sep 17 00:00:00 2001 From: liana313 <54730332+liana313@users.noreply.github.com> Date: Mon, 10 Feb 2025 11:39:10 -0800 Subject: [PATCH 15/18] Liana/filter example (#101) From f4c3c172a990dc67f3b06f120149a5b696505dde Mon Sep 17 00:00:00 2001 From: liana313 <54730332+liana313@users.noreply.github.com> Date: Mon, 10 Feb 2025 11:42:24 -0800 Subject: [PATCH 16/18] fix format (#102) fix format From 73d5a076b81f15171760e1d28a2be8a3cac5a13f Mon Sep 17 00:00:00 2001 From: liana313 <54730332+liana313@users.noreply.github.com> Date: Mon, 10 Feb 2025 11:46:16 -0800 Subject: [PATCH 17/18] Fix format (#103) --- lotus/sem_ops/sem_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index e31d581b..38e6cc68 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -388,7 +388,7 @@ def get_out_col_name(df, col_name): if col_name in df.columns: i = 1 while f"{col_name}_{i}" in new_df.columns: - i +=1 + i += 1 return f"{col_name}_{i}" else: return col_name From b43292b9a90fc98b3387a109035c59328ba54024 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Wed, 12 Feb 2025 18:31:36 -0800 Subject: [PATCH 18/18] ran linting and formatting --- .github/tests/lm_tests.py | 9 ++++++++- .github/tests/multimodality_tests.py | 6 +++--- examples/op_examples/filter.py | 3 ++- examples/op_examples/multimodal_ops/filter.py | 4 +--- lotus/sem_ops/sem_filter.py | 4 +++- lotus/utils.py | 4 +--- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 31662616..4a492770 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -260,7 +260,12 @@ def test_filter_operation_cot_fewshot(setup_models, model): examples_df = pd.DataFrame(examples) user_instruction = "{Sequence} is increasing" - filtered_df = df.sem_filter(user_instruction, strategy="cot", examples=examples_df, additional_cot_instructions="Assume the most typical or logical case.") + filtered_df = df.sem_filter( + user_instruction, + strategy="cot", + examples=examples_df, + additional_cot_instructions="Assume the most typical or logical case.", + ) expected_df = pd.DataFrame( { "Sequence": [ @@ -272,6 +277,7 @@ def test_filter_operation_cot_fewshot(setup_models, model): ) assert filtered_df.equals(expected_df) + @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) def test_filter_operation_cot_fewshot_no_reasoning(setup_models, model): lm = setup_models[model] @@ -305,6 +311,7 @@ def test_filter_operation_cot_fewshot_no_reasoning(setup_models, model): ) assert filtered_df.equals(expected_df) + ################################################################################ # Cascade tests ################################################################################ diff --git a/.github/tests/multimodality_tests.py b/.github/tests/multimodality_tests.py index c6311d81..68117a6b 100644 --- a/.github/tests/multimodality_tests.py +++ b/.github/tests/multimodality_tests.py @@ -138,7 +138,8 @@ def test_topk_operation(setup_models, model): top_2_actual = set(sorted_df["image"].values) assert top_2_expected == top_2_actual - + + @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_topk_with_groupby_operation(setup_models, model): image_url = [ @@ -153,8 +154,7 @@ def test_topk_with_groupby_operation(setup_models, model): df = image_df.join(element_df, how="cross") df.sem_topk("the {image} is most likely an {element}", K=1, group_by=["element"]) - assert(len(set(df["element"])) == 2) - + assert len(set(df["element"])) == 2 @pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32")) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index a1acc00d..f244051a 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,6 +6,7 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) + data = { "Course Name": [ "Probability and Random Processes", @@ -16,5 +17,5 @@ } df = pd.DataFrame(data) user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction) +df = df.sem_filter(user_instruction, strategy="cot") print(df) diff --git a/examples/op_examples/multimodal_ops/filter.py b/examples/op_examples/multimodal_ops/filter.py index d2d2e597..3fbb0fdb 100644 --- a/examples/op_examples/multimodal_ops/filter.py +++ b/examples/op_examples/multimodal_ops/filter.py @@ -15,9 +15,7 @@ labels = [os.path.splitext(image)[0] for image in image_file_names] image_paths = [os.path.join("images", image) for image in image_file_names] -df = pd.DataFrame({"image": ImageArray(image_paths), - "label": labels, - "image_path": image_paths}) +df = pd.DataFrame({"image": ImageArray(image_paths), "label": labels, "image_path": image_paths}) df = df.sem_filter("{image} represents number 1") print(df) diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 38e6cc68..cc7586a4 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -369,7 +369,7 @@ def __call__( raw_outputs = output.raw_outputs explanations = output.explanations - if return_all == False: + if not return_all: # find indices where output is True ids = [i for i, x in enumerate(outputs) if x] idx_ids = [self._obj.index[i] for i, x in enumerate(outputs) if x] @@ -384,6 +384,7 @@ def __call__( new_df = self._obj.iloc[ids] new_df.attrs["index_dirs"] = self._obj.attrs.get("index_dirs", None) else: + def get_out_col_name(df, col_name): if col_name in df.columns: i = 1 @@ -392,6 +393,7 @@ def get_out_col_name(df, col_name): return f"{col_name}_{i}" else: return col_name + new_df = self._obj.copy() new_df[get_out_col_name(new_df, "filter_label")] = outputs filtered_explanations = explanations diff --git a/lotus/utils.py b/lotus/utils.py index b3e68ac9..a8928005 100644 --- a/lotus/utils.py +++ b/lotus/utils.py @@ -29,8 +29,6 @@ def ret( verbose: bool = False, method: str = "kmeans", ) -> list[int]: - - import faiss """Cluster by column, and return a series in the dataframe with cluster-ids""" @@ -64,7 +62,7 @@ def ret( # get nearest centroid to each vector scores, indices = kmeans.index.search(vec_set, 1) - + # get the cluster centroids # centroids = kmeans.centroids # return indices.flatten(), scores.flatten(), centroids