Skip to content
This repository has been archived by the owner on Aug 13, 2024. It is now read-only.

Commit

Permalink
rework / clean up JSONSchema (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
jordan-wu-97 authored Dec 20, 2023
1 parent b6e1ec1 commit b76fbca
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ The following Python function:

class PythonEvalFunction(BaseFunction):
type: Literal["PythonEvalFunction"] = "PythonEvalFunction"
parameters: BaseJSONSchema
python_code: str

async def execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import anyio
import pandas as pd
from openassistants.contrib.python_callable import PythonCallableFunction
from openassistants.data_models.function_input import BaseJSONSchema
from openassistants.data_models.function_output import FunctionOutput, TextOutput
from openassistants.functions.base import (
BaseFunctionParameters,
Entity,
EntityConfig,
FunctionExecutionDependency,
Expand Down Expand Up @@ -64,7 +64,7 @@ async def _get_entity_configs() -> dict[str, EntityConfig]:
"Find the email address for {employee}",
"What is {employee}'s email address?",
],
parameters=BaseJSONSchema(
parameters=BaseFunctionParameters(
json_schema={
"type": "object",
"properties": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,3 @@ async def execute(
output += f"""**{function.get_display_name()}**
{function.get_description()}\n\n"""
yield [TextOutput(text=output)]

def get_parameters_json_schema(self) -> dict:
return {"type": "object", "properties": {}}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from typing import Dict, List, Literal, Optional, Sequence

from openassistants.data_models.function_input import BaseJSONSchema
from openassistants.data_models.function_output import FunctionOutput, TextOutput
from openassistants.functions.base import BaseFunction, FunctionExecutionDependency
from openassistants.functions.utils import AsyncStreamVersion
Expand Down Expand Up @@ -38,7 +37,6 @@ def ddgs_text(query: str, max_results: Optional[int] = None) -> List[Dict[str, s

class DuckDuckGoToolFunction(BaseFunction):
type: Literal["DuckDuckGoToolFunction"] = "DuckDuckGoToolFunction"
parameters: BaseJSONSchema

async def execute(
self, deps: FunctionExecutionDependency
Expand All @@ -65,6 +63,3 @@ async def execute(
raise RuntimeError(
f"Error while executing action function {self.id}. function raised: {e}"
) from e

def get_parameters_json_schema(self) -> dict:
return self.parameters.json_schema
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Awaitable, Callable, Mapping, Sequence

from openassistants.data_models.function_input import BaseJSONSchema
from openassistants.data_models.function_output import FunctionOutput
from openassistants.functions.base import (
BaseFunction,
Expand All @@ -15,8 +14,6 @@ class PythonCallableFunction(BaseFunction):
[FunctionExecutionDependency], AsyncStreamVersion[Sequence[FunctionOutput]]
]

parameters: BaseJSONSchema

get_entity_configs_callable: Callable[[], Awaitable[Mapping[str, IEntityConfig]]]

async def execute(
Expand All @@ -25,8 +22,5 @@ async def execute(
async for output in self.execute_callable(deps):
yield output

def get_parameters_json_schema(self) -> dict:
return self.parameters.json_schema

async def get_entity_configs(self) -> Mapping[str, IEntityConfig]:
return await self.get_entity_configs_callable()
5 changes: 0 additions & 5 deletions packages/openassistants/openassistants/contrib/python_eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import inspect
from typing import Any, Callable, Dict, List, Literal, Sequence

from openassistants.data_models.function_input import BaseJSONSchema
from openassistants.data_models.function_output import FunctionOutput
from openassistants.functions.base import (
BaseFunction,
Expand All @@ -13,7 +12,6 @@

class PythonEvalFunction(BaseFunction):
type: Literal["PythonEvalFunction"] = "PythonEvalFunction"
parameters: BaseJSONSchema
python_code: str

async def execute(
Expand Down Expand Up @@ -47,6 +45,3 @@ async def execute(
raise RuntimeError(
f"Error while executing action function {self.id}. function raised: {e}"
) from e

def get_parameters_json_schema(self) -> dict:
return self.parameters.json_schema
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Annotated, Any, List, Literal, Sequence

import jsonschema
import pandas as pd
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from openassistants.data_models.chat_messages import (
Expand All @@ -9,7 +10,6 @@
OpasMessage,
)
from openassistants.data_models.function_input import (
BaseJSONSchema,
FunctionCall,
)
from openassistants.data_models.function_output import (
Expand Down Expand Up @@ -69,7 +69,6 @@ def _opas_to_summarization_lc(

class QueryFunction(BaseFunction):
type: Literal["QueryFunction"] = "QueryFunction"
parameters: BaseJSONSchema
sqls: List[str]
visualizations: List[str]
summarization: str
Expand Down Expand Up @@ -153,7 +152,10 @@ async def execute(
self,
deps: FunctionExecutionDependency,
) -> AsyncStreamVersion[Sequence[FunctionOutput]]:
self.parameters.validate_args(deps.arguments)
try:
jsonschema.validate(deps.arguments, self.get_parameters_json_schema())
except jsonschema.ValidationError as e:
raise ValueError(f"Invalid arguments:\n{str(e)}") from e

results: List[FunctionOutput] = []

Expand Down Expand Up @@ -203,6 +205,3 @@ async def execute(
)

yield results

def get_parameters_json_schema(self) -> dict:
return self.parameters.json_schema
Empty file.
Original file line number Diff line number Diff line change
@@ -1,47 +1,14 @@
from typing import Any, Dict

import jsonschema
from pydantic import BaseModel


class BaseJSONSchema(BaseModel):
"""
Validates a json_schema. top level must of the schema must be type object
"""

json_schema: Dict[str, Any]

def schema_validator(cls, values):
jsonschema.validate(
values["json_schema"], jsonschema.Draft202012Validator.META_SCHEMA
)
jsonschema.validate(
values["json_schema"],
{
"type": "object",
"properties": {
"type": {"type": "string", "enum": ["object"]},
"properties": {"type": "object"},
"required": {"type": "array", "items": {"type": "string"}},
},
"required": [
"type",
],
},
)
return values

def validate_args(self, args: dict):
try:
jsonschema.validate(args, self.json_schema)
except jsonschema.exceptions.ValidationError as e:
raise ValueError(f"invalid function arguments\n{e}")
from openassistants.data_models.json_schema import JSONSchema


class FunctionCall(BaseModel):
name: str
arguments: Dict[str, Any]


class FunctionInputRequest(FunctionCall, BaseJSONSchema):
pass
class FunctionInputRequest(FunctionCall):
json_schema: JSONSchema
40 changes: 40 additions & 0 deletions packages/openassistants/openassistants/data_models/json_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Annotated, Any, Dict

import jsonschema
from pydantic import AfterValidator, TypeAdapter


def _json_schema_meta_validator(value: Any):
try:
jsonschema.validate(value, jsonschema.Draft202012Validator.META_SCHEMA)
except jsonschema.exceptions.ValidationError as e:
raise ValueError(f"Invalid JSONSchema:\n{str(e)}\n") from e
if value.get("type") != "object":
raise ValueError(
f"JSONSchema must have type='object'. Got '{value.get('type')}'"
)
return value


JSONSchema = Annotated[Dict[str, Any], AfterValidator(_json_schema_meta_validator)]
"""
A JSONSchema is a dict that conforms to the JSONSchema specification.
In order to validate an arbitrary dict
```
from pydantic import TypeAdapter
json_schema = TypeAdapter(JSONSchema).validate_python(some_dict)
```
When used in a pydantic model as a field, the JSONSchema will be validated automatically.
""" # noqa: E501


EMPTY_JSON_SCHEMA = TypeAdapter(JSONSchema).validate_python(
{
"type": "object",
"properties": {},
"required": [],
}
)
11 changes: 10 additions & 1 deletion packages/openassistants/openassistants/functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.language_models import BaseChatModel
from openassistants.data_models.chat_messages import OpasMessage
from openassistants.data_models.function_output import FunctionOutput
from openassistants.data_models.json_schema import EMPTY_JSON_SCHEMA, JSONSchema
from openassistants.functions.utils import AsyncStreamVersion
from openassistants.utils.json_schema import PyRepr
from pydantic import BaseModel
Expand Down Expand Up @@ -56,7 +57,7 @@ def get_sample_questions(self) -> Sequence[str]:
pass

@abc.abstractmethod
def get_parameters_json_schema(self) -> dict:
def get_parameters_json_schema(self) -> JSONSchema:
"""
Get the json schema of the function's parameters
"""
Expand Down Expand Up @@ -120,13 +121,18 @@ def get_entities(self) -> Sequence[IEntity]:
return self.entities


class BaseFunctionParameters(BaseModel):
json_schema: JSONSchema = EMPTY_JSON_SCHEMA


class BaseFunction(IBaseFunction, BaseModel, abc.ABC):
id: str
type: str
display_name: Optional[str] = None
description: str
sample_questions: List[str] = []
confirm: bool = False
parameters: BaseFunctionParameters = BaseFunctionParameters()

def get_id(self) -> str:
return self.id
Expand All @@ -146,5 +152,8 @@ def get_sample_questions(self) -> Sequence[str]:
def get_confirm(self) -> bool:
return self.confirm

def get_parameters_json_schema(self) -> JSONSchema:
return self.parameters.json_schema

async def get_entity_configs(self) -> Mapping[str, IEntityConfig]:
return {}

0 comments on commit b76fbca

Please sign in to comment.