Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JSON mode support for TypeAdapter and not adding computed_field #31

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions llmclient/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,15 @@ def sum_logprobs(choice: litellm.utils.Choices) -> float | None:


def validate_json_completion(
completion: litellm.ModelResponse, output_type: type[BaseModel] | JSONSchema
completion: litellm.ModelResponse,
output_type: type[BaseModel] | TypeAdapter | JSONSchema,
) -> None:
"""Validate a completion against a JSON schema.

Args:
completion: The completion to validate.
output_type: A JSON schema or a Pydantic model to validate the completion.
output_type: A Pydantic model, Pydantic type adapter, or a JSON schema to
validate the completion.
"""
try:
for choice in completion.choices:
Expand All @@ -131,6 +133,8 @@ def validate_json_completion(
litellm.litellm_core_utils.json_validation_rule.validate_schema(
schema=dict(output_type), response=choice.message.content
)
elif isinstance(output_type, TypeAdapter):
output_type.validate_json(choice.message.content)
else:
output_type.model_validate_json(choice.message.content)
except ValidationError as err:
Expand Down Expand Up @@ -737,7 +741,7 @@ async def call( # noqa: C901, PLR0915
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | JSONSchema | None = None,
output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
Expand Down Expand Up @@ -798,7 +802,10 @@ async def call( # noqa: C901, PLR0915
},
}
elif output_type is not None: # Use JSON mode
schema = json.dumps(output_type.model_json_schema(mode="serialization"))
if isinstance(output_type, TypeAdapter):
schema: str = json.dumps(output_type.json_schema())
else:
schema = json.dumps(output_type.model_json_schema())
schema_msg = f"Respond following this JSON schema:\n\n{schema}"
# Get the system prompt and its index, or the index to add it
i, system_prompt = next(
Expand Down Expand Up @@ -941,7 +948,7 @@ async def call_single(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
output_type: type[BaseModel] | TypeAdapter | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ interactions:
- request:
body:
'{"messages":[{"role":"system","content":"Respond following this JSON schema:\n\n{\"properties\":
{\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"age\": {\"title\":
\"Age\", \"type\": \"integer\"}}, \"required\": [\"name\", \"age\"], \"title\":
\"DummyOutputSchema\", \"type\": \"object\"}"},{"role":"user","content":"My
{\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"age\": {\"description\":
\"Age in years.\", \"title\": \"Age\", \"type\": \"integer\"}}, \"required\":
[\"name\", \"age\"], \"title\": \"DummyOutputSchema\", \"type\": \"object\"}"},{"role":"user","content":"My
name is Claude and I am 1 year old. What is my name and age?"}],"model":"gpt-3.5-turbo","n":2,"response_format":{"type":"json_object"}}'
headers:
accept:
Expand All @@ -14,7 +14,7 @@ interactions:
connection:
- keep-alive
content-length:
- "465"
- "501"
content-type:
- application/json
host:
Expand Down Expand Up @@ -44,34 +44,34 @@ interactions:
response:
body:
string: !!binary |
H4sIAAAAAAAAA9xTy27bMBC86yuIPctBFNmwo1sRwDkG8DUqBJpcS0z4CrlqbRj+94LyQwqSAj33
osPMzmCGuzpmjIGSUDEQHSdhvJ79EF0I9LFZvzwf9ofVXPziv9frjTHb3fMj5Enhtm8o6Kq6E854
jaScPdMiICdMrsWyLJer1XyxGAjjJOokaz3NyrvFjPqwdbP74mFxUXZOCYxQsdeMMcaOwzdltBL3
ULH7/IoYjJG3CNVtiDEITicEeIwqErcE+UgKZwntEPtY2wTVYLnBGipWw5PmvcQa8ivF24Epanua
ugTc9ZGnFrbX+oKfbrG0a31w23jhb/hOWRW7JiCPzqYIkZyHbCL+0rX4H7tmjP0cVt1/agQ+OOOp
IfeONhk+lmc7GI9rJMv5hSRHXI948bDMv7FrJBJXOk7eDwQXHcpROh4W76VyE2K6o69pvvM+F1e2
/Rf7kRACPaFsfECpxOfG41jA9O/9bez2yENgiIdIaJqdsi0GH9RwEcMuT9kfAAAA//8DAIF37Xz8
AwAA
H4sIAAAAAAAAAwAAAP//3FOxTsMwFNzzFdabU9QQWlA2BOqAGGABCYIi135NTR3b2C8SUPXfkUPa
BAESM0uGu3enO7+XbcIYKAkFA7HmJBqnJ+eGb+a3D1fX/OahvX6vz28X/v7ybiHMy42BNCrs8hkF
7VVHwjZOIynb08IjJ4yu2Wk+z8+O56cnHdFYiTrKakeT/Gg2odYv7WSaHc965doqgQEK9pgwxti2
+8aMRuIrFGya7pEGQ+A1QnEYYgy81REBHoIKxA1BOpDCGkLTxd6WJkIlGN5gCQUr4ULzVmIJ6Z7i
dcdkpdmNXTyu2sBjC9Nq3eO7Qyxta+ftMvT8AV8po8K68siDNTFCIOsgGYm/dc3+Y9eEsadu1e2X
RuC8bRxVZDdoomE27evDcF0Dm5/0JFnieqTKZ+kPfpVE4kqH0QOC4GKNcpAOl8VbqeyIGC/pe5qf
vD+bK1P/xX4ghEBHKCvnUSrxtfEw5jH+fL+NHV65CwzhLRA21UqZGr3zqjuJbpm75AMAAP//AwCz
swXL/QMAAA==
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8ef8eb547a107af2-SJC
- 8ff052c38a1d2513-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Mon, 09 Dec 2024 23:54:15 GMT
- Thu, 09 Jan 2025 00:31:14 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=0Iu69WUF4xDO0vFgaZtty9xUYWYu0lXQJbhdSA2auis-1733788455-1.0.1.1-C_LcbIbrGSDzIlYAOpyRBM7f4J4jywuThItbkN5OYxTafbo1mZrOsdaTFIpVIdmCcfnMLRLLiZudJUgCmUYZJA;
path=/; expires=Tue, 10-Dec-24 00:24:15 GMT; domain=.api.openai.com; HttpOnly;
- __cf_bm=yMDJIJGoTgvr0XyTnHvqjykpwbTij_VNggbHa0u7mjQ-1736382674-1.0.1.1-1BbGRnMcdB.agM3NFbDHdv.rTrCdpbxCQlFER5FxoM0sEo9eSenHk34Mjks9Mw4MylroAxyHfKngc86iIObF1w;
path=/; expires=Thu, 09-Jan-25 01:01:14 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=xNK6lxi6_xb_YkBT8_ajuNZBOXHEpxhoPSIPlaIH7WU-1733788455477-0.0.1.1-604800000;
- _cfuvid=jJ509aH157xZVGEjAnT.8yM8V_vvXauvl90X2nnE.a8-1736382674852-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
Expand All @@ -84,7 +84,7 @@ interactions:
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "246"
- "282"
openai-version:
- "2020-10-01"
strict-transport-security:
Expand All @@ -96,13 +96,13 @@ interactions:
x-ratelimit-remaining-requests:
- "11999"
x-ratelimit-remaining-tokens:
- "999894"
- "999886"
x-ratelimit-reset-requests:
- 5ms
x-ratelimit-reset-tokens:
- 6ms
x-request-id:
- req_a693eeeb97516e218a0ab89620e52988
- req_32759180c4d54e6ba640880e0a70a889
status:
code: 200
message: OK
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
interactions:
- request:
body:
'{"messages":[{"role":"system","content":"Respond following this JSON schema:\n\n{\"properties\":
{\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"age\": {\"description\":
\"Age in years.\", \"title\": \"Age\", \"type\": \"integer\"}}, \"required\":
[\"name\", \"age\"], \"title\": \"DummyOutputSchema\", \"type\": \"object\"}"},{"role":"user","content":"My
name is Claude and I am 1 year old. What is my name and age?"}],"model":"gpt-4o","n":2,"response_format":{"type":"json_object"}}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "494"
content-type:
- application/json
host:
- api.openai.com
user-agent:
- AsyncOpenAI/Python 1.56.2
x-stainless-arch:
- arm64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.56.2
x-stainless-raw-response:
- "true"
x-stainless-retry-count:
- "1"
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.7
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//3FPBTuMwFLznK6x3TlAS2oTNDSFAHEFaCbRZRa79khoc27IdaUvVf0dO
QxMEK3HmksPMm8mM/byPCAHBoSLAttSz3sjkUtGXUt5J2d48PK5390+vu9XvaytuH153vyAOCr15
RubfVWdM90aiF1odaWaRegyuWXlenF/kRbkeiV5zlEHWGZ+sdJKn+SpJL5K0mIRbLRg6qMifiBBC
9uM3RFQc/0FF0vgd6dE52iFUpyFCwGoZEKDOCeep8hDPJNPKoxpT72tQtMcaKlLDlaQDxxpiUgPt
RjA7LIUW28HRkFsNUk744ZRE6s5YvXETf8JboYTbNhap0yr81XltIFqIP9XLfki9iJC/44UOH0qA
sbo3vvH6BVUwzNKpMcwrNLN5MZFeeyoXqryMv/BrOHoqpFucGTDKtshn6bw/dOBCL4jlvXxO85X3
sblQ3XfsZ4IxNB55YyxywT42nscshhf2v7HTKY+Bwe2cx75pherQGiuOW9CaZlPyYt1mdL2B6BC9
AQAA//8DAOyaJNbtAwAA
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8ff052c9b8e6fa8a-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Thu, 09 Jan 2025 00:31:16 GMT
Server:
- cloudflare
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "640"
openai-version:
- "2020-10-01"
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- "10000"
x-ratelimit-limit-tokens:
- "30000000"
x-ratelimit-remaining-requests:
- "9999"
x-ratelimit-remaining-tokens:
- "29999886"
x-ratelimit-reset-requests:
- 6ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_c993df701dab27752a7cafe3ba505f63
status:
code: 200
message: OK
version: 1
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ interactions:
- request:
body:
'{"messages":[{"role":"user","content":"My name is Claude and I am 1 year
old. What is my name and age?"}],"model":"gpt-4o","n":2,"response_format":{"type":"json_schema","json_schema":{"strict":true,"schema":{"properties":{"name":{"title":"Name","type":"string"},"age":{"title":"Age","type":"integer"}},"required":["name","age"],"title":"DummyOutputSchema","type":"object","additionalProperties":false},"name":"DummyOutputSchema"}}}'
old. What is my name and age?"}],"model":"gpt-4o","n":2,"response_format":{"type":"json_schema","json_schema":{"strict":true,"schema":{"properties":{"name":{"title":"Name","type":"string"},"age":{"description":"Age
in years.","title":"Age","type":"integer"}},"required":["name","age"],"title":"DummyOutputSchema","type":"object","additionalProperties":false},"name":"DummyOutputSchema"}}}'
headers:
accept:
- application/json
Expand All @@ -11,7 +12,7 @@ interactions:
connection:
- keep-alive
content-length:
- "431"
- "461"
content-type:
- application/json
host:
Expand All @@ -31,7 +32,7 @@ interactions:
x-stainless-raw-response:
- "true"
x-stainless-retry-count:
- "0"
- "1"
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
Expand All @@ -41,35 +42,29 @@ interactions:
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//3FNNa9wwFLz7V4h3tos368Qb39qeSkO+IIQQF/NWemurkSVFkqFh2f8e
5N2sHZJCz734MPNmPCM9bRPGQAqoGPAOA++tyr7yzjnx7ao76x+6HzfXz7YoLi+W5c/b+7tnSKPC
rH8TD2+qL9z0VlGQRu9p7ggDRddFuVyWq1VxejoSvRGkoqy1IStMdpKfFFm+yvKzg7AzkpOHij0m
jDG2Hb8xohb0ByqWp29IT95jS1AdhxgDZ1REAL2XPqAOkE4kNzqQHlNva9DYUw1VDd8VDoJqSGvA
NkKL3VzlaDN4jKH1oNQB3x1jKNNaZ9b+wB/xjdTSd40j9EbHX/pgLCQz8Ydui/+hW8LYr/Eqh3cN
wDrT29AE80Q6GpbnezuYdmciF6sDGUxANeHnZfqJWyMooFR+dlzAkXckJuW0NzgIaWbE/Eo+hvnM
e99b6vZf7CeCc7KBRGMdCcnfF57GHMWX9bex4xmPgcG/+EB9s5G6JWed3C/Axja85JjTGjlCskte
AQAA//8DAOP4nprlAwAA
H4sIAAAAAAAAAwAAAP//3FM9b9swEN31K4ibrUJSbNXWFnQpOnRLGiAqBJo8SUwpHkFSQQLD/72g
/CEZaYHOXTi8d+/xHe94SBgDJaFiIHoexGB1em/406saNv33r5mn9tG8/nj49liOD/uXew2rqKD9
C4pwUX0SNFiNQZE50cIhDxhd88935d023+W7iRhIoo6yzoZ0TWmRFes026ZZeRb2pAR6qNhzwhhj
h+mMEY3EN6hYtrogA3rPO4TqWsQYONIRAe698oGbAKuZFGQCmin1oQbDB6yhquGL5qPEGlY18C5C
+XGpctiOnsfQZtT6jB+vMTR11tHen/kr3iqjfN845J5MvNIHspAsxB96y/+H3hLGfk6jHG86AOto
sKEJ9AtNNNxuTnYw785MFucxQ6DA9Yzn2UV1Y9dIDFxpv3gvEFz0KGfpvDh8lIoWxHImH9P8yfvU
uDLdv9jPhBBoA8rGOpRK3HY8lzmMX+tvZddHngKDf/cBh6ZVpkNnnTptQGubTVtkZVls1jtIjslv
AAAA//8DAFdJovPmAwAA
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8ef8eb5459995c1d-SJC
- 8ff04054ae4b9e5c-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Mon, 09 Dec 2024 23:54:15 GMT
- Thu, 09 Jan 2025 00:18:40 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=eThvnPKsEWq81agLsGvOWG3bDSV7z.O._JEvK_LDWKo-1733788455-1.0.1.1-4BqFkZTxOoot56DBhhcY4KTbLV6Dn33HJ7X7MttciM9xBNhDaDEIYxZny..80Wu2vmkDSL5bw7MqQa.zI_1D0g;
path=/; expires=Tue, 10-Dec-24 00:24:15 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=B45jEr7m1koJPP4VpyiYJCVjXWCenQuz8H_umfGeEBQ-1733788455794-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
Expand All @@ -81,7 +76,7 @@ interactions:
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "602"
- "892"
openai-version:
- "2020-10-01"
strict-transport-security:
Expand All @@ -99,7 +94,7 @@ interactions:
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_ef98d385b8292c3cd3fe3cb3fd0f31c2
- req_7ff1ffa34df34287ed396ab6176cbd29
status:
code: 200
message: OK
Expand Down
14 changes: 11 additions & 3 deletions tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import pytest
from aviary.core import Message, Tool, ToolRequestMessage
from pydantic import BaseModel
from pydantic import BaseModel, Field, TypeAdapter, computed_field

from llmclient.exceptions import JSONSchemaValidationError
from llmclient.llms import (
Expand Down Expand Up @@ -261,7 +261,12 @@ def test_pickling(self, tmp_path: pathlib.Path) -> None:

class DummyOutputSchema(BaseModel):
name: str
age: int
age: int = Field(description="Age in years.")

@computed_field # type: ignore[prop-decorator]
@property
def name_and_age(self) -> str: # So we can test computed_field is not included
return f"{self.name}, {self.age}"


class TestMultipleCompletionLLMModel:
Expand Down Expand Up @@ -364,7 +369,10 @@ def play(move: int | None) -> None:
@pytest.mark.parametrize(
("model_name", "output_type"),
[
pytest.param("gpt-3.5-turbo", DummyOutputSchema, id="json-mode"),
pytest.param("gpt-3.5-turbo", DummyOutputSchema, id="json-mode-base-model"),
pytest.param(
"gpt-4o", TypeAdapter(DummyOutputSchema), id="json-mode-type-adapter"
),
pytest.param(
"gpt-4o", DummyOutputSchema.model_json_schema(), id="structured-outputs"
),
Expand Down
Loading