diff --git a/pyproject.toml b/pyproject.toml index ed2d6ba0..aecb6b4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ types-tqdm = "^4.61.0" tqdm = "^4.66.1" gunicorn = "^21.2.0" types-pyyaml = "^6.0.12.12" +jsonschema = "^4.2.0" +types-jsonschema = "^4.2.0" [tool.poetry.group.dev.dependencies] diff --git a/src/canopy/llm/openai.py b/src/canopy/llm/openai.py index 208b0b5a..902dd077 100644 --- a/src/canopy/llm/openai.py +++ b/src/canopy/llm/openai.py @@ -1,5 +1,6 @@ from typing import Union, Iterable, Optional, Any, Dict, List +import jsonschema import openai import json from tenacity import ( @@ -102,7 +103,8 @@ def streaming_iterator(response): wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(3), retry=retry_if_exception_type( - OPEN_AI_TRANSIENT_EXCEPTIONS + (json.decoder.JSONDecodeError,) + OPEN_AI_TRANSIENT_EXCEPTIONS + (json.decoder.JSONDecodeError, + jsonschema.ValidationError) ), ) def enforced_function_call(self, @@ -172,7 +174,10 @@ def enforced_function_call(self, ) result = chat_completion.choices[0].message.function_call - return json.loads(result["arguments"]) + arguments = json.loads(result["arguments"]) + + jsonschema.validate(instance=arguments, schema=function.parameters.dict()) + return arguments async def achat_completion(self, messages: Messages, *, stream: bool = False, diff --git a/tests/system/llm/test_openai.py b/tests/system/llm/test_openai.py index 879eeca5..d2a11e97 100644 --- a/tests/system/llm/test_openai.py +++ b/tests/system/llm/test_openai.py @@ -1,4 +1,6 @@ -from unittest.mock import patch +from unittest.mock import patch, MagicMock + +import jsonschema import pytest @@ -184,3 +186,22 @@ def test_enforce_function_api_failure_populates(mock_api_call, with pytest.raises(Exception, match="API call failed"): openai_llm.enforced_function_call(messages=messages, function=function_query_knowledgebase) + + @staticmethod + @patch("openai.ChatCompletion") + def test_enforce_function_wrong_output_schema(chat_completion, + openai_llm, + messages, + function_query_knowledgebase): + chat_completion.create.return_value = MagicMock( + choices=[MagicMock( + message=MagicMock( + function_call={"arguments": "{\"key\": \"value\"}"}))]) + + with pytest.raises(jsonschema.ValidationError, + match="'queries' is a required property"): + openai_llm.enforced_function_call(messages=messages, + function=function_query_knowledgebase) + + assert chat_completion.create.call_count == 3, \ + "retry did not happen as expected"