From fd4a17abf8a645d83032f828b437408051bac980 Mon Sep 17 00:00:00 2001 From: Jordan Wu <101218661+jordan-definitive@users.noreply.github.com> Date: Fri, 5 Jan 2024 15:27:20 -0800 Subject: [PATCH] cleanup function + library to have a clean name / interface --- .../fast-api-server/fast_api_server/main.py | 6 +- .../openassistants/contrib/index_function.py | 4 +- .../openassistants/core/assistant.py | 39 +++++++------ .../openassistants/eval/interaction.py | 6 +- .../openassistants/functions/base.py | 10 +++- .../openassistants/functions/crud.py | 55 ++++++++++--------- .../llm_function_calling/entity_resolution.py | 4 +- .../llm_function_calling/infilling.py | 8 +-- .../llm_function_calling/selection.py | 10 ++-- 9 files changed, 78 insertions(+), 64 deletions(-) diff --git a/examples/fast-api-server/fast_api_server/main.py b/examples/fast-api-server/fast_api_server/main.py index 885dc18..a938601 100644 --- a/examples/fast-api-server/fast_api_server/main.py +++ b/examples/fast-api-server/fast_api_server/main.py @@ -5,7 +5,7 @@ from langchain.embeddings import OpenAIEmbeddings from langchain.storage import RedisStore from openassistants.core.assistant import Assistant -from openassistants.functions.crud import OpenAPICRUD, PythonCRUD +from openassistants.functions.crud import OpenAPILibrary, PythonLibrary from openassistants.utils.langchain_util import LangChainCachedEmbeddings from openassistants_fastapi import RouteAssistants, create_router @@ -15,9 +15,9 @@ # create a library with the custom function -custom_python_lib = PythonCRUD(functions=[find_email_by_name_function]) +custom_python_lib = PythonLibrary(functions=[find_email_by_name_function]) -openapi_lib = OpenAPICRUD( +openapi_lib = OpenAPILibrary( spec="https://petstore3.swagger.io/api/v3/openapi.json", base_url="https://petstore3.swagger.io/api/v3", ) diff --git a/packages/openassistants/openassistants/contrib/index_function.py b/packages/openassistants/openassistants/contrib/index_function.py index f81de83..3e93c41 100644 --- a/packages/openassistants/openassistants/contrib/index_function.py +++ b/packages/openassistants/openassistants/contrib/index_function.py @@ -4,14 +4,14 @@ from openassistants.functions.base import ( BaseFunction, FunctionExecutionDependency, - IBaseFunction, + IFunction, ) from openassistants.functions.utils import AsyncStreamVersion class IndexFunction(BaseFunction): type: Literal["IndexFunction"] = "IndexFunction" - functions: Callable[[], Awaitable[List[IBaseFunction]]] + functions: Callable[[], Awaitable[List[IFunction]]] async def execute( self, deps: FunctionExecutionDependency diff --git a/packages/openassistants/openassistants/core/assistant.py b/packages/openassistants/openassistants/core/assistant.py index 0dd37c2..406a23c 100644 --- a/packages/openassistants/openassistants/core/assistant.py +++ b/packages/openassistants/openassistants/core/assistant.py @@ -15,10 +15,11 @@ from openassistants.data_models.function_input import FunctionCall, FunctionInputRequest from openassistants.functions.base import ( FunctionExecutionDependency, - IBaseFunction, IEntity, + IFunction, + IFunctionLibrary, ) -from openassistants.functions.crud import FunctionCRUD, LocalCRUD, PythonCRUD +from openassistants.functions.crud import LocalFunctionLibrary, PythonLibrary from openassistants.llm_function_calling.entity_resolution import resolve_entities from openassistants.llm_function_calling.fallback import perform_general_qa from openassistants.llm_function_calling.infilling import ( @@ -37,14 +38,14 @@ class Assistant: function_summarization: BaseChatModel function_fallback: BaseChatModel entity_embedding_model: Embeddings - function_libraries: List[FunctionCRUD] + function_libraries: List[IFunctionLibrary] scope_description: str - _cached_all_functions: List[IBaseFunction] + _cached_all_functions: List[IFunction] def __init__( self, - libraries: List[str | FunctionCRUD], + libraries: List[str | IFunctionLibrary], function_identification: Optional[BaseChatModel] = None, function_infilling: Optional[BaseChatModel] = None, function_summarization: Optional[BaseChatModel] = None, @@ -76,12 +77,14 @@ def __init__( entity_embedding_model or LangChainCachedEmbeddings(OpenAIEmbeddings()) ) self.function_libraries = [ - library if isinstance(library, FunctionCRUD) else LocalCRUD(library) + library + if isinstance(library, IFunctionLibrary) + else LocalFunctionLibrary(library) for library in libraries ] if add_index: - index_func: IBaseFunction = IndexFunction( + index_func: IFunction = IndexFunction( id="index", display_name="List functions", description=( @@ -96,19 +99,19 @@ def __init__( functions=self.get_all_functions, ) - self.function_libraries.append(PythonCRUD(functions=[index_func])) + self.function_libraries.append(PythonLibrary(functions=[index_func])) self._cached_all_functions = [] - async def get_all_functions(self) -> List[IBaseFunction]: + async def get_all_functions(self) -> List[IFunction]: if not self._cached_all_functions: - functions = [] + functions: List[IFunction] = [] for library in self.function_libraries: - functions.extend(await library.aread_all()) + functions.extend(await library.get_all_functions()) self._cached_all_functions = functions return self._cached_all_functions - async def get_function_by_id(self, function_id: str) -> Optional[IBaseFunction]: + async def get_function_by_id(self, function_id: str) -> Optional[IFunction]: functions = await self.get_all_functions() for function in functions: if function.get_id() == function_id: @@ -117,7 +120,7 @@ async def get_function_by_id(self, function_id: str) -> Optional[IBaseFunction]: async def execute_function( self, - function: IBaseFunction, + function: IFunction, func_args: Dict[str, Any], dependencies: Dict[str, Any], ): @@ -143,7 +146,7 @@ async def do_infilling( self, dependencies: dict, message: OpasUserMessage, - selected_function: IBaseFunction, + selected_function: IFunction, args_json_schema: dict, entities_info: Dict[str, List[IEntity]], ) -> Tuple[bool, dict]: @@ -199,12 +202,12 @@ async def do_infilling( async def handle_user_plaintext( self, message: OpasUserMessage, - all_functions: List[IBaseFunction], + all_functions: List[IFunction], dependencies: Dict[str, Any], autorun: bool, force_select_function: Optional[str], ) -> AsyncStreamVersion[List[OpasMessage]]: - selected_function: Optional[IBaseFunction] = None + selected_function: Optional[IFunction] = None # perform entity resolution chat_history: List[OpasMessage] = dependencies.get("chat_history") # type: ignore @@ -304,13 +307,13 @@ async def handle_user_plaintext( async def handle_user_input( self, message: OpasUserMessage, - all_functions: List[IBaseFunction], + all_functions: List[IFunction], dependencies: Dict[str, Any], ) -> AsyncStreamVersion[List[OpasMessage]]: if message.input_response is None: raise ValueError("message must have input_response") - selected_function: Optional[IBaseFunction] = None + selected_function: Optional[IFunction] = None for f in all_functions: if f.get_id() == message.input_response.name: diff --git a/packages/openassistants/openassistants/eval/interaction.py b/packages/openassistants/openassistants/eval/interaction.py index e8b39a0..8826d41 100644 --- a/packages/openassistants/openassistants/eval/interaction.py +++ b/packages/openassistants/openassistants/eval/interaction.py @@ -12,7 +12,7 @@ ) from openassistants.data_models.function_input import FunctionCall from openassistants.data_models.function_output import DataFrameOutput, TextOutput -from openassistants.functions.base import IBaseFunction +from openassistants.functions.base import IFunction from openassistants.utils.async_utils import last_value from pydantic import BaseModel, ConfigDict @@ -88,7 +88,7 @@ async def run_function_invocation( async def get_function( self, assistant: Assistant, - ) -> IBaseFunction: + ) -> IFunction: base_function = await assistant.get_function_by_id(self.function) if base_function is None: raise ValueError("Function not found") @@ -159,7 +159,7 @@ class FunctionInteractionResponseNode(BaseModel): user_input_response: OpasUserMessage assistant_function_invocation: OpasAssistantMessage function_response: OpasFunctionMessage - function_spec: IBaseFunction + function_spec: IFunction class FunctionInteractionResponse(FunctionInteractionResponseNode): diff --git a/packages/openassistants/openassistants/functions/base.py b/packages/openassistants/openassistants/functions/base.py index 2008db9..8d53bc8 100644 --- a/packages/openassistants/openassistants/functions/base.py +++ b/packages/openassistants/openassistants/functions/base.py @@ -35,7 +35,7 @@ def get_entities(self) -> Sequence[IEntity]: pass -class IBaseFunction(abc.ABC): +class IFunction(abc.ABC): @abc.abstractmethod def get_id(self) -> str: pass @@ -125,7 +125,7 @@ class BaseFunctionParameters(BaseModel): json_schema: JSONSchema = EMPTY_JSON_SCHEMA -class BaseFunction(IBaseFunction, BaseModel, abc.ABC): +class BaseFunction(IFunction, BaseModel, abc.ABC): id: str type: str display_name: Optional[str] = None @@ -157,3 +157,9 @@ def get_parameters_json_schema(self) -> JSONSchema: async def get_entity_configs(self) -> Mapping[str, IEntityConfig]: return {} + + +class IFunctionLibrary(abc.ABC): + @abc.abstractmethod + async def get_all_functions(self) -> Sequence[IFunction]: + pass diff --git a/packages/openassistants/openassistants/functions/crud.py b/packages/openassistants/openassistants/functions/crud.py index 9cd5dc6..e682fb1 100644 --- a/packages/openassistants/openassistants/functions/crud.py +++ b/packages/openassistants/openassistants/functions/crud.py @@ -3,7 +3,17 @@ import json from json.decoder import JSONDecodeError from pathlib import Path -from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import ( + Annotated, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) from langchain.chains.openai_functions.openapi import openapi_spec_to_openai_fn from langchain_community.utilities.openapi import OpenAPISpec @@ -19,7 +29,8 @@ from openassistants.functions.base import ( BaseFunction, BaseFunctionParameters, - IBaseFunction, + IFunction, + IFunctionLibrary, ) from openassistants.utils import yaml as yaml_utils from pydantic import Field, TypeAdapter @@ -36,27 +47,32 @@ ] -class FunctionCRUD(abc.ABC): +class BaseFileLibrary(IFunctionLibrary, abc.ABC): @abc.abstractmethod - def read(self, slug: str) -> Optional[IBaseFunction]: + def read(self, slug: str) -> Optional[IFunction]: pass @abc.abstractmethod def list_ids(self) -> List[str]: pass - async def aread(self, function_id: str) -> Optional[IBaseFunction]: + async def aread(self, function_id: str) -> Optional[IFunction]: return await run_in_threadpool(self.read, function_id) async def alist_ids(self) -> List[str]: return await run_in_threadpool(self.list_ids) - async def aread_all(self) -> List[IBaseFunction]: + async def get_all_functions(self) -> Sequence[IFunction]: ids = await self.alist_ids() - return await asyncio.gather(*[self.aread(f_id) for f_id in ids]) # type: ignore + funcs: List[IFunction | None] = await asyncio.gather( # type: ignore + *[self.aread(f_id) for f_id in ids] + ) + if None in funcs: + raise RuntimeError("Failed to load all functions") + return funcs # type: ignore -class LocalCRUD(FunctionCRUD): +class LocalFunctionLibrary(BaseFileLibrary): def __init__(self, library_id: str, directory: str = "library"): self.library_id = library_id self.directory = Path(directory) / library_id @@ -74,32 +90,21 @@ def read(self, function_id: str) -> Optional[BaseFunction]: except Exception as e: raise RuntimeError(f"Failed to load: {function_id}") from e - async def aread_all(self) -> List[BaseFunction]: - ids = self.list_ids() - return [self.read(f_id) for f_id in ids] # type: ignore - def list_ids(self) -> List[str]: return [ file.stem for file in self.directory.iterdir() if file.suffix == ".yaml" ] -class PythonCRUD(FunctionCRUD): - def __init__(self, functions: List[IBaseFunction]): +class PythonLibrary(IFunctionLibrary): + def __init__(self, functions: Sequence[IFunction]): self.functions = functions - def read(self, slug: str) -> Optional[IBaseFunction]: - for function in self.functions: - if function.get_id() == slug: - return function - - return None - - def list_ids(self) -> List[str]: - return [function.get_id() for function in self.functions] + async def get_all_functions(self) -> Sequence[IFunction]: + return self.functions -class OpenAPICRUD(PythonCRUD): +class OpenAPILibrary(PythonLibrary): openapi: OpenAPISpec @staticmethod @@ -159,6 +164,6 @@ def __init__(self, spec: Union[OpenAPISpec, str], base_url: Optional[str]): self.openapi.servers[0].url = base_url openai_functions = openapi_spec_to_openai_fn(self.openapi) - functions = OpenAPICRUD.openai_fns_to_openapi_function(openai_functions) + functions = OpenAPILibrary.openai_fns_to_openapi_function(openai_functions) super().__init__(functions) diff --git a/packages/openassistants/openassistants/llm_function_calling/entity_resolution.py b/packages/openassistants/openassistants/llm_function_calling/entity_resolution.py index 9a043a6..08bc4cf 100644 --- a/packages/openassistants/openassistants/llm_function_calling/entity_resolution.py +++ b/packages/openassistants/openassistants/llm_function_calling/entity_resolution.py @@ -7,9 +7,9 @@ from langchain.vectorstores.usearch import USearch from openassistants.data_models.chat_messages import OpasMessage from openassistants.functions.base import ( - IBaseFunction, IEntity, IEntityConfig, + IFunction, ) from openassistants.llm_function_calling.infilling import generate_arguments @@ -64,7 +64,7 @@ async def _get_entities( async def resolve_entities( - function: IBaseFunction, + function: IFunction, function_infilling_llm: BaseChatModel, embeddings: Embeddings, user_query: str, diff --git a/packages/openassistants/openassistants/llm_function_calling/infilling.py b/packages/openassistants/openassistants/llm_function_calling/infilling.py index b499ea5..322086c 100644 --- a/packages/openassistants/openassistants/llm_function_calling/infilling.py +++ b/packages/openassistants/openassistants/llm_function_calling/infilling.py @@ -4,14 +4,14 @@ from langchain.chat_models.base import BaseChatModel from langchain.schema.messages import HumanMessage from openassistants.data_models.chat_messages import OpasMessage -from openassistants.functions.base import IBaseFunction, IEntity +from openassistants.functions.base import IEntity, IFunction from openassistants.llm_function_calling.utils import ( build_chat_history_prompt, generate_to_json, ) -async def generate_argument_decisions_schema(function: IBaseFunction): +async def generate_argument_decisions_schema(function: IFunction): # Start with the base schema json_schema = function.get_parameters_json_schema() @@ -50,7 +50,7 @@ class NestedObject(TypedDict): async def generate_argument_decisions( - function: IBaseFunction, + function: IFunction, chat: BaseChatModel, user_query: str, chat_history: List[OpasMessage], @@ -93,7 +93,7 @@ def entity_to_json_schema_obj(entity: IEntity): async def generate_arguments( - function: IBaseFunction, + function: IFunction, chat: BaseChatModel, user_query: str, chat_history: List[OpasMessage], diff --git a/packages/openassistants/openassistants/llm_function_calling/selection.py b/packages/openassistants/openassistants/llm_function_calling/selection.py index 782d4a0..32e0ccb 100644 --- a/packages/openassistants/openassistants/llm_function_calling/selection.py +++ b/packages/openassistants/openassistants/llm_function_calling/selection.py @@ -3,7 +3,7 @@ from langchain.chat_models.base import BaseChatModel from langchain.schema.messages import HumanMessage -from openassistants.functions.base import IBaseFunction +from openassistants.functions.base import IFunction from openassistants.llm_function_calling.utils import ( chunk_list_by_max_size, generate_to_json, @@ -12,7 +12,7 @@ async def filter_functions( - chat: BaseChatModel, functions: List[IBaseFunction], user_query: str + chat: BaseChatModel, functions: List[IFunction], user_query: str ) -> Optional[str]: functions_text = "\n".join([f.get_signature() for f in functions]) json_schema = { @@ -44,13 +44,13 @@ async def filter_functions( class SelectFunctionResult(BaseModel): - function: Optional[InstanceOf[IBaseFunction]] = None - suggested_functions: Optional[List[InstanceOf[IBaseFunction]]] = None + function: Optional[InstanceOf[IFunction]] = None + suggested_functions: Optional[List[InstanceOf[IFunction]]] = None async def select_function( chat: BaseChatModel, - functions: List[IBaseFunction], + functions: List[IFunction], user_query: str, chunk_size: int = 4, ) -> SelectFunctionResult: