Skip to content

Commit

Permalink
JSON mode support for TypeAdapter and not adding computed_field (#31
Browse files Browse the repository at this point in the history
)

* Fixed JSON mode's model_json_schema invocation, with test

* Added TypeAdapter to output_schema support
  • Loading branch information
jamesbraza authored Jan 10, 2025
1 parent ab9202c commit 02ba26c
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 49 deletions.
17 changes: 12 additions & 5 deletions llmclient/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,15 @@ def sum_logprobs(choice: litellm.utils.Choices) -> float | None:


def validate_json_completion(
completion: litellm.ModelResponse, output_type: type[BaseModel] | JSONSchema
completion: litellm.ModelResponse,
output_type: type[BaseModel] | TypeAdapter | JSONSchema,
) -> None:
"""Validate a completion against a JSON schema.
Args:
completion: The completion to validate.
output_type: A JSON schema or a Pydantic model to validate the completion.
output_type: A Pydantic model, Pydantic type adapter, or a JSON schema to
validate the completion.
"""
try:
for choice in completion.choices:
Expand All @@ -131,6 +133,8 @@ def validate_json_completion(
litellm.litellm_core_utils.json_validation_rule.validate_schema(
schema=dict(output_type), response=choice.message.content
)
elif isinstance(output_type, TypeAdapter):
output_type.validate_json(choice.message.content)
else:
output_type.model_validate_json(choice.message.content)
except ValidationError as err:
Expand Down Expand Up @@ -737,7 +741,7 @@ async def call( # noqa: C901, PLR0915
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | JSONSchema | None = None,
output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
Expand Down Expand Up @@ -798,7 +802,10 @@ async def call( # noqa: C901, PLR0915
},
}
elif output_type is not None: # Use JSON mode
schema = json.dumps(output_type.model_json_schema(mode="serialization"))
if isinstance(output_type, TypeAdapter):
schema: str = json.dumps(output_type.json_schema())
else:
schema = json.dumps(output_type.model_json_schema())
schema_msg = f"Respond following this JSON schema:\n\n{schema}"
# Get the system prompt and its index, or the index to add it
i, system_prompt = next(
Expand Down Expand Up @@ -941,7 +948,7 @@ async def call_single(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
output_type: type[BaseModel] | TypeAdapter | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ interactions:
- request:
body:
'{"messages":[{"role":"system","content":"Respond following this JSON schema:\n\n{\"properties\":
{\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"age\": {\"title\":
\"Age\", \"type\": \"integer\"}}, \"required\": [\"name\", \"age\"], \"title\":
\"DummyOutputSchema\", \"type\": \"object\"}"},{"role":"user","content":"My
{\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"age\": {\"description\":
\"Age in years.\", \"title\": \"Age\", \"type\": \"integer\"}}, \"required\":
[\"name\", \"age\"], \"title\": \"DummyOutputSchema\", \"type\": \"object\"}"},{"role":"user","content":"My
name is Claude and I am 1 year old. What is my name and age?"}],"model":"gpt-3.5-turbo","n":2,"response_format":{"type":"json_object"}}'
headers:
accept:
Expand All @@ -14,7 +14,7 @@ interactions:
connection:
- keep-alive
content-length:
- "465"
- "501"
content-type:
- application/json
host:
Expand Down Expand Up @@ -44,34 +44,34 @@ interactions:
response:
body:
string: !!binary |
H4sIAAAAAAAAA9xTy27bMBC86yuIPctBFNmwo1sRwDkG8DUqBJpcS0z4CrlqbRj+94LyQwqSAj33
osPMzmCGuzpmjIGSUDEQHSdhvJ79EF0I9LFZvzwf9ofVXPziv9frjTHb3fMj5Enhtm8o6Kq6E854
jaScPdMiICdMrsWyLJer1XyxGAjjJOokaz3NyrvFjPqwdbP74mFxUXZOCYxQsdeMMcaOwzdltBL3
ULH7/IoYjJG3CNVtiDEITicEeIwqErcE+UgKZwntEPtY2wTVYLnBGipWw5PmvcQa8ivF24Epanua
ugTc9ZGnFrbX+oKfbrG0a31w23jhb/hOWRW7JiCPzqYIkZyHbCL+0rX4H7tmjP0cVt1/agQ+OOOp
IfeONhk+lmc7GI9rJMv5hSRHXI948bDMv7FrJBJXOk7eDwQXHcpROh4W76VyE2K6o69pvvM+F1e2
/Rf7kRACPaFsfECpxOfG41jA9O/9bez2yENgiIdIaJqdsi0GH9RwEcMuT9kfAAAA//8DAIF37Xz8
AwAA
H4sIAAAAAAAAAwAAAP//3FOxTsMwFNzzFdabU9QQWlA2BOqAGGABCYIi135NTR3b2C8SUPXfkUPa
BAESM0uGu3enO7+XbcIYKAkFA7HmJBqnJ+eGb+a3D1fX/OahvX6vz28X/v7ybiHMy42BNCrs8hkF
7VVHwjZOIynb08IjJ4yu2Wk+z8+O56cnHdFYiTrKakeT/Gg2odYv7WSaHc965doqgQEK9pgwxti2
+8aMRuIrFGya7pEGQ+A1QnEYYgy81REBHoIKxA1BOpDCGkLTxd6WJkIlGN5gCQUr4ULzVmIJ6Z7i
dcdkpdmNXTyu2sBjC9Nq3eO7Qyxta+ftMvT8AV8po8K68siDNTFCIOsgGYm/dc3+Y9eEsadu1e2X
RuC8bRxVZDdoomE27evDcF0Dm5/0JFnieqTKZ+kPfpVE4kqH0QOC4GKNcpAOl8VbqeyIGC/pe5qf
vD+bK1P/xX4ghEBHKCvnUSrxtfEw5jH+fL+NHV65CwzhLRA21UqZGr3zqjuJbpm75AMAAP//AwCz
swXL/QMAAA==
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8ef8eb547a107af2-SJC
- 8ff052c38a1d2513-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Mon, 09 Dec 2024 23:54:15 GMT
- Thu, 09 Jan 2025 00:31:14 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=0Iu69WUF4xDO0vFgaZtty9xUYWYu0lXQJbhdSA2auis-1733788455-1.0.1.1-C_LcbIbrGSDzIlYAOpyRBM7f4J4jywuThItbkN5OYxTafbo1mZrOsdaTFIpVIdmCcfnMLRLLiZudJUgCmUYZJA;
path=/; expires=Tue, 10-Dec-24 00:24:15 GMT; domain=.api.openai.com; HttpOnly;
- __cf_bm=yMDJIJGoTgvr0XyTnHvqjykpwbTij_VNggbHa0u7mjQ-1736382674-1.0.1.1-1BbGRnMcdB.agM3NFbDHdv.rTrCdpbxCQlFER5FxoM0sEo9eSenHk34Mjks9Mw4MylroAxyHfKngc86iIObF1w;
path=/; expires=Thu, 09-Jan-25 01:01:14 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=xNK6lxi6_xb_YkBT8_ajuNZBOXHEpxhoPSIPlaIH7WU-1733788455477-0.0.1.1-604800000;
- _cfuvid=jJ509aH157xZVGEjAnT.8yM8V_vvXauvl90X2nnE.a8-1736382674852-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
Expand All @@ -84,7 +84,7 @@ interactions:
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "246"
- "282"
openai-version:
- "2020-10-01"
strict-transport-security:
Expand All @@ -96,13 +96,13 @@ interactions:
x-ratelimit-remaining-requests:
- "11999"
x-ratelimit-remaining-tokens:
- "999894"
- "999886"
x-ratelimit-reset-requests:
- 5ms
x-ratelimit-reset-tokens:
- 6ms
x-request-id:
- req_a693eeeb97516e218a0ab89620e52988
- req_32759180c4d54e6ba640880e0a70a889
status:
code: 200
message: OK
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
interactions:
- request:
body:
'{"messages":[{"role":"system","content":"Respond following this JSON schema:\n\n{\"properties\":
{\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"age\": {\"description\":
\"Age in years.\", \"title\": \"Age\", \"type\": \"integer\"}}, \"required\":
[\"name\", \"age\"], \"title\": \"DummyOutputSchema\", \"type\": \"object\"}"},{"role":"user","content":"My
name is Claude and I am 1 year old. What is my name and age?"}],"model":"gpt-4o","n":2,"response_format":{"type":"json_object"}}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "494"
content-type:
- application/json
host:
- api.openai.com
user-agent:
- AsyncOpenAI/Python 1.56.2
x-stainless-arch:
- arm64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.56.2
x-stainless-raw-response:
- "true"
x-stainless-retry-count:
- "1"
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.7
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//3FPBTuMwFLznK6x3TlAS2oTNDSFAHEFaCbRZRa79khoc27IdaUvVf0dO
QxMEK3HmksPMm8mM/byPCAHBoSLAttSz3sjkUtGXUt5J2d48PK5390+vu9XvaytuH153vyAOCr15
RubfVWdM90aiF1odaWaRegyuWXlenF/kRbkeiV5zlEHWGZ+sdJKn+SpJL5K0mIRbLRg6qMifiBBC
9uM3RFQc/0FF0vgd6dE52iFUpyFCwGoZEKDOCeep8hDPJNPKoxpT72tQtMcaKlLDlaQDxxpiUgPt
RjA7LIUW28HRkFsNUk744ZRE6s5YvXETf8JboYTbNhap0yr81XltIFqIP9XLfki9iJC/44UOH0qA
sbo3vvH6BVUwzNKpMcwrNLN5MZFeeyoXqryMv/BrOHoqpFucGTDKtshn6bw/dOBCL4jlvXxO85X3
sblQ3XfsZ4IxNB55YyxywT42nscshhf2v7HTKY+Bwe2cx75pherQGiuOW9CaZlPyYt1mdL2B6BC9
AQAA//8DAOyaJNbtAwAA
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8ff052c9b8e6fa8a-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Thu, 09 Jan 2025 00:31:16 GMT
Server:
- cloudflare
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "640"
openai-version:
- "2020-10-01"
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- "10000"
x-ratelimit-limit-tokens:
- "30000000"
x-ratelimit-remaining-requests:
- "9999"
x-ratelimit-remaining-tokens:
- "29999886"
x-ratelimit-reset-requests:
- 6ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_c993df701dab27752a7cafe3ba505f63
status:
code: 200
message: OK
version: 1
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ interactions:
- request:
body:
'{"messages":[{"role":"user","content":"My name is Claude and I am 1 year
old. What is my name and age?"}],"model":"gpt-4o","n":2,"response_format":{"type":"json_schema","json_schema":{"strict":true,"schema":{"properties":{"name":{"title":"Name","type":"string"},"age":{"title":"Age","type":"integer"}},"required":["name","age"],"title":"DummyOutputSchema","type":"object","additionalProperties":false},"name":"DummyOutputSchema"}}}'
old. What is my name and age?"}],"model":"gpt-4o","n":2,"response_format":{"type":"json_schema","json_schema":{"strict":true,"schema":{"properties":{"name":{"title":"Name","type":"string"},"age":{"description":"Age
in years.","title":"Age","type":"integer"}},"required":["name","age"],"title":"DummyOutputSchema","type":"object","additionalProperties":false},"name":"DummyOutputSchema"}}}'
headers:
accept:
- application/json
Expand All @@ -11,7 +12,7 @@ interactions:
connection:
- keep-alive
content-length:
- "431"
- "461"
content-type:
- application/json
host:
Expand All @@ -31,7 +32,7 @@ interactions:
x-stainless-raw-response:
- "true"
x-stainless-retry-count:
- "0"
- "1"
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
Expand All @@ -41,35 +42,29 @@ interactions:
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//3FNNa9wwFLz7V4h3tos368Qb39qeSkO+IIQQF/NWemurkSVFkqFh2f8e
5N2sHZJCz734MPNmPCM9bRPGQAqoGPAOA++tyr7yzjnx7ao76x+6HzfXz7YoLi+W5c/b+7tnSKPC
rH8TD2+qL9z0VlGQRu9p7ggDRddFuVyWq1VxejoSvRGkoqy1IStMdpKfFFm+yvKzg7AzkpOHij0m
jDG2Hb8xohb0ByqWp29IT95jS1AdhxgDZ1REAL2XPqAOkE4kNzqQHlNva9DYUw1VDd8VDoJqSGvA
NkKL3VzlaDN4jKH1oNQB3x1jKNNaZ9b+wB/xjdTSd40j9EbHX/pgLCQz8Ydui/+hW8LYr/Eqh3cN
wDrT29AE80Q6GpbnezuYdmciF6sDGUxANeHnZfqJWyMooFR+dlzAkXckJuW0NzgIaWbE/Eo+hvnM
e99b6vZf7CeCc7KBRGMdCcnfF57GHMWX9bex4xmPgcG/+EB9s5G6JWed3C/Axja85JjTGjlCskte
AQAA//8DAOP4nprlAwAA
H4sIAAAAAAAAAwAAAP//3FM9b9swEN31K4ibrUJSbNXWFnQpOnRLGiAqBJo8SUwpHkFSQQLD/72g
/CEZaYHOXTi8d+/xHe94SBgDJaFiIHoexGB1em/406saNv33r5mn9tG8/nj49liOD/uXew2rqKD9
C4pwUX0SNFiNQZE50cIhDxhd88935d023+W7iRhIoo6yzoZ0TWmRFes026ZZeRb2pAR6qNhzwhhj
h+mMEY3EN6hYtrogA3rPO4TqWsQYONIRAe698oGbAKuZFGQCmin1oQbDB6yhquGL5qPEGlY18C5C
+XGpctiOnsfQZtT6jB+vMTR11tHen/kr3iqjfN845J5MvNIHspAsxB96y/+H3hLGfk6jHG86AOto
sKEJ9AtNNNxuTnYw785MFucxQ6DA9Yzn2UV1Y9dIDFxpv3gvEFz0KGfpvDh8lIoWxHImH9P8yfvU
uDLdv9jPhBBoA8rGOpRK3HY8lzmMX+tvZddHngKDf/cBh6ZVpkNnnTptQGubTVtkZVls1jtIjslv
AAAA//8DAFdJovPmAwAA
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8ef8eb5459995c1d-SJC
- 8ff04054ae4b9e5c-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Mon, 09 Dec 2024 23:54:15 GMT
- Thu, 09 Jan 2025 00:18:40 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=eThvnPKsEWq81agLsGvOWG3bDSV7z.O._JEvK_LDWKo-1733788455-1.0.1.1-4BqFkZTxOoot56DBhhcY4KTbLV6Dn33HJ7X7MttciM9xBNhDaDEIYxZny..80Wu2vmkDSL5bw7MqQa.zI_1D0g;
path=/; expires=Tue, 10-Dec-24 00:24:15 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=B45jEr7m1koJPP4VpyiYJCVjXWCenQuz8H_umfGeEBQ-1733788455794-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
Expand All @@ -81,7 +76,7 @@ interactions:
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "602"
- "892"
openai-version:
- "2020-10-01"
strict-transport-security:
Expand All @@ -99,7 +94,7 @@ interactions:
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_ef98d385b8292c3cd3fe3cb3fd0f31c2
- req_7ff1ffa34df34287ed396ab6176cbd29
status:
code: 200
message: OK
Expand Down
14 changes: 11 additions & 3 deletions tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import pytest
from aviary.core import Message, Tool, ToolRequestMessage
from pydantic import BaseModel
from pydantic import BaseModel, Field, TypeAdapter, computed_field

from llmclient.exceptions import JSONSchemaValidationError
from llmclient.llms import (
Expand Down Expand Up @@ -261,7 +261,12 @@ def test_pickling(self, tmp_path: pathlib.Path) -> None:

class DummyOutputSchema(BaseModel):
name: str
age: int
age: int = Field(description="Age in years.")

@computed_field # type: ignore[prop-decorator]
@property
def name_and_age(self) -> str: # So we can test computed_field is not included
return f"{self.name}, {self.age}"


class TestMultipleCompletionLLMModel:
Expand Down Expand Up @@ -364,7 +369,10 @@ def play(move: int | None) -> None:
@pytest.mark.parametrize(
("model_name", "output_type"),
[
pytest.param("gpt-3.5-turbo", DummyOutputSchema, id="json-mode"),
pytest.param("gpt-3.5-turbo", DummyOutputSchema, id="json-mode-base-model"),
pytest.param(
"gpt-4o", TypeAdapter(DummyOutputSchema), id="json-mode-type-adapter"
),
pytest.param(
"gpt-4o", DummyOutputSchema.model_json_schema(), id="structured-outputs"
),
Expand Down

0 comments on commit 02ba26c

Please sign in to comment.