Skip to content

Commit

Permalink
sem agg done
Browse files Browse the repository at this point in the history
  • Loading branch information
harshitgupta412 committed Nov 18, 2024
1 parent ac8fd82 commit 5d56fb6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 67 deletions.
8 changes: 4 additions & 4 deletions lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def sem_extract(
docs: list[str],
docs: list[dict[str, Any]],
model: lotus.models.LM,
user_instruction: str,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
Expand All @@ -19,7 +19,7 @@ def sem_extract(
Extracts from a list of documents using a model.
Args:
docs (list[str]): The list of documents to extract from.
docs (list[dict[str, Any]]): The list of documents to extract from.
model (lotus.models.LM): The model to use.
user_instruction (str): The user instruction for extract.
postprocessor (Callable): The postprocessor for the model outputs. Defaults to extract_postprocess.
Expand Down Expand Up @@ -85,11 +85,11 @@ def __call__(
if column not in self._obj.columns:
raise ValueError(f"Column {column} not found in DataFrame")

df_txt = task_instructions.df2text(self._obj, col_li)
multimodal_data = task_instructions.df2multimodal_info(self._obj, col_li)
formatted_usr_instr = lotus.nl_expression.nle2str(user_instruction, col_li)

output = sem_extract(
df_txt,
multimodal_data,
lotus.settings.lm,
formatted_usr_instr,
postprocessor=postprocessor,
Expand Down
79 changes: 16 additions & 63 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,53 +5,9 @@
from lotus.dtype_extensions import ImageDtype


def filter_user_message_formatter(
def user_message_formatter(
multimodal_data: dict[str, Any] | str,
user_instruction: str,
) -> dict[str, Any]:
if isinstance(multimodal_data, str):
text = multimodal_data
image_inputs: list[dict[str, str]] = []
elif isinstance(multimodal_data, dict):
image_data: dict[str, str] = multimodal_data.get("image", {})
_image_inputs: list[tuple[dict, dict]] = [
(
{
"type": "text",
"text": f"[{key.capitalize()}]: \n",
},
{
"type": "image_url",
"image_url": {"url": base64_image},
},
)
for key, base64_image in image_data.items()
]
image_inputs = [m for image_input in _image_inputs for m in image_input]
text = multimodal_data["text"] or ""
else:
raise ValueError("multimodal_data must be a dictionary or a string")

if not image_inputs or len(image_inputs) == 0:
return {
"role": "user",
"content": f"Context:\n{text}\n\nClaim: {user_instruction}",
}
return {
"role": "user",
"content": [
{
"type": "text",
"text": f"Claim: {user_instruction}\n\nContext:\n{text}",
},
]
+ image_inputs,
}


def map_user_message_formatter(
multimodal_data: dict[str, Any] | str,
user_instruction: str,
user_instruction_with_tag: str,
) -> dict[str, Any]:
if isinstance(multimodal_data, str):
text = multimodal_data
Expand Down Expand Up @@ -79,14 +35,14 @@ def map_user_message_formatter(
if not image_inputs or len(image_inputs) == 0:
return {
"role": "user",
"content": f"Context:\n{text}\n\nInstruction: {user_instruction}",
"content": f"Context:\n{text}\n\n{user_instruction_with_tag}",
}
return {
"role": "user",
"content": [
{
"type": "text",
"text": f"nInstruction: {user_instruction}\n\nContext:\n{text}",
"text": f"{user_instruction_with_tag}\n\nContext:\n{text}",
},
]
+ image_inputs,
Expand Down Expand Up @@ -115,15 +71,15 @@ def filter_formatter_cot(
cot = cot_reasoning[idx]
messages.extend(
[
filter_user_message_formatter(ex_multimodal_data, user_instruction),
user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"),
{
"role": "assistant",
"content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}",
},
]
)

messages.append(filter_user_message_formatter(multimodal_data, user_instruction))
messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
return messages


Expand All @@ -140,7 +96,7 @@ def filter_formatter_zs_cot(
{"role": "system", "content": sys_instruction},
]

messages.append(filter_user_message_formatter(multimodal_data, user_instruction))
messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
return messages


Expand Down Expand Up @@ -177,12 +133,12 @@ def filter_formatter(
ex_ans = examples_answer[i]
messages.extend(
[
filter_user_message_formatter(ex_multimodal_data, user_instruction),
user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"),
{"role": "assistant", "content": str(ex_ans)},
]
)

messages.append(filter_user_message_formatter(multimodal_data, user_instruction))
messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
return messages


Expand All @@ -208,15 +164,15 @@ def map_formatter_cot(
cot = cot_reasoning[idx]
messages.extend(
[
map_user_message_formatter(ex_df_txt, user_instruction),
user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"),
{
"role": "assistant",
"content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}",
},
]
)

messages.append(map_user_message_formatter(multimodal_data, user_instruction))
messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
return messages


Expand All @@ -233,7 +189,7 @@ def map_formatter_zs_cot(
{"role": "system", "content": sys_instruction},
]

messages.append(map_user_message_formatter(multimodal_data, user_instruction))
messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
return messages


Expand Down Expand Up @@ -266,16 +222,16 @@ def map_formatter(
for ex_df_txt, ex_ans in zip(examples_multimodal_data, examples_answer):
messages.extend(
[
map_user_message_formatter(ex_df_txt, user_instruction),
user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"),
{"role": "assistant", "content": str(ex_ans)},
]
)

messages.append(map_user_message_formatter(multimodal_data, user_instruction))
messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
return messages


def extract_formatter(df_text: str, user_instruction: str) -> list[dict[str, str]]:
def extract_formatter(multimodal_data: dict[str, Any], user_instruction: str) -> list[dict[str, str]]:
sys_instruction = (
"The user will provide an instruction and some relevant context.\n"
"Your job is to extract the information requested in the instruction.\n"
Expand All @@ -284,10 +240,7 @@ def extract_formatter(df_text: str, user_instruction: str) -> list[dict[str, str
)
messages = [
{"role": "system", "content": sys_instruction},
{
"role": "user",
"content": f"Context:\n{df_text}\n\nInstruction: {user_instruction}",
},
user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"),
]
return messages

Expand Down

0 comments on commit 5d56fb6

Please sign in to comment.