From 5d56fb63830f5d286e0fa32524a450d6466510d6 Mon Sep 17 00:00:00 2001 From: Harshit Gupta Date: Sun, 17 Nov 2024 22:56:33 -0800 Subject: [PATCH] sem agg done --- lotus/sem_ops/sem_extract.py | 8 +-- lotus/templates/task_instructions.py | 79 ++++++---------------------- 2 files changed, 20 insertions(+), 67 deletions(-) diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 82e82a98..1deee258 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -10,7 +10,7 @@ def sem_extract( - docs: list[str], + docs: list[dict[str, Any]], model: lotus.models.LM, user_instruction: str, postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess, @@ -19,7 +19,7 @@ def sem_extract( Extracts from a list of documents using a model. Args: - docs (list[str]): The list of documents to extract from. + docs (list[dict[str, Any]]): The list of documents to extract from. model (lotus.models.LM): The model to use. user_instruction (str): The user instruction for extract. postprocessor (Callable): The postprocessor for the model outputs. Defaults to extract_postprocess. @@ -85,11 +85,11 @@ def __call__( if column not in self._obj.columns: raise ValueError(f"Column {column} not found in DataFrame") - df_txt = task_instructions.df2text(self._obj, col_li) + multimodal_data = task_instructions.df2multimodal_info(self._obj, col_li) formatted_usr_instr = lotus.nl_expression.nle2str(user_instruction, col_li) output = sem_extract( - df_txt, + multimodal_data, lotus.settings.lm, formatted_usr_instr, postprocessor=postprocessor, diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index afbb5aa3..283f0ef4 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -5,53 +5,9 @@ from lotus.dtype_extensions import ImageDtype -def filter_user_message_formatter( +def user_message_formatter( multimodal_data: dict[str, Any] | str, - user_instruction: str, -) -> dict[str, Any]: - if isinstance(multimodal_data, str): - text = multimodal_data - image_inputs: list[dict[str, str]] = [] - elif isinstance(multimodal_data, dict): - image_data: dict[str, str] = multimodal_data.get("image", {}) - _image_inputs: list[tuple[dict, dict]] = [ - ( - { - "type": "text", - "text": f"[{key.capitalize()}]: \n", - }, - { - "type": "image_url", - "image_url": {"url": base64_image}, - }, - ) - for key, base64_image in image_data.items() - ] - image_inputs = [m for image_input in _image_inputs for m in image_input] - text = multimodal_data["text"] or "" - else: - raise ValueError("multimodal_data must be a dictionary or a string") - - if not image_inputs or len(image_inputs) == 0: - return { - "role": "user", - "content": f"Context:\n{text}\n\nClaim: {user_instruction}", - } - return { - "role": "user", - "content": [ - { - "type": "text", - "text": f"Claim: {user_instruction}\n\nContext:\n{text}", - }, - ] - + image_inputs, - } - - -def map_user_message_formatter( - multimodal_data: dict[str, Any] | str, - user_instruction: str, + user_instruction_with_tag: str, ) -> dict[str, Any]: if isinstance(multimodal_data, str): text = multimodal_data @@ -79,14 +35,14 @@ def map_user_message_formatter( if not image_inputs or len(image_inputs) == 0: return { "role": "user", - "content": f"Context:\n{text}\n\nInstruction: {user_instruction}", + "content": f"Context:\n{text}\n\n{user_instruction_with_tag}", } return { "role": "user", "content": [ { "type": "text", - "text": f"nInstruction: {user_instruction}\n\nContext:\n{text}", + "text": f"{user_instruction_with_tag}\n\nContext:\n{text}", }, ] + image_inputs, @@ -115,7 +71,7 @@ def filter_formatter_cot( cot = cot_reasoning[idx] messages.extend( [ - filter_user_message_formatter(ex_multimodal_data, user_instruction), + user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), { "role": "assistant", "content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}", @@ -123,7 +79,7 @@ def filter_formatter_cot( ] ) - messages.append(filter_user_message_formatter(multimodal_data, user_instruction)) + messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) return messages @@ -140,7 +96,7 @@ def filter_formatter_zs_cot( {"role": "system", "content": sys_instruction}, ] - messages.append(filter_user_message_formatter(multimodal_data, user_instruction)) + messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) return messages @@ -177,12 +133,12 @@ def filter_formatter( ex_ans = examples_answer[i] messages.extend( [ - filter_user_message_formatter(ex_multimodal_data, user_instruction), + user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), {"role": "assistant", "content": str(ex_ans)}, ] ) - messages.append(filter_user_message_formatter(multimodal_data, user_instruction)) + messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) return messages @@ -208,7 +164,7 @@ def map_formatter_cot( cot = cot_reasoning[idx] messages.extend( [ - map_user_message_formatter(ex_df_txt, user_instruction), + user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"), { "role": "assistant", "content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}", @@ -216,7 +172,7 @@ def map_formatter_cot( ] ) - messages.append(map_user_message_formatter(multimodal_data, user_instruction)) + messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) return messages @@ -233,7 +189,7 @@ def map_formatter_zs_cot( {"role": "system", "content": sys_instruction}, ] - messages.append(map_user_message_formatter(multimodal_data, user_instruction)) + messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) return messages @@ -266,16 +222,16 @@ def map_formatter( for ex_df_txt, ex_ans in zip(examples_multimodal_data, examples_answer): messages.extend( [ - map_user_message_formatter(ex_df_txt, user_instruction), + user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"), {"role": "assistant", "content": str(ex_ans)}, ] ) - messages.append(map_user_message_formatter(multimodal_data, user_instruction)) + messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) return messages -def extract_formatter(df_text: str, user_instruction: str) -> list[dict[str, str]]: +def extract_formatter(multimodal_data: dict[str, Any], user_instruction: str) -> list[dict[str, str]]: sys_instruction = ( "The user will provide an instruction and some relevant context.\n" "Your job is to extract the information requested in the instruction.\n" @@ -284,10 +240,7 @@ def extract_formatter(df_text: str, user_instruction: str) -> list[dict[str, str ) messages = [ {"role": "system", "content": sys_instruction}, - { - "role": "user", - "content": f"Context:\n{df_text}\n\nInstruction: {user_instruction}", - }, + user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"), ] return messages