Skip to content

Commit

Permalink
sem_schema operator
Browse files Browse the repository at this point in the history
  • Loading branch information
StanChan03 committed Nov 21, 2024
1 parent 606d761 commit 0afffbb
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 1 deletion.
18 changes: 18 additions & 0 deletions examples/op_examples/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from datasets import load_dataset

import lotus
from lotus.models import LM

lm = LM(model="ollama/llama3.1")

lotus.settings.configure(lm=lm)

dataset = load_dataset("CShorten/ML-ArXiv-Papers", split="train")
df = dataset.to_pandas().head(3)

columns = ["problem", "dataset", "results"]
col_descriptions = ["Description of the problem", "What dataset is used", "What results are obtained"]

user_instruction = "{abstract}"
new_df = df.sem_schema(user_instruction, columns, col_descriptions)
print(new_df)
27 changes: 26 additions & 1 deletion lotus/sem_ops/postprocessors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import json
import re

import lotus
from lotus.types import SemanticExtractPostprocessOutput, SemanticFilterPostprocessOutput, SemanticMapPostprocessOutput
from lotus.types import (
SemanticExtractPostprocessOutput,
SemanticFilterPostprocessOutput,
SemanticMapPostprocessOutput,
SemanticSchemaPostprocessOutput,
)


def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput:
Expand Down Expand Up @@ -52,6 +58,25 @@ def map_postprocess(llm_answers: list[str], cot_reasoning: bool = False) -> Sema
return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)


def schema_postprocess(llm_answers: list[str]) -> SemanticSchemaPostprocessOutput:
"""
Postprocess the output of the schema operator to extract the schema.
Args:
llm_answers (list[str]): The list of llm answers containging the schema.
Returns:
SemanticSchemaPostprocessOutput
"""
schema_data = []
for answers in llm_answers:
cleaned_answers = re.findall(r"(\{.*\})", answers, re.DOTALL)[0]
cleaned_answers = re.sub(r"\\(?![\"\\/bfnrt])", r"\\\\", cleaned_answers)
output = json.loads(cleaned_answers)
schema_data.append(output)
return SemanticSchemaPostprocessOutput(raw_outputs=llm_answers, outputs=schema_data)


def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFilterPostprocessOutput:
"""
Postprocess the output of the filter operator with CoT reasoning.
Expand Down
109 changes: 109 additions & 0 deletions lotus/sem_ops/sem_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Callable

import pandas as pd

import lotus
from lotus.models import LM
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticSchemaOutput, SemanticSchemaPostprocessOutput

from .postprocessors import schema_postprocess


def sem_schema(
docs: list[str],
model: LM,
columns: list[str],
col_description: list[str],
postprocessor: Callable[[list[str]], SemanticSchemaPostprocessOutput] = schema_postprocess,
) -> SemanticSchemaOutput:
"""
Schemas a list of documents using a model.
Args:
docs (list[str]): The list of documents to schema.
model (lotus.models.LM): The model to use.
columns (list[str]): The columns to schema.
col_description (str): The description of the columns.
postprocessor (Callable): The postprocessor for the model outputs. Defaults to schema_postprocess.
Returns:
SemanticSchemaOutput: The outputs, raw outputs, and quotes.
"""

# prepare model inputs
inputs = []
for doc in docs:
prompt = task_instructions.schema_formatter(doc, columns, col_description)
lotus.logger.debug(f"input to model: {prompt}")
lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}")
inputs.append(prompt)

# call model
lm_output: LMOutput = model(inputs)

# post process results
postprocess_output = postprocessor(lm_output.outputs)
lotus.logger.debug(f"raw_outputs: {lm_output.outputs}")
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")

return SemanticSchemaOutput(**postprocess_output.model_dump())


@pd.api.extensions.register_dataframe_accessor("sem_schema")
class SemSchemaDataFrame:
def __init__(self, pandas_obj: pd.DataFrame):
self._validate(pandas_obj)
self._obj = pandas_obj

@staticmethod
def _validate(obj: pd.DataFrame) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

def __call__(
self,
user_instruction: str,
columns: list[str],
col_description: list[str],
postprocessor: Callable[[list[str]], SemanticSchemaPostprocessOutput] = schema_postprocess,
return_raw_outputs: bool = False,
) -> pd.DataFrame:
"""
Schemas the attributes and values of a dataframe.
Args:
user_instruction (str): The columns from the documents to schema.
columns (list[str]): The columns to schema.
col_description (str): The description of the columns.
postprocessor (Callable): The postprocessor for the model outputs. Defaults to schema_postprocess.
return_raw_outputs (bool): Whether to return raw outputs. Defaults to False.
Returns:
pd.DataFrame: The dataframe with the new mapped columns.
"""
col_li = lotus.nl_expression.parse_cols(user_instruction)

# check that column exists
for column in col_li:
if column not in self._obj.columns:
raise ValueError(f"Column {column} not found in DataFrame")

docs = task_instructions.df2text(self._obj, col_li)

out = sem_schema(
docs=docs,
model=lotus.settings.lm,
columns=columns,
col_description=col_description,
postprocessor=postprocessor,
)

new_df = pd.DataFrame()

for column, value in zip(columns, out.outputs):
new_df[column] = value

new_df = new_df.reset_index(drop=True)

return new_df
16 changes: 16 additions & 0 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,22 @@ def map_formatter(
return messages


def schema_formatter(df_text: str, columns: list[str], column_description: list[str]) -> list[dict[str, str]]:
sys_instruction = (
"The user will provide the columns that need to be extracted as well as the column descriptions and some relevant context.\n"
f"Your job is to extract these columns from the context in JSONL format in a single line with the following fields {columns}\n"
"Only repsond in JSONL format and no other text. Your output will be parsed json.loads"
)
messages = [
{"role": "system", "content": sys_instruction},
{
"role": "user",
"content": f"Context:\n{df_text}",
},
]
return messages


def extract_formatter(df_text: str, user_instruction: str) -> list[dict[str, str]]:
sys_instruction = (
"The user will provide an instruction and some relevant context.\n"
Expand Down
9 changes: 9 additions & 0 deletions lotus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ class SemanticMapOutput(SemanticMapPostprocessOutput):
pass


class SemanticSchemaPostprocessOutput(BaseModel):
raw_outputs: list[str]
outputs: list[dict[str, str]]


class SemanticSchemaOutput(SemanticSchemaPostprocessOutput):
pass


class SemanticFilterPostprocessOutput(BaseModel):
raw_outputs: list[str]
outputs: list[bool]
Expand Down

0 comments on commit 0afffbb

Please sign in to comment.