From 9812a2063ed62fa710b6334e5c3de657d949d27b Mon Sep 17 00:00:00 2001 From: Jordan Wu <101218661+jordan-definitive@users.noreply.github.com> Date: Mon, 8 Jan 2024 11:50:33 -0800 Subject: [PATCH] move AllFunctionTypes out of openassistant library (#97) --- .../fast-api-server/fast_api_server/main.py | 39 ++++++++++++++++++- .../openassistants/core/assistant.py | 11 ++---- .../openassistants/functions/crud.py | 38 ++++++------------ 3 files changed, 52 insertions(+), 36 deletions(-) diff --git a/examples/fast-api-server/fast_api_server/main.py b/examples/fast-api-server/fast_api_server/main.py index a938601..d56e975 100644 --- a/examples/fast-api-server/fast_api_server/main.py +++ b/examples/fast-api-server/fast_api_server/main.py @@ -1,19 +1,54 @@ import os +from typing import Annotated from fastapi import FastAPI, status from fastapi.middleware.cors import CORSMiddleware from langchain.embeddings import OpenAIEmbeddings from langchain.storage import RedisStore +from openassistants.contrib.advisor_function import AdvisorFunction +from openassistants.contrib.duckdb_query import DuckDBQueryFunction +from openassistants.contrib.langchain_ddg_tool import DuckDuckGoToolFunction +from openassistants.contrib.python_eval import PythonEvalFunction +from openassistants.contrib.sqlalchemy_query import QueryFunction +from openassistants.contrib.text_response import TextResponseFunction from openassistants.core.assistant import Assistant -from openassistants.functions.crud import OpenAPILibrary, PythonLibrary +from openassistants.functions.base import IFunction +from openassistants.functions.crud import ( + LocalYAMLLibrary, + OpenAPILibrary, + PythonLibrary, +) from openassistants.utils.langchain_util import LangChainCachedEmbeddings from openassistants_fastapi import RouteAssistants, create_router +from pydantic import Field, TypeAdapter from fast_api_server.find_email_by_name_function import find_email_by_name_function app = FastAPI() +# Specify all the function types that are allowed in the local YAML library +AllFunctionTypes = Annotated[ + QueryFunction + | DuckDBQueryFunction + | PythonEvalFunction + | DuckDuckGoToolFunction + | TextResponseFunction + | AdvisorFunction, + Field(json_schema_extra={"discriminator": "type"}), +] + + +def model_parser(d: dict) -> IFunction: + return TypeAdapter(AllFunctionTypes).validate_python(d) # type: ignore + + +local_library = LocalYAMLLibrary( + "piedpiper", + model_parser, + "library", +) + # create a library with the custom function custom_python_lib = PythonLibrary(functions=[find_email_by_name_function]) @@ -38,7 +73,7 @@ hooli_assistant = Assistant( - libraries=["piedpiper", custom_python_lib, openapi_lib], + libraries=[local_library, custom_python_lib, openapi_lib], scope_description="""Only answer questions about Hooli company related matters. You're also allowed to answer questions that refer to anything in the current chat history.""", # noqa: E501 entity_embedding_model=entity_embedding_model, diff --git a/packages/openassistants/openassistants/core/assistant.py b/packages/openassistants/openassistants/core/assistant.py index 406a23c..faa2b42 100644 --- a/packages/openassistants/openassistants/core/assistant.py +++ b/packages/openassistants/openassistants/core/assistant.py @@ -19,7 +19,7 @@ IFunction, IFunctionLibrary, ) -from openassistants.functions.crud import LocalFunctionLibrary, PythonLibrary +from openassistants.functions.crud import PythonLibrary from openassistants.llm_function_calling.entity_resolution import resolve_entities from openassistants.llm_function_calling.fallback import perform_general_qa from openassistants.llm_function_calling.infilling import ( @@ -45,7 +45,7 @@ class Assistant: def __init__( self, - libraries: List[str | IFunctionLibrary], + libraries: List[IFunctionLibrary], function_identification: Optional[BaseChatModel] = None, function_infilling: Optional[BaseChatModel] = None, function_summarization: Optional[BaseChatModel] = None, @@ -76,12 +76,7 @@ def __init__( self.entity_embedding_model = ( entity_embedding_model or LangChainCachedEmbeddings(OpenAIEmbeddings()) ) - self.function_libraries = [ - library - if isinstance(library, IFunctionLibrary) - else LocalFunctionLibrary(library) - for library in libraries - ] + self.function_libraries = libraries if add_index: index_func: IFunction = IndexFunction( diff --git a/packages/openassistants/openassistants/functions/crud.py b/packages/openassistants/openassistants/functions/crud.py index e682fb1..d3e94f6 100644 --- a/packages/openassistants/openassistants/functions/crud.py +++ b/packages/openassistants/openassistants/functions/crud.py @@ -4,7 +4,6 @@ from json.decoder import JSONDecodeError from pathlib import Path from typing import ( - Annotated, Any, Callable, Dict, @@ -17,35 +16,18 @@ from langchain.chains.openai_functions.openapi import openapi_spec_to_openai_fn from langchain_community.utilities.openapi import OpenAPISpec -from openassistants.contrib.advisor_function import AdvisorFunction -from openassistants.contrib.duckdb_query import DuckDBQueryFunction -from openassistants.contrib.langchain_ddg_tool import DuckDuckGoToolFunction from openassistants.contrib.python_callable import PythonCallableFunction -from openassistants.contrib.python_eval import PythonEvalFunction -from openassistants.contrib.sqlalchemy_query import QueryFunction -from openassistants.contrib.text_response import TextResponseFunction from openassistants.data_models.function_output import TextOutput from openassistants.data_models.json_schema import JSONSchema from openassistants.functions.base import ( - BaseFunction, BaseFunctionParameters, IFunction, IFunctionLibrary, ) from openassistants.utils import yaml as yaml_utils -from pydantic import Field, TypeAdapter +from pydantic import TypeAdapter from starlette.concurrency import run_in_threadpool -AllFunctionTypes = Annotated[ - QueryFunction - | DuckDBQueryFunction - | PythonEvalFunction - | DuckDuckGoToolFunction - | TextResponseFunction - | AdvisorFunction, - Field(json_schema_extra={"discriminator": "type"}), -] - class BaseFileLibrary(IFunctionLibrary, abc.ABC): @abc.abstractmethod @@ -72,19 +54,23 @@ async def get_all_functions(self) -> Sequence[IFunction]: return funcs # type: ignore -class LocalFunctionLibrary(BaseFileLibrary): - def __init__(self, library_id: str, directory: str = "library"): +class LocalYAMLLibrary(BaseFileLibrary): + def __init__( + self, + library_id: str, + model_parser: Callable[[dict], IFunction], + directory: str = "library", + ): self.library_id = library_id + self.model_parser = model_parser self.directory = Path(directory) / library_id - def read(self, function_id: str) -> Optional[BaseFunction]: + def read(self, function_id: str) -> Optional[IFunction]: try: if (yaml_file := self.directory / f"{function_id}.yaml").exists(): with yaml_file.open() as f: - parsed_yaml = yaml_utils.load(f) - return TypeAdapter(AllFunctionTypes).validate_python( - parsed_yaml | {"id": function_id} - ) # type: ignore + yaml_dict = yaml_utils.load(f) + return self.model_parser(yaml_dict | {"id": function_id}) else: return None except Exception as e: