-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathschema_generator.py
127 lines (101 loc) · 4.73 KB
/
schema_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import inspect
import typing
from typing import get_args, get_origin
from docstring_parser import parse
from pydantic import BaseModel
def is_optional(annotation):
# Check if the annotation is a Union
if getattr(annotation, "__origin__", None) is typing.Union:
# Check if None is one of the options in the Union
return type(None) in annotation.__args__
return False
def optional_length(annotation):
if is_optional(annotation):
# Subtract 1 to account for NoneType
return len(annotation.__args__) - 1
else:
raise ValueError("The annotation is not an Optional type")
def type_to_json_schema_type(py_type):
"""
Maps a Python type to a JSON schema type.
Specifically handles typing.Optional and common Python types.
"""
# if get_origin(py_type) is typing.Optional:
if is_optional(py_type):
# Assert that Optional has only one type argument
type_args = get_args(py_type)
assert optional_length(py_type) == 1, f"Optional type must have exactly one type argument, but got {py_type}"
# Extract and map the inner type
return type_to_json_schema_type(type_args[0])
# Mapping of Python types to JSON schema types
type_map = {
int: "integer",
str: "string",
bool: "boolean",
float: "number",
list[str]: "array",
# Add more mappings as needed
}
if py_type not in type_map:
raise ValueError(f"Python type {py_type} has no corresponding JSON schema type")
return type_map.get(py_type, "string") # Default to "string" if type not in map
def pydantic_model_to_open_ai(model):
schema = model.model_json_schema()
docstring = parse(model.__doc__ or "")
parameters = {k: v for k, v in schema.items() if k not in ("title", "description")}
for param in docstring.params:
if (name := param.arg_name) in parameters["properties"] and (description := param.description):
if "description" not in parameters["properties"][name]:
parameters["properties"][name]["description"] = description
parameters["required"] = sorted(k for k, v in parameters["properties"].items() if "default" not in v)
if "description" not in schema:
if docstring.short_description:
schema["description"] = docstring.short_description
else:
raise
return {
"name": schema["title"],
"description": schema["description"],
"parameters": parameters,
}
def generate_schema(function):
# Get the signature of the function
sig = inspect.signature(function)
# Parse the docstring
docstring = parse(function.__doc__)
# Prepare the schema dictionary
schema = {
"name": function.__name__,
"description": docstring.short_description,
"parameters": {"type": "object", "properties": {}, "required": []},
}
for param in sig.parameters.values():
# Exclude 'self' parameter
if param.name == "self":
continue
# Assert that the parameter has a type annotation
if param.annotation == inspect.Parameter.empty:
raise TypeError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a type annotation")
# Find the parameter's description in the docstring
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
# Assert that the parameter has a description
if not param_doc or not param_doc.description:
raise ValueError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a description in the docstring")
if inspect.isclass(param.annotation) and issubclass(param.annotation, BaseModel):
schema["parameters"]["properties"][param.name] = pydantic_model_to_open_ai(param.annotation)
else:
# Add parameter details to the schema
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
schema["parameters"]["properties"][param.name] = {
# "type": "string" if param.annotation == str else str(param.annotation),
"type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string",
"description": param_doc.description,
}
if param.default == inspect.Parameter.empty:
schema["parameters"]["required"].append(param.name)
if get_origin(param.annotation) is list:
if get_args(param.annotation)[0] is str:
schema["parameters"]["properties"][param.name]["items"] = {"type": "string"}
if param.annotation == inspect.Parameter.empty:
schema["parameters"]["required"].append(param.name)
return schema