Skip to content
This repository was archived by the owner on Nov 13, 2024. It is now read-only.

validate json schema for OpenAI function calling result #139

Merged
merged 2 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ types-tqdm = "^4.61.0"
tqdm = "^4.66.1"
gunicorn = "^21.2.0"
types-pyyaml = "^6.0.12.12"
jsonschema = "^4.2.0"
types-jsonschema = "^4.2.0"


[tool.poetry.group.dev.dependencies]
Expand Down
9 changes: 7 additions & 2 deletions src/canopy/llm/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union, Iterable, Optional, Any, Dict, List

import jsonschema
import openai
import json
from tenacity import (
Expand Down Expand Up @@ -102,7 +103,8 @@ def streaming_iterator(response):
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
retry=retry_if_exception_type(
OPEN_AI_TRANSIENT_EXCEPTIONS + (json.decoder.JSONDecodeError,)
OPEN_AI_TRANSIENT_EXCEPTIONS + (json.decoder.JSONDecodeError,
jsonschema.ValidationError)
),
)
def enforced_function_call(self,
Expand Down Expand Up @@ -172,7 +174,10 @@ def enforced_function_call(self,
)

result = chat_completion.choices[0].message.function_call
return json.loads(result["arguments"])
arguments = json.loads(result["arguments"])

jsonschema.validate(instance=arguments, schema=function.parameters.dict())
return arguments

async def achat_completion(self,
messages: Messages, *, stream: bool = False,
Expand Down
23 changes: 22 additions & 1 deletion tests/system/llm/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from unittest.mock import patch
from unittest.mock import patch, MagicMock

import jsonschema
import pytest


Expand Down Expand Up @@ -184,3 +186,22 @@ def test_enforce_function_api_failure_populates(mock_api_call,
with pytest.raises(Exception, match="API call failed"):
openai_llm.enforced_function_call(messages=messages,
function=function_query_knowledgebase)

@staticmethod
@patch("openai.ChatCompletion")
def test_enforce_function_wrong_output_schema(chat_completion,
openai_llm,
messages,
function_query_knowledgebase):
chat_completion.create.return_value = MagicMock(
choices=[MagicMock(
message=MagicMock(
function_call={"arguments": "{\"key\": \"value\"}"}))])

with pytest.raises(jsonschema.ValidationError,
match="'queries' is a required property"):
openai_llm.enforced_function_call(messages=messages,
function=function_query_knowledgebase)

assert chat_completion.create.call_count == 3, \
"retry did not happen as expected"