From 0afffbbcf127d9ef7b588e45fbfb5c36e06693d5 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Wed, 20 Nov 2024 20:40:44 -0800 Subject: [PATCH] sem_schema operator --- examples/op_examples/schema.py | 18 +++++ lotus/sem_ops/postprocessors.py | 27 ++++++- lotus/sem_ops/sem_schema.py | 109 +++++++++++++++++++++++++++ lotus/templates/task_instructions.py | 16 ++++ lotus/types.py | 9 +++ 5 files changed, 178 insertions(+), 1 deletion(-) create mode 100644 examples/op_examples/schema.py create mode 100644 lotus/sem_ops/sem_schema.py diff --git a/examples/op_examples/schema.py b/examples/op_examples/schema.py new file mode 100644 index 00000000..6316884e --- /dev/null +++ b/examples/op_examples/schema.py @@ -0,0 +1,18 @@ +from datasets import load_dataset + +import lotus +from lotus.models import LM + +lm = LM(model="ollama/llama3.1") + +lotus.settings.configure(lm=lm) + +dataset = load_dataset("CShorten/ML-ArXiv-Papers", split="train") +df = dataset.to_pandas().head(3) + +columns = ["problem", "dataset", "results"] +col_descriptions = ["Description of the problem", "What dataset is used", "What results are obtained"] + +user_instruction = "{abstract}" +new_df = df.sem_schema(user_instruction, columns, col_descriptions) +print(new_df) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 559bb6ce..413d6d4b 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -1,7 +1,13 @@ import json +import re import lotus -from lotus.types import SemanticExtractPostprocessOutput, SemanticFilterPostprocessOutput, SemanticMapPostprocessOutput +from lotus.types import ( + SemanticExtractPostprocessOutput, + SemanticFilterPostprocessOutput, + SemanticMapPostprocessOutput, + SemanticSchemaPostprocessOutput, +) def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput: @@ -52,6 +58,25 @@ def map_postprocess(llm_answers: list[str], cot_reasoning: bool = False) -> Sema return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) +def schema_postprocess(llm_answers: list[str]) -> SemanticSchemaPostprocessOutput: + """ + Postprocess the output of the schema operator to extract the schema. + + Args: + llm_answers (list[str]): The list of llm answers containging the schema. + + Returns: + SemanticSchemaPostprocessOutput + """ + schema_data = [] + for answers in llm_answers: + cleaned_answers = re.findall(r"(\{.*\})", answers, re.DOTALL)[0] + cleaned_answers = re.sub(r"\\(?![\"\\/bfnrt])", r"\\\\", cleaned_answers) + output = json.loads(cleaned_answers) + schema_data.append(output) + return SemanticSchemaPostprocessOutput(raw_outputs=llm_answers, outputs=schema_data) + + def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFilterPostprocessOutput: """ Postprocess the output of the filter operator with CoT reasoning. diff --git a/lotus/sem_ops/sem_schema.py b/lotus/sem_ops/sem_schema.py new file mode 100644 index 00000000..30c428ff --- /dev/null +++ b/lotus/sem_ops/sem_schema.py @@ -0,0 +1,109 @@ +from typing import Callable + +import pandas as pd + +import lotus +from lotus.models import LM +from lotus.templates import task_instructions +from lotus.types import LMOutput, SemanticSchemaOutput, SemanticSchemaPostprocessOutput + +from .postprocessors import schema_postprocess + + +def sem_schema( + docs: list[str], + model: LM, + columns: list[str], + col_description: list[str], + postprocessor: Callable[[list[str]], SemanticSchemaPostprocessOutput] = schema_postprocess, +) -> SemanticSchemaOutput: + """ + Schemas a list of documents using a model. + + Args: + docs (list[str]): The list of documents to schema. + model (lotus.models.LM): The model to use. + columns (list[str]): The columns to schema. + col_description (str): The description of the columns. + postprocessor (Callable): The postprocessor for the model outputs. Defaults to schema_postprocess. + + Returns: + SemanticSchemaOutput: The outputs, raw outputs, and quotes. + """ + + # prepare model inputs + inputs = [] + for doc in docs: + prompt = task_instructions.schema_formatter(doc, columns, col_description) + lotus.logger.debug(f"input to model: {prompt}") + lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}") + inputs.append(prompt) + + # call model + lm_output: LMOutput = model(inputs) + + # post process results + postprocess_output = postprocessor(lm_output.outputs) + lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") + lotus.logger.debug(f"outputs: {postprocess_output.outputs}") + + return SemanticSchemaOutput(**postprocess_output.model_dump()) + + +@pd.api.extensions.register_dataframe_accessor("sem_schema") +class SemSchemaDataFrame: + def __init__(self, pandas_obj: pd.DataFrame): + self._validate(pandas_obj) + self._obj = pandas_obj + + @staticmethod + def _validate(obj: pd.DataFrame) -> None: + if not isinstance(obj, pd.DataFrame): + raise AttributeError("Must be a DataFrame") + + def __call__( + self, + user_instruction: str, + columns: list[str], + col_description: list[str], + postprocessor: Callable[[list[str]], SemanticSchemaPostprocessOutput] = schema_postprocess, + return_raw_outputs: bool = False, + ) -> pd.DataFrame: + """ + Schemas the attributes and values of a dataframe. + + Args: + user_instruction (str): The columns from the documents to schema. + columns (list[str]): The columns to schema. + col_description (str): The description of the columns. + postprocessor (Callable): The postprocessor for the model outputs. Defaults to schema_postprocess. + return_raw_outputs (bool): Whether to return raw outputs. Defaults to False. + + Returns: + pd.DataFrame: The dataframe with the new mapped columns. + """ + col_li = lotus.nl_expression.parse_cols(user_instruction) + + # check that column exists + for column in col_li: + if column not in self._obj.columns: + raise ValueError(f"Column {column} not found in DataFrame") + + docs = task_instructions.df2text(self._obj, col_li) + + out = sem_schema( + docs=docs, + model=lotus.settings.lm, + columns=columns, + col_description=col_description, + postprocessor=postprocessor, + ) + + new_df = pd.DataFrame() + + for column, value in zip(columns, out.outputs): + new_df[column] = value + + new_df = new_df.reset_index(drop=True) + + return new_df diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 1371dfb7..1f2f4df5 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -203,6 +203,22 @@ def map_formatter( return messages +def schema_formatter(df_text: str, columns: list[str], column_description: list[str]) -> list[dict[str, str]]: + sys_instruction = ( + "The user will provide the columns that need to be extracted as well as the column descriptions and some relevant context.\n" + f"Your job is to extract these columns from the context in JSONL format in a single line with the following fields {columns}\n" + "Only repsond in JSONL format and no other text. Your output will be parsed json.loads" + ) + messages = [ + {"role": "system", "content": sys_instruction}, + { + "role": "user", + "content": f"Context:\n{df_text}", + }, + ] + return messages + + def extract_formatter(df_text: str, user_instruction: str) -> list[dict[str, str]]: sys_instruction = ( "The user will provide an instruction and some relevant context.\n" diff --git a/lotus/types.py b/lotus/types.py index 1d7a3bcc..64d135df 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -56,6 +56,15 @@ class SemanticMapOutput(SemanticMapPostprocessOutput): pass +class SemanticSchemaPostprocessOutput(BaseModel): + raw_outputs: list[str] + outputs: list[dict[str, str]] + + +class SemanticSchemaOutput(SemanticSchemaPostprocessOutput): + pass + + class SemanticFilterPostprocessOutput(BaseModel): raw_outputs: list[str] outputs: list[bool]