From fd4a17abf8a645d83032f828b437408051bac980 Mon Sep 17 00:00:00 2001
From: Jordan Wu <101218661+jordan-definitive@users.noreply.github.com>
Date: Fri, 5 Jan 2024 15:27:20 -0800
Subject: [PATCH] cleanup function + library to have a clean name / interface

---
 .../fast-api-server/fast_api_server/main.py   |  6 +-
 .../openassistants/contrib/index_function.py  |  4 +-
 .../openassistants/core/assistant.py          | 39 +++++++------
 .../openassistants/eval/interaction.py        |  6 +-
 .../openassistants/functions/base.py          | 10 +++-
 .../openassistants/functions/crud.py          | 55 ++++++++++---------
 .../llm_function_calling/entity_resolution.py |  4 +-
 .../llm_function_calling/infilling.py         |  8 +--
 .../llm_function_calling/selection.py         | 10 ++--
 9 files changed, 78 insertions(+), 64 deletions(-)

diff --git a/examples/fast-api-server/fast_api_server/main.py b/examples/fast-api-server/fast_api_server/main.py
index 885dc18..a938601 100644
--- a/examples/fast-api-server/fast_api_server/main.py
+++ b/examples/fast-api-server/fast_api_server/main.py
@@ -5,7 +5,7 @@
 from langchain.embeddings import OpenAIEmbeddings
 from langchain.storage import RedisStore
 from openassistants.core.assistant import Assistant
-from openassistants.functions.crud import OpenAPICRUD, PythonCRUD
+from openassistants.functions.crud import OpenAPILibrary, PythonLibrary
 from openassistants.utils.langchain_util import LangChainCachedEmbeddings
 from openassistants_fastapi import RouteAssistants, create_router
 
@@ -15,9 +15,9 @@
 
 
 # create a library with the custom function
-custom_python_lib = PythonCRUD(functions=[find_email_by_name_function])
+custom_python_lib = PythonLibrary(functions=[find_email_by_name_function])
 
-openapi_lib = OpenAPICRUD(
+openapi_lib = OpenAPILibrary(
     spec="https://petstore3.swagger.io/api/v3/openapi.json",
     base_url="https://petstore3.swagger.io/api/v3",
 )
diff --git a/packages/openassistants/openassistants/contrib/index_function.py b/packages/openassistants/openassistants/contrib/index_function.py
index f81de83..3e93c41 100644
--- a/packages/openassistants/openassistants/contrib/index_function.py
+++ b/packages/openassistants/openassistants/contrib/index_function.py
@@ -4,14 +4,14 @@
 from openassistants.functions.base import (
     BaseFunction,
     FunctionExecutionDependency,
-    IBaseFunction,
+    IFunction,
 )
 from openassistants.functions.utils import AsyncStreamVersion
 
 
 class IndexFunction(BaseFunction):
     type: Literal["IndexFunction"] = "IndexFunction"
-    functions: Callable[[], Awaitable[List[IBaseFunction]]]
+    functions: Callable[[], Awaitable[List[IFunction]]]
 
     async def execute(
         self, deps: FunctionExecutionDependency
diff --git a/packages/openassistants/openassistants/core/assistant.py b/packages/openassistants/openassistants/core/assistant.py
index 0dd37c2..406a23c 100644
--- a/packages/openassistants/openassistants/core/assistant.py
+++ b/packages/openassistants/openassistants/core/assistant.py
@@ -15,10 +15,11 @@
 from openassistants.data_models.function_input import FunctionCall, FunctionInputRequest
 from openassistants.functions.base import (
     FunctionExecutionDependency,
-    IBaseFunction,
     IEntity,
+    IFunction,
+    IFunctionLibrary,
 )
-from openassistants.functions.crud import FunctionCRUD, LocalCRUD, PythonCRUD
+from openassistants.functions.crud import LocalFunctionLibrary, PythonLibrary
 from openassistants.llm_function_calling.entity_resolution import resolve_entities
 from openassistants.llm_function_calling.fallback import perform_general_qa
 from openassistants.llm_function_calling.infilling import (
@@ -37,14 +38,14 @@ class Assistant:
     function_summarization: BaseChatModel
     function_fallback: BaseChatModel
     entity_embedding_model: Embeddings
-    function_libraries: List[FunctionCRUD]
+    function_libraries: List[IFunctionLibrary]
     scope_description: str
 
-    _cached_all_functions: List[IBaseFunction]
+    _cached_all_functions: List[IFunction]
 
     def __init__(
         self,
-        libraries: List[str | FunctionCRUD],
+        libraries: List[str | IFunctionLibrary],
         function_identification: Optional[BaseChatModel] = None,
         function_infilling: Optional[BaseChatModel] = None,
         function_summarization: Optional[BaseChatModel] = None,
@@ -76,12 +77,14 @@ def __init__(
             entity_embedding_model or LangChainCachedEmbeddings(OpenAIEmbeddings())
         )
         self.function_libraries = [
-            library if isinstance(library, FunctionCRUD) else LocalCRUD(library)
+            library
+            if isinstance(library, IFunctionLibrary)
+            else LocalFunctionLibrary(library)
             for library in libraries
         ]
 
         if add_index:
-            index_func: IBaseFunction = IndexFunction(
+            index_func: IFunction = IndexFunction(
                 id="index",
                 display_name="List functions",
                 description=(
@@ -96,19 +99,19 @@ def __init__(
                 functions=self.get_all_functions,
             )
 
-            self.function_libraries.append(PythonCRUD(functions=[index_func]))
+            self.function_libraries.append(PythonLibrary(functions=[index_func]))
 
         self._cached_all_functions = []
 
-    async def get_all_functions(self) -> List[IBaseFunction]:
+    async def get_all_functions(self) -> List[IFunction]:
         if not self._cached_all_functions:
-            functions = []
+            functions: List[IFunction] = []
             for library in self.function_libraries:
-                functions.extend(await library.aread_all())
+                functions.extend(await library.get_all_functions())
             self._cached_all_functions = functions
         return self._cached_all_functions
 
-    async def get_function_by_id(self, function_id: str) -> Optional[IBaseFunction]:
+    async def get_function_by_id(self, function_id: str) -> Optional[IFunction]:
         functions = await self.get_all_functions()
         for function in functions:
             if function.get_id() == function_id:
@@ -117,7 +120,7 @@ async def get_function_by_id(self, function_id: str) -> Optional[IBaseFunction]:
 
     async def execute_function(
         self,
-        function: IBaseFunction,
+        function: IFunction,
         func_args: Dict[str, Any],
         dependencies: Dict[str, Any],
     ):
@@ -143,7 +146,7 @@ async def do_infilling(
         self,
         dependencies: dict,
         message: OpasUserMessage,
-        selected_function: IBaseFunction,
+        selected_function: IFunction,
         args_json_schema: dict,
         entities_info: Dict[str, List[IEntity]],
     ) -> Tuple[bool, dict]:
@@ -199,12 +202,12 @@ async def do_infilling(
     async def handle_user_plaintext(
         self,
         message: OpasUserMessage,
-        all_functions: List[IBaseFunction],
+        all_functions: List[IFunction],
         dependencies: Dict[str, Any],
         autorun: bool,
         force_select_function: Optional[str],
     ) -> AsyncStreamVersion[List[OpasMessage]]:
-        selected_function: Optional[IBaseFunction] = None
+        selected_function: Optional[IFunction] = None
         # perform entity resolution
         chat_history: List[OpasMessage] = dependencies.get("chat_history")  # type: ignore
 
@@ -304,13 +307,13 @@ async def handle_user_plaintext(
     async def handle_user_input(
         self,
         message: OpasUserMessage,
-        all_functions: List[IBaseFunction],
+        all_functions: List[IFunction],
         dependencies: Dict[str, Any],
     ) -> AsyncStreamVersion[List[OpasMessage]]:
         if message.input_response is None:
             raise ValueError("message must have input_response")
 
-        selected_function: Optional[IBaseFunction] = None
+        selected_function: Optional[IFunction] = None
 
         for f in all_functions:
             if f.get_id() == message.input_response.name:
diff --git a/packages/openassistants/openassistants/eval/interaction.py b/packages/openassistants/openassistants/eval/interaction.py
index e8b39a0..8826d41 100644
--- a/packages/openassistants/openassistants/eval/interaction.py
+++ b/packages/openassistants/openassistants/eval/interaction.py
@@ -12,7 +12,7 @@
 )
 from openassistants.data_models.function_input import FunctionCall
 from openassistants.data_models.function_output import DataFrameOutput, TextOutput
-from openassistants.functions.base import IBaseFunction
+from openassistants.functions.base import IFunction
 from openassistants.utils.async_utils import last_value
 from pydantic import BaseModel, ConfigDict
 
@@ -88,7 +88,7 @@ async def run_function_invocation(
     async def get_function(
         self,
         assistant: Assistant,
-    ) -> IBaseFunction:
+    ) -> IFunction:
         base_function = await assistant.get_function_by_id(self.function)
         if base_function is None:
             raise ValueError("Function not found")
@@ -159,7 +159,7 @@ class FunctionInteractionResponseNode(BaseModel):
     user_input_response: OpasUserMessage
     assistant_function_invocation: OpasAssistantMessage
     function_response: OpasFunctionMessage
-    function_spec: IBaseFunction
+    function_spec: IFunction
 
 
 class FunctionInteractionResponse(FunctionInteractionResponseNode):
diff --git a/packages/openassistants/openassistants/functions/base.py b/packages/openassistants/openassistants/functions/base.py
index 2008db9..8d53bc8 100644
--- a/packages/openassistants/openassistants/functions/base.py
+++ b/packages/openassistants/openassistants/functions/base.py
@@ -35,7 +35,7 @@ def get_entities(self) -> Sequence[IEntity]:
         pass
 
 
-class IBaseFunction(abc.ABC):
+class IFunction(abc.ABC):
     @abc.abstractmethod
     def get_id(self) -> str:
         pass
@@ -125,7 +125,7 @@ class BaseFunctionParameters(BaseModel):
     json_schema: JSONSchema = EMPTY_JSON_SCHEMA
 
 
-class BaseFunction(IBaseFunction, BaseModel, abc.ABC):
+class BaseFunction(IFunction, BaseModel, abc.ABC):
     id: str
     type: str
     display_name: Optional[str] = None
@@ -157,3 +157,9 @@ def get_parameters_json_schema(self) -> JSONSchema:
 
     async def get_entity_configs(self) -> Mapping[str, IEntityConfig]:
         return {}
+
+
+class IFunctionLibrary(abc.ABC):
+    @abc.abstractmethod
+    async def get_all_functions(self) -> Sequence[IFunction]:
+        pass
diff --git a/packages/openassistants/openassistants/functions/crud.py b/packages/openassistants/openassistants/functions/crud.py
index 9cd5dc6..e682fb1 100644
--- a/packages/openassistants/openassistants/functions/crud.py
+++ b/packages/openassistants/openassistants/functions/crud.py
@@ -3,7 +3,17 @@
 import json
 from json.decoder import JSONDecodeError
 from pathlib import Path
-from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import (
+    Annotated,
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+)
 
 from langchain.chains.openai_functions.openapi import openapi_spec_to_openai_fn
 from langchain_community.utilities.openapi import OpenAPISpec
@@ -19,7 +29,8 @@
 from openassistants.functions.base import (
     BaseFunction,
     BaseFunctionParameters,
-    IBaseFunction,
+    IFunction,
+    IFunctionLibrary,
 )
 from openassistants.utils import yaml as yaml_utils
 from pydantic import Field, TypeAdapter
@@ -36,27 +47,32 @@
 ]
 
 
-class FunctionCRUD(abc.ABC):
+class BaseFileLibrary(IFunctionLibrary, abc.ABC):
     @abc.abstractmethod
-    def read(self, slug: str) -> Optional[IBaseFunction]:
+    def read(self, slug: str) -> Optional[IFunction]:
         pass
 
     @abc.abstractmethod
     def list_ids(self) -> List[str]:
         pass
 
-    async def aread(self, function_id: str) -> Optional[IBaseFunction]:
+    async def aread(self, function_id: str) -> Optional[IFunction]:
         return await run_in_threadpool(self.read, function_id)
 
     async def alist_ids(self) -> List[str]:
         return await run_in_threadpool(self.list_ids)
 
-    async def aread_all(self) -> List[IBaseFunction]:
+    async def get_all_functions(self) -> Sequence[IFunction]:
         ids = await self.alist_ids()
-        return await asyncio.gather(*[self.aread(f_id) for f_id in ids])  # type: ignore
+        funcs: List[IFunction | None] = await asyncio.gather(  # type: ignore
+            *[self.aread(f_id) for f_id in ids]
+        )
+        if None in funcs:
+            raise RuntimeError("Failed to load all functions")
+        return funcs  # type: ignore
 
 
-class LocalCRUD(FunctionCRUD):
+class LocalFunctionLibrary(BaseFileLibrary):
     def __init__(self, library_id: str, directory: str = "library"):
         self.library_id = library_id
         self.directory = Path(directory) / library_id
@@ -74,32 +90,21 @@ def read(self, function_id: str) -> Optional[BaseFunction]:
         except Exception as e:
             raise RuntimeError(f"Failed to load: {function_id}") from e
 
-    async def aread_all(self) -> List[BaseFunction]:
-        ids = self.list_ids()
-        return [self.read(f_id) for f_id in ids]  # type: ignore
-
     def list_ids(self) -> List[str]:
         return [
             file.stem for file in self.directory.iterdir() if file.suffix == ".yaml"
         ]
 
 
-class PythonCRUD(FunctionCRUD):
-    def __init__(self, functions: List[IBaseFunction]):
+class PythonLibrary(IFunctionLibrary):
+    def __init__(self, functions: Sequence[IFunction]):
         self.functions = functions
 
-    def read(self, slug: str) -> Optional[IBaseFunction]:
-        for function in self.functions:
-            if function.get_id() == slug:
-                return function
-
-        return None
-
-    def list_ids(self) -> List[str]:
-        return [function.get_id() for function in self.functions]
+    async def get_all_functions(self) -> Sequence[IFunction]:
+        return self.functions
 
 
-class OpenAPICRUD(PythonCRUD):
+class OpenAPILibrary(PythonLibrary):
     openapi: OpenAPISpec
 
     @staticmethod
@@ -159,6 +164,6 @@ def __init__(self, spec: Union[OpenAPISpec, str], base_url: Optional[str]):
             self.openapi.servers[0].url = base_url
 
         openai_functions = openapi_spec_to_openai_fn(self.openapi)
-        functions = OpenAPICRUD.openai_fns_to_openapi_function(openai_functions)
+        functions = OpenAPILibrary.openai_fns_to_openapi_function(openai_functions)
 
         super().__init__(functions)
diff --git a/packages/openassistants/openassistants/llm_function_calling/entity_resolution.py b/packages/openassistants/openassistants/llm_function_calling/entity_resolution.py
index 9a043a6..08bc4cf 100644
--- a/packages/openassistants/openassistants/llm_function_calling/entity_resolution.py
+++ b/packages/openassistants/openassistants/llm_function_calling/entity_resolution.py
@@ -7,9 +7,9 @@
 from langchain.vectorstores.usearch import USearch
 from openassistants.data_models.chat_messages import OpasMessage
 from openassistants.functions.base import (
-    IBaseFunction,
     IEntity,
     IEntityConfig,
+    IFunction,
 )
 from openassistants.llm_function_calling.infilling import generate_arguments
 
@@ -64,7 +64,7 @@ async def _get_entities(
 
 
 async def resolve_entities(
-    function: IBaseFunction,
+    function: IFunction,
     function_infilling_llm: BaseChatModel,
     embeddings: Embeddings,
     user_query: str,
diff --git a/packages/openassistants/openassistants/llm_function_calling/infilling.py b/packages/openassistants/openassistants/llm_function_calling/infilling.py
index b499ea5..322086c 100644
--- a/packages/openassistants/openassistants/llm_function_calling/infilling.py
+++ b/packages/openassistants/openassistants/llm_function_calling/infilling.py
@@ -4,14 +4,14 @@
 from langchain.chat_models.base import BaseChatModel
 from langchain.schema.messages import HumanMessage
 from openassistants.data_models.chat_messages import OpasMessage
-from openassistants.functions.base import IBaseFunction, IEntity
+from openassistants.functions.base import IEntity, IFunction
 from openassistants.llm_function_calling.utils import (
     build_chat_history_prompt,
     generate_to_json,
 )
 
 
-async def generate_argument_decisions_schema(function: IBaseFunction):
+async def generate_argument_decisions_schema(function: IFunction):
     # Start with the base schema
     json_schema = function.get_parameters_json_schema()
 
@@ -50,7 +50,7 @@ class NestedObject(TypedDict):
 
 
 async def generate_argument_decisions(
-    function: IBaseFunction,
+    function: IFunction,
     chat: BaseChatModel,
     user_query: str,
     chat_history: List[OpasMessage],
@@ -93,7 +93,7 @@ def entity_to_json_schema_obj(entity: IEntity):
 
 
 async def generate_arguments(
-    function: IBaseFunction,
+    function: IFunction,
     chat: BaseChatModel,
     user_query: str,
     chat_history: List[OpasMessage],
diff --git a/packages/openassistants/openassistants/llm_function_calling/selection.py b/packages/openassistants/openassistants/llm_function_calling/selection.py
index 782d4a0..32e0ccb 100644
--- a/packages/openassistants/openassistants/llm_function_calling/selection.py
+++ b/packages/openassistants/openassistants/llm_function_calling/selection.py
@@ -3,7 +3,7 @@
 
 from langchain.chat_models.base import BaseChatModel
 from langchain.schema.messages import HumanMessage
-from openassistants.functions.base import IBaseFunction
+from openassistants.functions.base import IFunction
 from openassistants.llm_function_calling.utils import (
     chunk_list_by_max_size,
     generate_to_json,
@@ -12,7 +12,7 @@
 
 
 async def filter_functions(
-    chat: BaseChatModel, functions: List[IBaseFunction], user_query: str
+    chat: BaseChatModel, functions: List[IFunction], user_query: str
 ) -> Optional[str]:
     functions_text = "\n".join([f.get_signature() for f in functions])
     json_schema = {
@@ -44,13 +44,13 @@ async def filter_functions(
 
 
 class SelectFunctionResult(BaseModel):
-    function: Optional[InstanceOf[IBaseFunction]] = None
-    suggested_functions: Optional[List[InstanceOf[IBaseFunction]]] = None
+    function: Optional[InstanceOf[IFunction]] = None
+    suggested_functions: Optional[List[InstanceOf[IFunction]]] = None
 
 
 async def select_function(
     chat: BaseChatModel,
-    functions: List[IBaseFunction],
+    functions: List[IFunction],
     user_query: str,
     chunk_size: int = 4,
 ) -> SelectFunctionResult: