diff --git a/llmclient/llms.py b/llmclient/llms.py index 65d7ffa..79d89fe 100644 --- a/llmclient/llms.py +++ b/llmclient/llms.py @@ -109,13 +109,15 @@ def sum_logprobs(choice: litellm.utils.Choices) -> float | None: def validate_json_completion( - completion: litellm.ModelResponse, output_type: type[BaseModel] | JSONSchema + completion: litellm.ModelResponse, + output_type: type[BaseModel] | TypeAdapter | JSONSchema, ) -> None: """Validate a completion against a JSON schema. Args: completion: The completion to validate. - output_type: A JSON schema or a Pydantic model to validate the completion. + output_type: A Pydantic model, Pydantic type adapter, or a JSON schema to + validate the completion. """ try: for choice in completion.choices: @@ -131,6 +133,8 @@ def validate_json_completion( litellm.litellm_core_utils.json_validation_rule.validate_schema( schema=dict(output_type), response=choice.message.content ) + elif isinstance(output_type, TypeAdapter): + output_type.validate_json(choice.message.content) else: output_type.model_validate_json(choice.message.content) except ValidationError as err: @@ -737,7 +741,7 @@ async def call( # noqa: C901, PLR0915 self, messages: list[Message], callbacks: list[Callable] | None = None, - output_type: type[BaseModel] | JSONSchema | None = None, + output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = None, tools: list[Tool] | None = None, tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, **chat_kwargs, @@ -798,7 +802,10 @@ async def call( # noqa: C901, PLR0915 }, } elif output_type is not None: # Use JSON mode - schema = json.dumps(output_type.model_json_schema(mode="serialization")) + if isinstance(output_type, TypeAdapter): + schema: str = json.dumps(output_type.json_schema()) + else: + schema = json.dumps(output_type.model_json_schema()) schema_msg = f"Respond following this JSON schema:\n\n{schema}" # Get the system prompt and its index, or the index to add it i, system_prompt = next( @@ -941,7 +948,7 @@ async def call_single( self, messages: list[Message], callbacks: list[Callable] | None = None, - output_type: type[BaseModel] | None = None, + output_type: type[BaseModel] | TypeAdapter | None = None, tools: list[Tool] | None = None, tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, **chat_kwargs, diff --git a/tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[json-mode].yaml b/tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[json-mode-base-model].yaml similarity index 64% rename from tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[json-mode].yaml rename to tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[json-mode-base-model].yaml index e209bb5..d0c731f 100644 --- a/tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[json-mode].yaml +++ b/tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[json-mode-base-model].yaml @@ -2,9 +2,9 @@ interactions: - request: body: '{"messages":[{"role":"system","content":"Respond following this JSON schema:\n\n{\"properties\": - {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"age\": {\"title\": - \"Age\", \"type\": \"integer\"}}, \"required\": [\"name\", \"age\"], \"title\": - \"DummyOutputSchema\", \"type\": \"object\"}"},{"role":"user","content":"My + {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"age\": {\"description\": + \"Age in years.\", \"title\": \"Age\", \"type\": \"integer\"}}, \"required\": + [\"name\", \"age\"], \"title\": \"DummyOutputSchema\", \"type\": \"object\"}"},{"role":"user","content":"My name is Claude and I am 1 year old. What is my name and age?"}],"model":"gpt-3.5-turbo","n":2,"response_format":{"type":"json_object"}}' headers: accept: @@ -14,7 +14,7 @@ interactions: connection: - keep-alive content-length: - - "465" + - "501" content-type: - application/json host: @@ -44,19 +44,19 @@ interactions: response: body: string: !!binary | - H4sIAAAAAAAAA9xTy27bMBC86yuIPctBFNmwo1sRwDkG8DUqBJpcS0z4CrlqbRj+94LyQwqSAj33 - osPMzmCGuzpmjIGSUDEQHSdhvJ79EF0I9LFZvzwf9ofVXPziv9frjTHb3fMj5Enhtm8o6Kq6E854 - jaScPdMiICdMrsWyLJer1XyxGAjjJOokaz3NyrvFjPqwdbP74mFxUXZOCYxQsdeMMcaOwzdltBL3 - ULH7/IoYjJG3CNVtiDEITicEeIwqErcE+UgKZwntEPtY2wTVYLnBGipWw5PmvcQa8ivF24Epanua - ugTc9ZGnFrbX+oKfbrG0a31w23jhb/hOWRW7JiCPzqYIkZyHbCL+0rX4H7tmjP0cVt1/agQ+OOOp - IfeONhk+lmc7GI9rJMv5hSRHXI948bDMv7FrJBJXOk7eDwQXHcpROh4W76VyE2K6o69pvvM+F1e2 - /Rf7kRACPaFsfECpxOfG41jA9O/9bez2yENgiIdIaJqdsi0GH9RwEcMuT9kfAAAA//8DAIF37Xz8 - AwAA + H4sIAAAAAAAAAwAAAP//3FOxTsMwFNzzFdabU9QQWlA2BOqAGGABCYIi135NTR3b2C8SUPXfkUPa + BAESM0uGu3enO7+XbcIYKAkFA7HmJBqnJ+eGb+a3D1fX/OahvX6vz28X/v7ybiHMy42BNCrs8hkF + 7VVHwjZOIynb08IjJ4yu2Wk+z8+O56cnHdFYiTrKakeT/Gg2odYv7WSaHc965doqgQEK9pgwxti2 + +8aMRuIrFGya7pEGQ+A1QnEYYgy81REBHoIKxA1BOpDCGkLTxd6WJkIlGN5gCQUr4ULzVmIJ6Z7i + dcdkpdmNXTyu2sBjC9Nq3eO7Qyxta+ftMvT8AV8po8K68siDNTFCIOsgGYm/dc3+Y9eEsadu1e2X + RuC8bRxVZDdoomE27evDcF0Dm5/0JFnieqTKZ+kPfpVE4kqH0QOC4GKNcpAOl8VbqeyIGC/pe5qf + vD+bK1P/xX4ghEBHKCvnUSrxtfEw5jH+fL+NHV65CwzhLRA21UqZGr3zqjuJbpm75AMAAP//AwCz + swXL/QMAAA== headers: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8ef8eb547a107af2-SJC + - 8ff052c38a1d2513-SJC Connection: - keep-alive Content-Encoding: @@ -64,14 +64,14 @@ interactions: Content-Type: - application/json Date: - - Mon, 09 Dec 2024 23:54:15 GMT + - Thu, 09 Jan 2025 00:31:14 GMT Server: - cloudflare Set-Cookie: - - __cf_bm=0Iu69WUF4xDO0vFgaZtty9xUYWYu0lXQJbhdSA2auis-1733788455-1.0.1.1-C_LcbIbrGSDzIlYAOpyRBM7f4J4jywuThItbkN5OYxTafbo1mZrOsdaTFIpVIdmCcfnMLRLLiZudJUgCmUYZJA; - path=/; expires=Tue, 10-Dec-24 00:24:15 GMT; domain=.api.openai.com; HttpOnly; + - __cf_bm=yMDJIJGoTgvr0XyTnHvqjykpwbTij_VNggbHa0u7mjQ-1736382674-1.0.1.1-1BbGRnMcdB.agM3NFbDHdv.rTrCdpbxCQlFER5FxoM0sEo9eSenHk34Mjks9Mw4MylroAxyHfKngc86iIObF1w; + path=/; expires=Thu, 09-Jan-25 01:01:14 GMT; domain=.api.openai.com; HttpOnly; Secure; SameSite=None - - _cfuvid=xNK6lxi6_xb_YkBT8_ajuNZBOXHEpxhoPSIPlaIH7WU-1733788455477-0.0.1.1-604800000; + - _cfuvid=jJ509aH157xZVGEjAnT.8yM8V_vvXauvl90X2nnE.a8-1736382674852-0.0.1.1-604800000; path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None Transfer-Encoding: - chunked @@ -84,7 +84,7 @@ interactions: openai-organization: - future-house-xr4tdh openai-processing-ms: - - "246" + - "282" openai-version: - "2020-10-01" strict-transport-security: @@ -96,13 +96,13 @@ interactions: x-ratelimit-remaining-requests: - "11999" x-ratelimit-remaining-tokens: - - "999894" + - "999886" x-ratelimit-reset-requests: - 5ms x-ratelimit-reset-tokens: - 6ms x-request-id: - - req_a693eeeb97516e218a0ab89620e52988 + - req_32759180c4d54e6ba640880e0a70a889 status: code: 200 message: OK diff --git a/tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[json-mode-type-adapter].yaml b/tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[json-mode-type-adapter].yaml new file mode 100644 index 0000000..ed39207 --- /dev/null +++ b/tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[json-mode-type-adapter].yaml @@ -0,0 +1,103 @@ +interactions: + - request: + body: + '{"messages":[{"role":"system","content":"Respond following this JSON schema:\n\n{\"properties\": + {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"age\": {\"description\": + \"Age in years.\", \"title\": \"Age\", \"type\": \"integer\"}}, \"required\": + [\"name\", \"age\"], \"title\": \"DummyOutputSchema\", \"type\": \"object\"}"},{"role":"user","content":"My + name is Claude and I am 1 year old. What is my name and age?"}],"model":"gpt-4o","n":2,"response_format":{"type":"json_object"}}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - "494" + content-type: + - application/json + host: + - api.openai.com + user-agent: + - AsyncOpenAI/Python 1.56.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - async:asyncio + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.56.2 + x-stainless-raw-response: + - "true" + x-stainless-retry-count: + - "1" + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.12.7 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//3FPBTuMwFLznK6x3TlAS2oTNDSFAHEFaCbRZRa79khoc27IdaUvVf0dO + QxMEK3HmksPMm8mM/byPCAHBoSLAttSz3sjkUtGXUt5J2d48PK5390+vu9XvaytuH153vyAOCr15 + RubfVWdM90aiF1odaWaRegyuWXlenF/kRbkeiV5zlEHWGZ+sdJKn+SpJL5K0mIRbLRg6qMifiBBC + 9uM3RFQc/0FF0vgd6dE52iFUpyFCwGoZEKDOCeep8hDPJNPKoxpT72tQtMcaKlLDlaQDxxpiUgPt + RjA7LIUW28HRkFsNUk744ZRE6s5YvXETf8JboYTbNhap0yr81XltIFqIP9XLfki9iJC/44UOH0qA + sbo3vvH6BVUwzNKpMcwrNLN5MZFeeyoXqryMv/BrOHoqpFucGTDKtshn6bw/dOBCL4jlvXxO85X3 + sblQ3XfsZ4IxNB55YyxywT42nscshhf2v7HTKY+Bwe2cx75pherQGiuOW9CaZlPyYt1mdL2B6BC9 + AQAA//8DAOyaJNbtAwAA + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8ff052c9b8e6fa8a-SJC + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Thu, 09 Jan 2025 00:31:16 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - future-house-xr4tdh + openai-processing-ms: + - "640" + openai-version: + - "2020-10-01" + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - "10000" + x-ratelimit-limit-tokens: + - "30000000" + x-ratelimit-remaining-requests: + - "9999" + x-ratelimit-remaining-tokens: + - "29999886" + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_c993df701dab27752a7cafe3ba505f63 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[structured-outputs].yaml b/tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[structured-outputs].yaml index d04338a..6b2a679 100644 --- a/tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[structured-outputs].yaml +++ b/tests/cassettes/TestMultipleCompletionLLMModel.test_output_schema[structured-outputs].yaml @@ -2,7 +2,8 @@ interactions: - request: body: '{"messages":[{"role":"user","content":"My name is Claude and I am 1 year - old. What is my name and age?"}],"model":"gpt-4o","n":2,"response_format":{"type":"json_schema","json_schema":{"strict":true,"schema":{"properties":{"name":{"title":"Name","type":"string"},"age":{"title":"Age","type":"integer"}},"required":["name","age"],"title":"DummyOutputSchema","type":"object","additionalProperties":false},"name":"DummyOutputSchema"}}}' + old. What is my name and age?"}],"model":"gpt-4o","n":2,"response_format":{"type":"json_schema","json_schema":{"strict":true,"schema":{"properties":{"name":{"title":"Name","type":"string"},"age":{"description":"Age + in years.","title":"Age","type":"integer"}},"required":["name","age"],"title":"DummyOutputSchema","type":"object","additionalProperties":false},"name":"DummyOutputSchema"}}}' headers: accept: - application/json @@ -11,7 +12,7 @@ interactions: connection: - keep-alive content-length: - - "431" + - "461" content-type: - application/json host: @@ -31,7 +32,7 @@ interactions: x-stainless-raw-response: - "true" x-stainless-retry-count: - - "0" + - "1" x-stainless-runtime: - CPython x-stainless-runtime-version: @@ -41,19 +42,19 @@ interactions: response: body: string: !!binary | - H4sIAAAAAAAAAwAAAP//3FNNa9wwFLz7V4h3tos368Qb39qeSkO+IIQQF/NWemurkSVFkqFh2f8e - 5N2sHZJCz734MPNmPCM9bRPGQAqoGPAOA++tyr7yzjnx7ao76x+6HzfXz7YoLi+W5c/b+7tnSKPC - rH8TD2+qL9z0VlGQRu9p7ggDRddFuVyWq1VxejoSvRGkoqy1IStMdpKfFFm+yvKzg7AzkpOHij0m - jDG2Hb8xohb0ByqWp29IT95jS1AdhxgDZ1REAL2XPqAOkE4kNzqQHlNva9DYUw1VDd8VDoJqSGvA - NkKL3VzlaDN4jKH1oNQB3x1jKNNaZ9b+wB/xjdTSd40j9EbHX/pgLCQz8Ydui/+hW8LYr/Eqh3cN - wDrT29AE80Q6GpbnezuYdmciF6sDGUxANeHnZfqJWyMooFR+dlzAkXckJuW0NzgIaWbE/Eo+hvnM - e99b6vZf7CeCc7KBRGMdCcnfF57GHMWX9bex4xmPgcG/+EB9s5G6JWed3C/Axja85JjTGjlCskte - AQAA//8DAOP4nprlAwAA + H4sIAAAAAAAAAwAAAP//3FM9b9swEN31K4ibrUJSbNXWFnQpOnRLGiAqBJo8SUwpHkFSQQLD/72g + /CEZaYHOXTi8d+/xHe94SBgDJaFiIHoexGB1em/406saNv33r5mn9tG8/nj49liOD/uXew2rqKD9 + C4pwUX0SNFiNQZE50cIhDxhd88935d023+W7iRhIoo6yzoZ0TWmRFes026ZZeRb2pAR6qNhzwhhj + h+mMEY3EN6hYtrogA3rPO4TqWsQYONIRAe698oGbAKuZFGQCmin1oQbDB6yhquGL5qPEGlY18C5C + +XGpctiOnsfQZtT6jB+vMTR11tHen/kr3iqjfN845J5MvNIHspAsxB96y/+H3hLGfk6jHG86AOto + sKEJ9AtNNNxuTnYw785MFucxQ6DA9Yzn2UV1Y9dIDFxpv3gvEFz0KGfpvDh8lIoWxHImH9P8yfvU + uDLdv9jPhBBoA8rGOpRK3HY8lzmMX+tvZddHngKDf/cBh6ZVpkNnnTptQGubTVtkZVls1jtIjslv + AAAA//8DAFdJovPmAwAA headers: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8ef8eb5459995c1d-SJC + - 8ff04054ae4b9e5c-SJC Connection: - keep-alive Content-Encoding: @@ -61,15 +62,9 @@ interactions: Content-Type: - application/json Date: - - Mon, 09 Dec 2024 23:54:15 GMT + - Thu, 09 Jan 2025 00:18:40 GMT Server: - cloudflare - Set-Cookie: - - __cf_bm=eThvnPKsEWq81agLsGvOWG3bDSV7z.O._JEvK_LDWKo-1733788455-1.0.1.1-4BqFkZTxOoot56DBhhcY4KTbLV6Dn33HJ7X7MttciM9xBNhDaDEIYxZny..80Wu2vmkDSL5bw7MqQa.zI_1D0g; - path=/; expires=Tue, 10-Dec-24 00:24:15 GMT; domain=.api.openai.com; HttpOnly; - Secure; SameSite=None - - _cfuvid=B45jEr7m1koJPP4VpyiYJCVjXWCenQuz8H_umfGeEBQ-1733788455794-0.0.1.1-604800000; - path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None Transfer-Encoding: - chunked X-Content-Type-Options: @@ -81,7 +76,7 @@ interactions: openai-organization: - future-house-xr4tdh openai-processing-ms: - - "602" + - "892" openai-version: - "2020-10-01" strict-transport-security: @@ -99,7 +94,7 @@ interactions: x-ratelimit-reset-tokens: - 0s x-request-id: - - req_ef98d385b8292c3cd3fe3cb3fd0f31c2 + - req_7ff1ffa34df34287ed396ab6176cbd29 status: code: 200 message: OK diff --git a/tests/test_llms.py b/tests/test_llms.py index 754080b..ffb7cf6 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -7,7 +7,7 @@ import numpy as np import pytest from aviary.core import Message, Tool, ToolRequestMessage -from pydantic import BaseModel +from pydantic import BaseModel, Field, TypeAdapter, computed_field from llmclient.exceptions import JSONSchemaValidationError from llmclient.llms import ( @@ -261,7 +261,12 @@ def test_pickling(self, tmp_path: pathlib.Path) -> None: class DummyOutputSchema(BaseModel): name: str - age: int + age: int = Field(description="Age in years.") + + @computed_field # type: ignore[prop-decorator] + @property + def name_and_age(self) -> str: # So we can test computed_field is not included + return f"{self.name}, {self.age}" class TestMultipleCompletionLLMModel: @@ -364,7 +369,10 @@ def play(move: int | None) -> None: @pytest.mark.parametrize( ("model_name", "output_type"), [ - pytest.param("gpt-3.5-turbo", DummyOutputSchema, id="json-mode"), + pytest.param("gpt-3.5-turbo", DummyOutputSchema, id="json-mode-base-model"), + pytest.param( + "gpt-4o", TypeAdapter(DummyOutputSchema), id="json-mode-type-adapter" + ), pytest.param( "gpt-4o", DummyOutputSchema.model_json_schema(), id="structured-outputs" ),