Skip to content

Commit

Permalink
See if refusal attribute exists in ChatCompletionMessage before refer…
Browse files Browse the repository at this point in the history
…encing it (#962)

Co-authored-by: Ivan Leo <[email protected]>
  • Loading branch information
callmephilip and ivanleomk authored Aug 31, 2024
1 parent 02fcfe3 commit 06fc5a3
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 5 deletions.
9 changes: 6 additions & 3 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,15 @@ def parse_tools(
strict: Optional[bool] = None,
) -> BaseModel:
message = completion.choices[0].message
# this field seems to be missing when using instructor with some other tools (e.g. litellm)
# trying to fix this by adding a check
if hasattr(message, "refusal"):
assert (
message.refusal is None
), f"Unable to generate a response due to {message.refusal}"
assert (
len(message.tool_calls or []) == 1
), "Instructor does not support multiple tool calls, use List[Model] instead."
assert (
message.refusal is None
), f"Unable to generate a response due to {message.refusal}"
tool_call = message.tool_calls[0] # type: ignore
assert (
tool_call.function.name == cls.openai_schema["name"] # type: ignore[index]
Expand Down
108 changes: 106 additions & 2 deletions tests/test_function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import pytest
from anthropic.types import Message, Usage
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
from pydantic import BaseModel, ValidationError

import instructor
Expand Down Expand Up @@ -206,4 +208,106 @@ class Model(BaseModel):
def test_mode_functions_deprecation_warning() -> None:
from openai import OpenAI
with pytest.warns(DeprecationWarning, match="The FUNCTIONS mode is deprecated and will be removed in future versions"):
_ = instructor.from_openai(OpenAI(), mode=instructor.Mode.FUNCTIONS)
_ = instructor.from_openai(OpenAI(), mode=instructor.Mode.FUNCTIONS)

def test_refusal_attribute(test_model: type[OpenAISchema]):
completion = ChatCompletion(
id="test_id",
created=1234567890,
model="gpt-3.5-turbo",
object="chat.completion",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
content="test_content",
refusal="test_refusal",
role="assistant",
tool_calls=[],
),
finish_reason="stop",
logprobs=None,
)
],
)

try:

test_model.from_response(completion, mode=instructor.Mode.TOOLS)
except Exception as e:
assert "Unable to generate a response due to test_refusal" in str(e)


def test_no_refusal_attribute(test_model: type[OpenAISchema]):
completion = ChatCompletion(
id="test_id",
created=1234567890,
model="gpt-3.5-turbo",
object="chat.completion",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
content="test_content",
refusal=None,
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
id="test_id",
function=Function(
name="TestModel",
arguments='{"data": "test_data", "name": "TestModel"}',
),
type="function",
)
],
),
finish_reason="stop",
logprobs=None,
)
],
)

resp = test_model.from_response(completion, mode=instructor.Mode.TOOLS)
assert resp.data == "test_data"
assert resp.name == "TestModel"


def test_missing_refusal_attribute(test_model: type[OpenAISchema]):
message_without_refusal_attribute = ChatCompletionMessage(
content="test_content",
refusal="test_refusal",
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
id="test_id",
function=Function(
name="TestModel",
arguments='{"data": "test_data", "name": "TestModel"}',
),
type="function",
)
],
)

del message_without_refusal_attribute.refusal
assert not hasattr(message_without_refusal_attribute, "refusal")

completion = ChatCompletion(
id="test_id",
created=1234567890,
model="gpt-3.5-turbo",
object="chat.completion",
choices=[
Choice(
index=0,
message=message_without_refusal_attribute,
finish_reason="stop",
logprobs=None,
)
],
)

resp = test_model.from_response(completion, mode=instructor.Mode.TOOLS)
assert resp.data == "test_data"
assert resp.name == "TestModel"

0 comments on commit 06fc5a3

Please sign in to comment.