From 07c0e256f2012bb31e95d80dc8d1bc9ee87aa204 Mon Sep 17 00:00:00 2001 From: Shroominic Date: Mon, 12 Feb 2024 20:58:33 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20fix=20input=20dict=20type?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/funcchain/backend/compiler.py | 10 +++++----- src/funcchain/backend/meta_inspect.py | 4 ++-- src/funcchain/syntax/components/router.py | 2 +- src/funcchain/syntax/decorators.py | 6 +++--- src/funcchain/syntax/executable.py | 4 ++-- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/funcchain/backend/compiler.py b/src/funcchain/backend/compiler.py index 90a2981..8070e9a 100644 --- a/src/funcchain/backend/compiler.py +++ b/src/funcchain/backend/compiler.py @@ -45,7 +45,7 @@ def create_union_chain( context: list[BaseMessage], llm: BaseChatModel, input_kwargs: dict[str, Any], -) -> Runnable[dict[str, str], Any]: +) -> Runnable[dict[str, Any], Any]: """ Compile a langchain runnable chain from the funcchain syntax. """ @@ -78,7 +78,7 @@ def create_union_chain( def patch_openai_function_to_pydantic( llm: BaseChatModel, output_type: type[BaseModel], - input_kwargs: dict[str, str], + input_kwargs: dict[str, Any], primitive_type: bool = False, ) -> tuple[BaseChatModel, BaseGenerationOutputParser]: input_kwargs["format_instructions"] = f"Extract to {output_type.__name__}." @@ -101,7 +101,7 @@ def create_chain( settings: FuncchainSettings, input_args: list[tuple[str, type]], temp_images: list[Image] = [], -) -> Runnable[dict[str, str], ChainOutput]: +) -> Runnable[dict[str, Any], ChainOutput]: """ Compile a langchain runnable chain from the funcchain syntax. """ @@ -209,7 +209,7 @@ def create_chain( return chat_prompt | llm | parser -def compile_chain(signature: Signature, temp_images: list[Image] = []) -> Runnable[dict[str, str], ChainOutput]: +def compile_chain(signature: Signature, temp_images: list[Image] = []) -> Runnable[dict[str, Any], ChainOutput]: """ Compile a signature to a runnable chain. """ @@ -236,7 +236,7 @@ def compile_chain(signature: Signature, temp_images: list[Image] = []) -> Runnab def _add_format_instructions( parser: BaseOutputParser, instruction: str, - input_kwargs: dict[str, str], + input_kwargs: dict[str, Any], ) -> tuple[str, str | None]: """ Add parsing format instructions diff --git a/src/funcchain/backend/meta_inspect.py b/src/funcchain/backend/meta_inspect.py index 15f265f..1acf2dc 100644 --- a/src/funcchain/backend/meta_inspect.py +++ b/src/funcchain/backend/meta_inspect.py @@ -1,6 +1,6 @@ from inspect import FrameInfo, currentframe, getouterframes from types import FunctionType, UnionType -from typing import Optional +from typing import Any, Optional FUNC_DEPTH = 4 @@ -53,7 +53,7 @@ def get_output_types(f: Optional[FunctionType] = None) -> list[type]: raise ValueError("The funcchain must have a return type annotation") -def kwargs_from_parent() -> dict[str, str]: +def kwargs_from_parent() -> dict[str, Any]: """ Get the kwargs from the parent function. """ diff --git a/src/funcchain/syntax/components/router.py b/src/funcchain/syntax/components/router.py index a5fc1a7..3174dee 100644 --- a/src/funcchain/syntax/components/router.py +++ b/src/funcchain/syntax/components/router.py @@ -62,7 +62,7 @@ def runnable(self) -> RunnableSerializable[HumanMessage, AIMessage]: runnables={name: run["handler"] for name, run in self.routes.items()}, ) # maybe add auto conversion of strings to AI Messages/Chunks - def _selector(self) -> Runnable[dict[str, str], Any]: + def _selector(self) -> Runnable[dict[str, Any], Any]: RouteChoices = Enum( # type: ignore "RouteChoices", {r: r for r in self.routes.keys()}, diff --git a/src/funcchain/syntax/decorators.py b/src/funcchain/syntax/decorators.py index 8a1b317..a35b5d5 100644 --- a/src/funcchain/syntax/decorators.py +++ b/src/funcchain/syntax/decorators.py @@ -1,5 +1,5 @@ from types import FunctionType -from typing import Callable, Optional, TypeVar, Union, overload +from typing import Any, Callable, Optional, TypeVar, Union, overload from langchain_core.runnables import Runnable @@ -15,7 +15,7 @@ @overload def runnable( f: Callable[..., OutputT], -) -> Runnable[dict[str, str], OutputT]: +) -> Runnable[dict[str, Any], OutputT]: ... @@ -25,7 +25,7 @@ def runnable( llm: UniversalChatModel = None, settings: SettingsOverride = {}, auto_tune: bool = False, -) -> Callable[[Callable], Runnable[dict[str, str], OutputT]]: +) -> Callable[[Callable], Runnable[dict[str, Any], OutputT]]: ... diff --git a/src/funcchain/syntax/executable.py b/src/funcchain/syntax/executable.py index 3ce14e4..250fb25 100644 --- a/src/funcchain/syntax/executable.py +++ b/src/funcchain/syntax/executable.py @@ -110,7 +110,7 @@ async def achain( history=context, settings=settings, ) - chain: Runnable[dict[str, str], Any] = compile_chain(sig, temp_images) + chain: Runnable[dict[str, Any], Any] = compile_chain(sig, temp_images) result = await chain.ainvoke(input_kwargs, {"run_name": get_parent_frame(2).function, "callbacks": callbacks}) if memory and isinstance(result, str): @@ -132,7 +132,7 @@ def compile_runnable( llm: UniversalChatModel = None, system: str = "", settings_override: SettingsOverride = {}, -) -> Runnable[dict[str, str], ChainOut]: +) -> Runnable[dict[str, Any], ChainOut]: """ On the fly compilation of the funcchain syntax. """