Skip to content
This repository has been archived by the owner on Aug 13, 2024. It is now read-only.

Commit

Permalink
move AllFunctionTypes out of openassistant library (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
jordan-wu-97 authored Jan 8, 2024
1 parent 021a14a commit 9812a20
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 36 deletions.
39 changes: 37 additions & 2 deletions examples/fast-api-server/fast_api_server/main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,54 @@
import os
from typing import Annotated

from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from langchain.embeddings import OpenAIEmbeddings
from langchain.storage import RedisStore
from openassistants.contrib.advisor_function import AdvisorFunction
from openassistants.contrib.duckdb_query import DuckDBQueryFunction
from openassistants.contrib.langchain_ddg_tool import DuckDuckGoToolFunction
from openassistants.contrib.python_eval import PythonEvalFunction
from openassistants.contrib.sqlalchemy_query import QueryFunction
from openassistants.contrib.text_response import TextResponseFunction
from openassistants.core.assistant import Assistant
from openassistants.functions.crud import OpenAPILibrary, PythonLibrary
from openassistants.functions.base import IFunction
from openassistants.functions.crud import (
LocalYAMLLibrary,
OpenAPILibrary,
PythonLibrary,
)
from openassistants.utils.langchain_util import LangChainCachedEmbeddings
from openassistants_fastapi import RouteAssistants, create_router
from pydantic import Field, TypeAdapter

from fast_api_server.find_email_by_name_function import find_email_by_name_function

app = FastAPI()


# Specify all the function types that are allowed in the local YAML library
AllFunctionTypes = Annotated[
QueryFunction
| DuckDBQueryFunction
| PythonEvalFunction
| DuckDuckGoToolFunction
| TextResponseFunction
| AdvisorFunction,
Field(json_schema_extra={"discriminator": "type"}),
]


def model_parser(d: dict) -> IFunction:
return TypeAdapter(AllFunctionTypes).validate_python(d) # type: ignore


local_library = LocalYAMLLibrary(
"piedpiper",
model_parser,
"library",
)

# create a library with the custom function
custom_python_lib = PythonLibrary(functions=[find_email_by_name_function])

Expand All @@ -38,7 +73,7 @@


hooli_assistant = Assistant(
libraries=["piedpiper", custom_python_lib, openapi_lib],
libraries=[local_library, custom_python_lib, openapi_lib],
scope_description="""Only answer questions about Hooli company related matters.
You're also allowed to answer questions that refer to anything in the current chat history.""", # noqa: E501
entity_embedding_model=entity_embedding_model,
Expand Down
11 changes: 3 additions & 8 deletions packages/openassistants/openassistants/core/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
IFunction,
IFunctionLibrary,
)
from openassistants.functions.crud import LocalFunctionLibrary, PythonLibrary
from openassistants.functions.crud import PythonLibrary
from openassistants.llm_function_calling.entity_resolution import resolve_entities
from openassistants.llm_function_calling.fallback import perform_general_qa
from openassistants.llm_function_calling.infilling import (
Expand All @@ -45,7 +45,7 @@ class Assistant:

def __init__(
self,
libraries: List[str | IFunctionLibrary],
libraries: List[IFunctionLibrary],
function_identification: Optional[BaseChatModel] = None,
function_infilling: Optional[BaseChatModel] = None,
function_summarization: Optional[BaseChatModel] = None,
Expand Down Expand Up @@ -76,12 +76,7 @@ def __init__(
self.entity_embedding_model = (
entity_embedding_model or LangChainCachedEmbeddings(OpenAIEmbeddings())
)
self.function_libraries = [
library
if isinstance(library, IFunctionLibrary)
else LocalFunctionLibrary(library)
for library in libraries
]
self.function_libraries = libraries

if add_index:
index_func: IFunction = IndexFunction(
Expand Down
38 changes: 12 additions & 26 deletions packages/openassistants/openassistants/functions/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from json.decoder import JSONDecodeError
from pathlib import Path
from typing import (
Annotated,
Any,
Callable,
Dict,
Expand All @@ -17,35 +16,18 @@

from langchain.chains.openai_functions.openapi import openapi_spec_to_openai_fn
from langchain_community.utilities.openapi import OpenAPISpec
from openassistants.contrib.advisor_function import AdvisorFunction
from openassistants.contrib.duckdb_query import DuckDBQueryFunction
from openassistants.contrib.langchain_ddg_tool import DuckDuckGoToolFunction
from openassistants.contrib.python_callable import PythonCallableFunction
from openassistants.contrib.python_eval import PythonEvalFunction
from openassistants.contrib.sqlalchemy_query import QueryFunction
from openassistants.contrib.text_response import TextResponseFunction
from openassistants.data_models.function_output import TextOutput
from openassistants.data_models.json_schema import JSONSchema
from openassistants.functions.base import (
BaseFunction,
BaseFunctionParameters,
IFunction,
IFunctionLibrary,
)
from openassistants.utils import yaml as yaml_utils
from pydantic import Field, TypeAdapter
from pydantic import TypeAdapter
from starlette.concurrency import run_in_threadpool

AllFunctionTypes = Annotated[
QueryFunction
| DuckDBQueryFunction
| PythonEvalFunction
| DuckDuckGoToolFunction
| TextResponseFunction
| AdvisorFunction,
Field(json_schema_extra={"discriminator": "type"}),
]


class BaseFileLibrary(IFunctionLibrary, abc.ABC):
@abc.abstractmethod
Expand All @@ -72,19 +54,23 @@ async def get_all_functions(self) -> Sequence[IFunction]:
return funcs # type: ignore


class LocalFunctionLibrary(BaseFileLibrary):
def __init__(self, library_id: str, directory: str = "library"):
class LocalYAMLLibrary(BaseFileLibrary):
def __init__(
self,
library_id: str,
model_parser: Callable[[dict], IFunction],
directory: str = "library",
):
self.library_id = library_id
self.model_parser = model_parser
self.directory = Path(directory) / library_id

def read(self, function_id: str) -> Optional[BaseFunction]:
def read(self, function_id: str) -> Optional[IFunction]:
try:
if (yaml_file := self.directory / f"{function_id}.yaml").exists():
with yaml_file.open() as f:
parsed_yaml = yaml_utils.load(f)
return TypeAdapter(AllFunctionTypes).validate_python(
parsed_yaml | {"id": function_id}
) # type: ignore
yaml_dict = yaml_utils.load(f)
return self.model_parser(yaml_dict | {"id": function_id})
else:
return None
except Exception as e:
Expand Down

0 comments on commit 9812a20

Please sign in to comment.