Skip to content

Commit

Permalink
🔧 improve chain creation
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Dec 26, 2023
1 parent a9b0fe1 commit 224f143
Showing 1 changed file with 86 additions and 75 deletions.
161 changes: 86 additions & 75 deletions src/funcchain/chain/creation.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
from types import UnionType
from typing import TypeVar, Type
from typing import Type, TypeVar

from langchain_core.callbacks import Callbacks
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables import (
RunnableSerializable,
RunnableWithFallbacks,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableSerializable
from PIL import Image
from pydantic import BaseModel

from funcchain._llms import ChatLlamaCpp

from ..parser import MultiToolParser, ParserBaseModel, PydanticFuncParser
from ..settings import FuncchainSettings
from ..streaming import stream_handler
from ..utils import (
parser_for,
count_tokens,
is_function_model,
is_vision_model,
multi_pydantic_to_functions,
parser_for,
pydantic_to_functions,
pydantic_to_grammar,
univeral_model_selector,
Expand All @@ -44,7 +40,7 @@ def create_union_chain(
system: str,
memory: BaseChatMessageHistory,
context: list[BaseMessage],
llm: BaseChatModel | RunnableWithFallbacks,
llm: BaseChatModel,
input_kwargs: dict[str, str],
) -> RunnableSerializable[dict[str, str], BaseModel]:
"""
Expand All @@ -64,16 +60,7 @@ def create_union_chain(

functions = multi_pydantic_to_functions(output_types)

if isinstance(llm, RunnableWithFallbacks):
llm = llm.runnable.bind(**functions).with_fallbacks(
[
fallback.bind(**functions)
for fallback in llm.fallbacks
if hasattr(llm, "fallbacks")
]
)
else:
llm = llm.bind(**functions) # type: ignore
llm = llm.bind(**functions) # type: ignore

prompt = create_chat_prompt(
system,
Expand All @@ -93,24 +80,15 @@ def create_union_chain(
def create_pydanctic_chain(
output_type: type[BaseModel],
prompt: ChatPromptTemplate,
llm: BaseChatModel | RunnableWithFallbacks,
llm: BaseChatModel,
input_kwargs: dict[str, str],
) -> RunnableSerializable[dict[str, str], BaseModel]:
# TODO: check these format_instructions
input_kwargs["format_instructions"] = f"Extract to {output_type.__name__}."
functions = pydantic_to_functions(output_type)

llm = (
llm.runnable.bind(**functions).with_fallbacks( # type: ignore
[
fallback.bind(**functions)
for fallback in llm.fallbacks
if hasattr(llm, "fallbacks")
]
)
if isinstance(llm, RunnableWithFallbacks)
else llm.bind(**functions)
)
llm = llm.bind(**functions) # type: ignore

return prompt | llm | PydanticFuncParser(pydantic_schema=output_type)


Expand All @@ -127,13 +105,17 @@ def create_chain(
Compile a langchain runnable chain from the funcchain syntax.
"""
# large language model
llm = _gather_llm(settings)
_llm = _gather_llm(settings)
llm = _add_custom_callbacks(_llm, settings)

parser = parser_for(output_type)

# add format instructions for parser
if parser and not is_function_model(llm):
instruction = _add_format_instructions(
f_instructions = None
if parser and (settings.streaming or not is_function_model(llm)):
# streaming behavior is not supported for function models
# but for normal function models we do not need to add format instructions
instruction, f_instructions = _add_format_instructions(
parser,
instruction,
input_kwargs,
Expand All @@ -151,29 +133,18 @@ def create_chain(
images = _handle_images(llm, input_kwargs)

# create prompts
instruction_prompt = create_instruction_prompt(instruction, images, input_kwargs)
instruction_prompt = create_instruction_prompt(
instruction,
images,
input_kwargs,
format_instructions=f_instructions,
)
chat_prompt = create_chat_prompt(system, instruction_prompt, context, memory)

# add formatted instruction to chat history
memory.add_message(instruction_prompt.format(**input_kwargs))

if isinstance(llm, ChatLlamaCpp):
if isinstance(output_type, UnionType):
# TODO: implement Union Type grammar
raise NotImplementedError(
"Union types are not yet supported for LlamaCpp models."
)
if issubclass(output_type, BaseModel) and not issubclass(
output_type, ParserBaseModel
):
from llama_cpp import LlamaGrammar

grammar = pydantic_to_grammar(output_type)
setattr(
llm,
"grammar",
LlamaGrammar.from_string(grammar, verbose=False),
)
_inject_grammar_for_local_models(llm, output_type)

# function model patches
if is_function_model(llm):
Expand All @@ -191,21 +162,24 @@ def create_chain(
if issubclass(output_type, BaseModel) and not issubclass(
output_type, ParserBaseModel
):
return create_pydanctic_chain( # type: ignore
output_type,
chat_prompt,
llm,
input_kwargs,
)

if settings.streaming and hasattr(llm, "model_kwargs"):
llm.model_kwargs = {"response_format": {"type": "json_object"}}
else:
return create_pydanctic_chain( # type: ignore
output_type,
chat_prompt,
llm,
input_kwargs,
)
assert parser is not None
return chat_prompt | llm | parser


def _add_format_instructions(
parser: BaseOutputParser,
instruction: str,
input_kwargs: dict[str, str],
) -> str:
) -> tuple[str, str | None]:
"""
Add parsing format instructions
to the instruction message and input_kwargs
Expand All @@ -215,9 +189,9 @@ def _add_format_instructions(
if format_instructions := parser.get_format_instructions():
instruction += "\n{format_instructions}"
input_kwargs["format_instructions"] = format_instructions
return instruction
return instruction, format_instructions
except NotImplementedError:
return instruction
return instruction, None


def _crop_large_inputs(
Expand All @@ -239,7 +213,7 @@ def _crop_large_inputs(


def _handle_images(
llm: BaseChatModel | RunnableWithFallbacks,
llm: BaseChatModel,
input_kwargs: dict[str, str],
) -> list[Image.Image]:
"""
Expand All @@ -256,12 +230,33 @@ def _handle_images(
return images


def _inject_grammar_for_local_models(llm: BaseChatModel, output_type: type) -> None:
"""
Inject GBNF grammar into local models.
"""
try:
from funcchain._llms import ChatOllama
except: # noqa
pass
else:
if isinstance(llm, ChatOllama):
if isinstance(output_type, UnionType):
raise NotImplementedError(
"Union types are not yet supported for LlamaCpp models."
) # TODO: implement

if issubclass(output_type, BaseModel) and not issubclass(
output_type, ParserBaseModel
):
llm.grammar = pydantic_to_grammar(output_type)
if issubclass(output_type, ParserBaseModel):
llm.grammar = output_type.custom_grammar()


def _gather_llm(
settings: FuncchainSettings,
) -> BaseChatModel | RunnableWithFallbacks:
if isinstance(settings.llm, RunnableWithFallbacks) or isinstance(
settings.llm, BaseChatModel
):
) -> BaseChatModel:
if isinstance(settings.llm, BaseChatModel):
llm = settings.llm
else:
llm = univeral_model_selector(settings)
Expand All @@ -271,12 +266,28 @@ def _gather_llm(
"No language model provided. Either set the llm environment variable or "
"pass a model to the `chain` function."
)
return llm


def _add_custom_callbacks(
llm: BaseChatModel, settings: FuncchainSettings
) -> BaseChatModel:
callbacks: Callbacks = []

if handler := stream_handler.get():
callbacks = [handler]

if settings.console_stream:
from ..streaming import AsyncStreamHandler

callbacks = [
AsyncStreamHandler(print, {"end": "", "flush": True}),
]

if callbacks:
settings.streaming = True
if isinstance(llm, RunnableWithFallbacks) and isinstance(
llm.runnable, BaseChatModel
):
llm.runnable.callbacks = [handler]
elif isinstance(llm, BaseChatModel):
llm.callbacks = [handler]
if hasattr(llm, "streaming"):
llm.streaming = True
llm.callbacks = callbacks

return llm

0 comments on commit 224f143

Please sign in to comment.