Skip to content

Commit

Permalink
New sem_Extract operator (#38)
Browse files Browse the repository at this point in the history
New Sem_Extract Operator that schemas a list of documents and creates a
DataFrame and provides the option to get quotes from context.

---------

Co-authored-by: liana313 <[email protected]>
Co-authored-by: Sid Jha <[email protected]>
  • Loading branch information
3 people authored Nov 23, 2024
1 parent 51c68f4 commit c85b8ed
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 75 deletions.
41 changes: 41 additions & 0 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,47 @@ def test_group_by_with_agg(setup_models, model):
assert set(cleaned_df["final_output"].values[1].lower().strip(".,!?\"'").split(", ")) == {"michael", "dwight"}


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_sem_extract(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)

data = {
"Text": [
"Lionel Messi is a good soccer player, he has won the World Cup 5 times",
"Michael Jordan is a good basketball player, he has won the NBA championships 6 times",
"Tiger Woods is a good golf player, he has won the Master championships 4 times",
"Tom Brady is a good football player, he has won the NFL championships 7 times",
]
}
df = pd.DataFrame(data)
input_cols = ["Text"]
output_cols = {
"Name": None,
"Sport": None,
"Number of Championships": None,
}
df = df.sem_extract(input_cols, output_cols, extract_quotes=True)

expected_values = {
"Name": ["lionel messi", "michael jordan", "tiger woods", "tom brady"],
"Sport": ["soccer", "basketball", "golf", "football"],
"Number of Championships": ["5", "6", "4", "7"],
}

for col in output_cols:
assert [str(val).strip().lower() for val in df[col].tolist()] == expected_values[col]

for idx, row in df.iterrows():
assert row["Name"] in row["Name_quote"], f"Name '{row['Name']}' not found in '{row['Name_quote']}'"
assert (
row["Sport"].lower() in row["Sport_quote"].lower()
), f"Sport '{row['Sport']}' not found in '{row['Sport_quote']}'"
assert (
str(row["Number of Championships"]) in row["Number of Championships_quote"]
), f"Number of Championships '{row['Number of Championships']}' not found in '{row['Number of Championships_quote']}'"


################################################################################
# Cascade tests
################################################################################
Expand Down
36 changes: 36 additions & 0 deletions examples/op_examples/extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pandas as pd

import lotus
from lotus.models import LM

lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)

df = pd.DataFrame(
{
"description": [
"Yoshi is 25 years old",
"Bowser is 45 years old",
"Luigi is 15 years old",
]
}
)
input_cols = ["description"]

# A description can be specified for each output column
output_cols = {
"masked_col_1": "The name of the person",
"masked_col_2": "The age of the person",
}

# you can optionally set extract_quotes=True to return quotes that support each output
new_df = df.sem_extract(input_cols, output_cols, extract_quotes=True)
print(new_df)

# A description can also be omitted for each output column
output_cols = {
"name": None,
"age": None,
}
new_df = df.sem_extract(input_cols, output_cols)
print(new_df)
57 changes: 29 additions & 28 deletions lotus/sem_ops/postprocessors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import json

import lotus
from lotus.types import SemanticExtractPostprocessOutput, SemanticFilterPostprocessOutput, SemanticMapPostprocessOutput
from lotus.types import (
SemanticExtractPostprocessOutput,
SemanticFilterPostprocessOutput,
SemanticMapPostprocessOutput,
)


def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput:
Expand Down Expand Up @@ -52,6 +56,30 @@ def map_postprocess(llm_answers: list[str], cot_reasoning: bool = False) -> Sema
return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)


def extract_postprocess(llm_answers: list[str]) -> SemanticExtractPostprocessOutput:
"""
Postprocess the output of the extract operator to extract the schema.
Args:
llm_answers (list[str]): The list of llm answers containging the extract.
Returns:
SemanticExtractPostprocessOutput
"""
extract_data = []
for llm_answer in llm_answers:
try:
output = json.loads(llm_answer)
except json.JSONDecodeError:
lotus.logger.info(f"\t Failed to parse: {llm_answer}")
output = {}

output = {key: str(value) for key, value in output.items()}
extract_data.append(output)

return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data)


def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFilterPostprocessOutput:
"""
Postprocess the output of the filter operator with CoT reasoning.
Expand Down Expand Up @@ -121,30 +149,3 @@ def filter_postprocess(
outputs.append(default)

return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)


def extract_postprocess(llm_answers: list[str]) -> SemanticExtractPostprocessOutput:
"""
Postprocess the output of the extract operator, which we assume to
be a JSONL with an answer and quotes field.
Args:
llm_answers (list[str]): The list of llm answers.
Returns:
SemanticExtractPostprocessOutput
"""
answers = []
quotes = []

for json_string in llm_answers:
try:
data = json.loads(json_string)
answers.append(data["answer"])
quotes.append(data["quotes"])
except Exception as e:
lotus.logger.error(f"Failed to parse JSON: {e}")
answers.append(None)
quotes.append(None)

return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=answers, quotes=quotes)
71 changes: 40 additions & 31 deletions lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Callable
from typing import Callable

import pandas as pd

import lotus
from lotus.models import LM
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput

Expand All @@ -11,94 +12,102 @@

def sem_extract(
docs: list[str],
model: lotus.models.LM,
user_instruction: str,
model: LM,
output_cols: dict[str, str | None],
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
) -> SemanticExtractOutput:
"""
Extracts from a list of documents using a model.
Extracts attributes and values from a list of documents using a model.
Args:
docs (list[str]): The list of documents to extract from.
model (lotus.models.LM): The model to use.
user_instruction (str): The user instruction for extract.
output_cols (dict[str, str | None]): A mapping from desired output column names to optional descriptions.
extract_quotes (bool): Whether to extract quotes for the output columns. Defaults to False.
postprocessor (Callable): The postprocessor for the model outputs. Defaults to extract_postprocess.
Returns:
SemanticExtractOutput: The outputs, raw outputs, and quotes.
"""

# prepare model inputs
inputs = []
for doc in docs:
prompt = lotus.templates.task_instructions.extract_formatter(doc, user_instruction)
prompt = task_instructions.extract_formatter(doc, output_cols, extract_quotes)
lotus.logger.debug(f"input to model: {prompt}")
lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}")
inputs.append(prompt)

# call model
lm_output: LMOutput = model(inputs)
lm_output: LMOutput = model(inputs, response_format={"type": "json_object"})

# post process results
postprocess_output = postprocessor(lm_output.outputs)
lotus.logger.debug(f"raw_outputs: {lm_output.outputs}")
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")
lotus.logger.debug(f"quotes: {postprocess_output.quotes}")

return SemanticExtractOutput(**postprocess_output.model_dump())


@pd.api.extensions.register_dataframe_accessor("sem_extract")
class SemExtractDataframe:
"""DataFrame accessor for semantic extract."""

def __init__(self, pandas_obj: Any):
class SemExtractDataFrame:
def __init__(self, pandas_obj: pd.DataFrame):
self._validate(pandas_obj)
self._obj = pandas_obj

@staticmethod
def _validate(obj: Any) -> None:
def _validate(obj: pd.DataFrame) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

def __call__(
self,
user_instruction: str,
input_cols: list[str],
output_cols: dict[str, str | None],
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
return_raw_outputs: bool = False,
suffix: str = "_extract",
) -> pd.DataFrame:
"""
Applies semantic extract over a dataframe.
Extracts the attributes and values of a dataframe.
Args:
user_instruction (str): The user instruction for extract.
input_cols (list[str]): The columns that a model should extract from.
output_cols (dict[str, str | None]): A mapping from desired output column names to optional descriptions.
extract_quotes (bool): Whether to extract quotes for the output columns. Defaults to False.
postprocessor (Callable): The postprocessor for the model outputs. Defaults to extract_postprocess.
return_raw_outputs (bool): Whether to return raw outputs. Defaults to False.
suffix (str): The suffix for the new columns. Defaults to "_extract".
Returns:
pd.DataFrame: The dataframe with the new extracted values.
pd.DataFrame: The dataframe with the new mapped columns.
"""
col_li = lotus.nl_expression.parse_cols(user_instruction)

# check that column exists
for column in col_li:
for column in input_cols:
if column not in self._obj.columns:
raise ValueError(f"Column {column} not found in DataFrame")

df_txt = task_instructions.df2text(self._obj, col_li)
formatted_usr_instr = lotus.nl_expression.nle2str(user_instruction, col_li)
docs = task_instructions.df2text(self._obj, input_cols)

output = sem_extract(
df_txt,
lotus.settings.lm,
formatted_usr_instr,
out = sem_extract(
docs=docs,
model=lotus.settings.lm,
output_cols=output_cols,
extract_quotes=extract_quotes,
postprocessor=postprocessor,
)

new_df = self._obj
new_df["answers" + suffix] = output.outputs
new_df["quotes" + suffix] = output.quotes
new_df = self._obj.copy()
for i, output_dict in enumerate(out.outputs):
for key, value in output_dict.items():
if key not in new_df.columns:
new_df[key] = None
new_df.loc[i, key] = value

if return_raw_outputs:
new_df["raw_output" + suffix] = output.raw_outputs
new_df["raw_output"] = out.raw_outputs

new_df = new_df.reset_index(drop=True)

return new_df
27 changes: 21 additions & 6 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,33 @@ def map_formatter(
return messages


def extract_formatter(df_text: str, user_instruction: str) -> list[dict[str, str]]:
def extract_formatter(
df_text: str, output_cols: dict[str, str | None], extract_quotes: bool = True
) -> list[dict[str, str]]:
output_col_names = list(output_cols.keys())
# Set the description to be the key if no value is provided
output_cols_with_desc: dict[str, str] = {col: col if desc is None else desc for col, desc in output_cols.items()}

all_fields = output_col_names
if extract_quotes:
quote_fields = [f"{col}_quote" for col in output_col_names]
all_fields += quote_fields

fields_str = ", ".join(all_fields)

sys_instruction = (
"The user will provide an instruction and some relevant context.\n"
"Your job is to extract the information requested in the instruction.\n"
"Write the response in JSONL format in a single line with the following fields:\n"
"""{"answer": "your answer", "quotes": "quote from context supporting your answer"}"""
"The user will provide the columns that need to be extracted and some relevant context.\n"
f"Your job is to extract these columns and provide only a concise value for each field "
f"and the corresponding full quote for each field in the '{', '.join([f'{col}_quote' for col in output_col_names])}' fields.\n"
f"Here is a description of each field: {output_cols_with_desc}\n"
f"The response should be valid JSON format with the following fields: {fields_str}.\n"
)

messages = [
{"role": "system", "content": sys_instruction},
{
"role": "user",
"content": f"Context:\n{df_text}\n\nInstruction: {user_instruction}",
"content": f"Context:\n{df_text}",
},
]
return messages
Expand Down
19 changes: 9 additions & 10 deletions lotus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ class SemanticMapOutput(SemanticMapPostprocessOutput):
pass


class SemanticExtractPostprocessOutput(BaseModel):
raw_outputs: list[str]
outputs: list[dict[str, str]]


class SemanticExtractOutput(SemanticExtractPostprocessOutput):
pass


class SemanticFilterPostprocessOutput(BaseModel):
raw_outputs: list[str]
outputs: list[bool]
Expand All @@ -70,16 +79,6 @@ class SemanticAggOutput(BaseModel):
outputs: list[str]


class SemanticExtractPostprocessOutput(BaseModel):
raw_outputs: list[str]
outputs: list[str]
quotes: list[str | None]


class SemanticExtractOutput(SemanticExtractPostprocessOutput):
pass


class SemanticJoinOutput(StatsMixin):
join_results: list[tuple[int, int, str | None]]
filter_outputs: list[bool]
Expand Down

0 comments on commit c85b8ed

Please sign in to comment.