Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
Merge pull request #139 from pinecone-io/function-calling-jsonschema-…
Browse files Browse the repository at this point in the history
…validation

validate json schema for OpenAI function calling result
  • Loading branch information
igiloh-pinecone authored Nov 6, 2023
2 parents 0ff074a + 7ef1b54 commit 9868f41
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ types-tqdm = "^4.61.0"
tqdm = "^4.66.1"
gunicorn = "^21.2.0"
types-pyyaml = "^6.0.12.12"
jsonschema = "^4.2.0"
types-jsonschema = "^4.2.0"


[tool.poetry.group.dev.dependencies]
Expand Down
9 changes: 7 additions & 2 deletions src/canopy/llm/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union, Iterable, Optional, Any, Dict, List

import jsonschema
import openai
import json
from tenacity import (
Expand Down Expand Up @@ -102,7 +103,8 @@ def streaming_iterator(response):
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
retry=retry_if_exception_type(
OPEN_AI_TRANSIENT_EXCEPTIONS + (json.decoder.JSONDecodeError,)
OPEN_AI_TRANSIENT_EXCEPTIONS + (json.decoder.JSONDecodeError,
jsonschema.ValidationError)
),
)
def enforced_function_call(self,
Expand Down Expand Up @@ -172,7 +174,10 @@ def enforced_function_call(self,
)

result = chat_completion.choices[0].message.function_call
return json.loads(result["arguments"])
arguments = json.loads(result["arguments"])

jsonschema.validate(instance=arguments, schema=function.parameters.dict())
return arguments

async def achat_completion(self,
messages: Messages, *, stream: bool = False,
Expand Down
23 changes: 22 additions & 1 deletion tests/system/llm/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from unittest.mock import patch
from unittest.mock import patch, MagicMock

import jsonschema
import pytest


Expand Down Expand Up @@ -184,3 +186,22 @@ def test_enforce_function_api_failure_populates(mock_api_call,
with pytest.raises(Exception, match="API call failed"):
openai_llm.enforced_function_call(messages=messages,
function=function_query_knowledgebase)

@staticmethod
@patch("openai.ChatCompletion")
def test_enforce_function_wrong_output_schema(chat_completion,
openai_llm,
messages,
function_query_knowledgebase):
chat_completion.create.return_value = MagicMock(
choices=[MagicMock(
message=MagicMock(
function_call={"arguments": "{\"key\": \"value\"}"}))])

with pytest.raises(jsonschema.ValidationError,
match="'queries' is a required property"):
openai_llm.enforced_function_call(messages=messages,
function=function_query_knowledgebase)

assert chat_completion.create.call_count == 3, \
"retry did not happen as expected"

0 comments on commit 9868f41

Please sign in to comment.