Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sidjha1 committed Nov 23, 2024
1 parent 654290d commit 85d096c
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 24 deletions.
6 changes: 5 additions & 1 deletion .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ def test_sem_extract(setup_models, model):
}
df = pd.DataFrame(data)
input_cols = ["Text"]
output_cols = ["Name", "Sport", "Number of Championships"]
output_cols = {
"Name": None,
"Sport": None,
"Number of Championships": None,
}
df = df.sem_extract(input_cols, output_cols)

expected_values = {
Expand Down
17 changes: 15 additions & 2 deletions examples/op_examples/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import lotus
from lotus.models import LM

lm = LM(model="gpt-4o")
lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)

df = pd.DataFrame(
Expand All @@ -16,7 +16,20 @@
}
)
input_cols = ["description"]
output_cols = ["name", "age"]

# A description can be specified for each output column
output_cols = {
"masked_col_1": "The name of the person",
"masked_col_2": "The age of the person",
}

new_df = df.sem_extract(input_cols, output_cols)
print(new_df)

# A description can also be omitted for each output column
output_cols = {
"name": None,
"age": None,
}
new_df = df.sem_extract(input_cols, output_cols)
print(new_df)
10 changes: 3 additions & 7 deletions lotus/sem_ops/postprocessors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import re

import lotus
from lotus.types import (
Expand Down Expand Up @@ -68,14 +67,11 @@ def extract_postprocess(llm_answers: list[str]) -> SemanticExtractPostprocessOut
SemanticExtractPostprocessOutput
"""
extract_data = []
for answers in llm_answers:
cleaned_answers = re.findall(r"(\{.*\})", answers, re.DOTALL)[0]
cleaned_answers = re.sub(r"\\(?![\"\\/bfnrt])", r"\\\\", cleaned_answers)

for llm_answer in llm_answers:
try:
output = json.loads(cleaned_answers)
output = json.loads(llm_answer)
except json.JSONDecodeError:
lotus.logger.info(f"\t Failed to parse: {cleaned_answers}")
lotus.logger.info(f"\t Failed to parse: {llm_answer}")
output = {}

output = {key: str(value) for key, value in output.items()}
Expand Down
13 changes: 8 additions & 5 deletions lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def sem_extract(
docs: list[str],
model: LM,
output_cols: list[str],
output_cols: dict[str, str | None],
extract_quotes: bool = True,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
) -> SemanticExtractOutput:
Expand All @@ -23,7 +23,7 @@ def sem_extract(
Args:
docs (list[str]): The list of documents to extract from.
model (lotus.models.LM): The model to use.
output_cols (list[str]): The columns that a model should extract.
output_cols (dict[str, str | None]): A mapping from desired output column names to optional descriptions.
extract_quotes (bool, optional): Whether to extract quotes for user_instruction. Defaults to True.
postprocessor (Callable): The postprocessor for the model outputs. Defaults to extract_postprocess.
Expand All @@ -40,7 +40,7 @@ def sem_extract(
inputs.append(prompt)

# call model
lm_output: LMOutput = model(inputs)
lm_output: LMOutput = model(inputs, response_format={"type": "json_object"})

# post process results
postprocess_output = postprocessor(lm_output.outputs)
Expand All @@ -64,7 +64,7 @@ def _validate(obj: pd.DataFrame) -> None:
def __call__(
self,
input_cols: list[str],
output_cols: list[str],
output_cols: dict[str, str | None],
extract_quotes: bool = True,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
return_raw_outputs: bool = False,
Expand All @@ -74,7 +74,7 @@ def __call__(
Args:
input_cols (list[str]): The columns that a model should extract from.
output_cols (list[str]): The columns that a model should extract.
output_cols (dict[str, str | None]): A mapping from desired output column names to optional descriptions.
extract_quotes (bool, optional): Whether to extract quotes for user_instruction. Defaults to True.
postprocessor (Callable): The postprocessor for the model outputs. Defaults to extract_postprocess.
return_raw_outputs (bool): Whether to return raw outputs. Defaults to False.
Expand Down Expand Up @@ -105,6 +105,9 @@ def __call__(
new_df[key] = None
new_df.loc[i, key] = value

if return_raw_outputs:
new_df["raw_output"] = out.raw_outputs

new_df = new_df.reset_index(drop=True)

return new_df
23 changes: 14 additions & 9 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,21 +203,26 @@ def map_formatter(
return messages


def extract_formatter(df_text: str, output_cols: list[str], extract_quotes: bool = True) -> list[dict[str, str]]:
def extract_formatter(
df_text: str, output_cols: dict[str, str | None], extract_quotes: bool = True
) -> list[dict[str, str]]:
output_col_names = list(output_cols.keys())
# Set the description to be the key if no value is provided
output_cols_with_desc: dict[str, str] = {col: col if desc is None else desc for col, desc in output_cols.items()}

all_fields = output_col_names
if extract_quotes:
quote_fields = [f"{col}_quote" for col in output_cols]
all_fields = output_cols + quote_fields
else:
all_fields = output_cols
quote_fields = [f"{col}_quote" for col in output_col_names]
all_fields += quote_fields

fields_str = ", ".join(all_fields)

sys_instruction = (
"The user will provide the columns that need to be extracted and some relevant context.\n"
f"Your job is to extract these columns and provide only the concise subject or topic as the value for each field "
f"and the corresponding full quote for each field in the '{', '.join([f'{col}_quote' for col in output_cols])}' fields.\n"
f"The response should be in JSONL in a single line format with the following fields: {fields_str}.\n"
"Only respond in JSONL format and no other text. Your output will be parsed with json.loads.\n"
f"Your job is to extract these columns and provide only a concise value for each field "
f"and the corresponding full quote for each field in the '{', '.join([f'{col}_quote' for col in output_col_names])}' fields.\n"
f"Here is a description of each field: {output_cols_with_desc}\n"
f"The response should be valid JSON format with the following fields: {fields_str}.\n"
)

messages = [
Expand Down

0 comments on commit 85d096c

Please sign in to comment.