Skip to content

Commit

Permalink
🔧 fix primitive type parsing with openai
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Feb 11, 2024
1 parent 16dbab0 commit f4ee97a
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 13 deletions.
19 changes: 14 additions & 5 deletions examples/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@

def classify(
text: str,
) -> Literal["energetic", "sad", "flirty", "neural"]:
) -> Literal["happy", "sad", "flirty"]:
"""
Classify the text.
Classify the text as happy, sad, flirty, or neural.
"""
return chain()


print(
classify("Hello my name is Jeff."),
)
if __name__ == "__main__":
r = classify("Hey :)")
print(r)
assert r == "happy"

r = classify("Hey :(")
print(r)
assert r == "sad"

r = classify("Hey ;)")
print(r)
assert r == "flirty"
20 changes: 18 additions & 2 deletions src/funcchain/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from ..model.abilities import is_openai_function_model, is_vision_model
from ..model.defaults import univeral_model_selector
from ..parser.json_schema import RetryJsonPydanticParser
from ..parser.openai_functions import RetryOpenAIFunctionPydanticParser, RetryOpenAIFunctionPydanticUnionParser
from ..parser.openai_functions import (
RetryOpenAIFunctionPrimitiveTypeParser,
RetryOpenAIFunctionPydanticParser,
RetryOpenAIFunctionPydanticUnionParser,
)
from ..parser.primitive_types import RetryJsonPrimitiveTypeParser
from ..parser.schema_converter import pydantic_to_grammar
from ..parser.selector import parser_for
Expand Down Expand Up @@ -75,14 +79,17 @@ def patch_openai_function_to_pydantic(
llm: BaseChatModel,
output_type: type[BaseModel],
input_kwargs: dict[str, str],
primitive_type: bool = False,
) -> tuple[BaseChatModel, BaseGenerationOutputParser]:
input_kwargs["format_instructions"] = f"Extract to {output_type.__name__}."
functions = pydantic_to_functions(output_type)

_llm = llm
llm = llm.bind(**functions) # type: ignore

return llm, RetryOpenAIFunctionPydanticParser(pydantic_schema=output_type, retry=3, retry_llm=_llm)
if not primitive_type:
return llm, RetryOpenAIFunctionPydanticParser(pydantic_schema=output_type, retry=3, retry_llm=_llm)
return llm, RetryOpenAIFunctionPrimitiveTypeParser(pydantic_schema=output_type, retry=3, retry_llm=_llm)


def create_chain(
Expand Down Expand Up @@ -183,11 +190,20 @@ def create_chain(
if isinstance(parser, RetryJsonPydanticParser) or isinstance(parser, RetryJsonPrimitiveTypeParser):
output_type = parser.pydantic_object
if issubclass(output_type, BaseModel) and not issubclass(output_type, ParserBaseModel):
# openai json streaming
if settings.streaming and hasattr(llm, "model_kwargs"):
llm.model_kwargs = {"response_format": {"type": "json_object"}}
# primitive types
elif isinstance(parser, RetryJsonPrimitiveTypeParser):
llm, parser = patch_openai_function_to_pydantic(llm, output_type, input_kwargs, primitive_type=True)
# pydantic types
else:
assert isinstance(parser, RetryJsonPydanticParser)
llm, parser = patch_openai_function_to_pydantic(llm, output_type, input_kwargs)
# custom parsers
elif issubclass(output_type, ParserBaseModel):
# todo maybe add custom openai function parsing
...

assert parser is not None
return chat_prompt | llm | parser
Expand Down
18 changes: 14 additions & 4 deletions src/funcchain/parser/openai_functions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import copy
from typing import Type, TypeVar
from typing import Generic, Type, TypeVar

from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseChatModel
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.runnables import Runnable
from pydantic import BaseModel, ValidationError

from ..schema.types import UniversalChatModel
from ..syntax.output_types import CodeBlock as CodeBlock
from ..utils.msg_tools import msg_to_str

Expand All @@ -18,7 +18,7 @@ class RetryOpenAIFunctionPydanticParser(BaseGenerationOutputParser[M]):
pydantic_schema: Type[M]
args_only: bool = False
retry: int
retry_llm: BaseChatModel | str | None = None
retry_llm: UniversalChatModel = None

def parse_result(self, result: list[Generation], *, partial: bool = False) -> M:
try:
Expand Down Expand Up @@ -69,7 +69,7 @@ class RetryOpenAIFunctionPydanticUnionParser(BaseGenerationOutputParser[M]):
output_types: list[type[M]]
args_only: bool = False
retry: int
retry_llm: BaseChatModel | str | None = None
retry_llm: UniversalChatModel = None

def parse_result(self, result: list[Generation], *, partial: bool = False) -> M:
try:
Expand Down Expand Up @@ -142,3 +142,13 @@ def retry_chain(self) -> Runnable:
llm=self.retry_llm,
settings_override={"retry_parse": self.retry - 1},
)


class RetryOpenAIFunctionPrimitiveTypeParser(RetryOpenAIFunctionPydanticParser, Generic[M]):
"""
Parse primitve types by wrapping them in a PydanticModel and parsing them.
Examples: int, float, bool, list[str], dict[str, int], Literal["a", "b", "c"], etc.
"""

def parse_result(self, result: list[Generation], *, partial: bool = False) -> M:
return super().parse_result(result, partial=partial).value
10 changes: 8 additions & 2 deletions src/funcchain/parser/primitive_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
"""
from typing import Generic, TypeVar

from langchain_core.language_models import BaseChatModel
from pydantic import BaseModel, create_model

from ..schema.types import UniversalChatModel
from .json_schema import RetryJsonPydanticParser

M = TypeVar("M", bound=BaseModel)
Expand All @@ -21,7 +21,7 @@ def __init__(
self,
primitive_type: type,
retry: int = 1,
retry_llm: BaseChatModel | str | None = None,
retry_llm: UniversalChatModel = None,
) -> None:
super().__init__(
pydantic_object=create_model("Extract", value=(primitive_type, ...)),
Expand All @@ -30,4 +30,10 @@ def __init__(
)

def parse(self, text: str) -> M:
print("text", text)
print("super().parse(text)", super().parse(text))
return super().parse(text).value

def get_format_instructions(self) -> str:
"""TODO: override with optimized version"""
return super().get_format_instructions()

0 comments on commit f4ee97a

Please sign in to comment.