Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

validate json schema for OpenAI function calling result #139

Merged
merged 2 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ types-tqdm = "^4.61.0"
tqdm = "^4.66.1"
gunicorn = "^21.2.0"
types-pyyaml = "^6.0.12.12"
jsonschema = "^4.2.0"
types-jsonschema = "^4.2.0"


[tool.poetry.group.dev.dependencies]
Expand Down
9 changes: 7 additions & 2 deletions src/canopy/llm/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union, Iterable, Optional, Any, Dict, List

import jsonschema
import openai
import json
from tenacity import (
Expand Down Expand Up @@ -102,7 +103,8 @@ def streaming_iterator(response):
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
retry=retry_if_exception_type(
OPEN_AI_TRANSIENT_EXCEPTIONS + (json.decoder.JSONDecodeError,)
OPEN_AI_TRANSIENT_EXCEPTIONS + (json.decoder.JSONDecodeError,
jsonschema.ValidationError)
),
)
def enforced_function_call(self,
Expand Down Expand Up @@ -172,7 +174,10 @@ def enforced_function_call(self,
)

result = chat_completion.choices[0].message.function_call
return json.loads(result["arguments"])
arguments = json.loads(result["arguments"])

jsonschema.validate(instance=arguments, schema=function.parameters.dict())
return arguments

async def achat_completion(self,
messages: Messages, *, stream: bool = False,
Expand Down