Skip to content

Commit

Permalink
Mc/toolmaxretry (#323)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcasty authored Jan 5, 2024
1 parent 05fcf0b commit c16d622
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
29 changes: 20 additions & 9 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
"content": message.content or "",
}
if message.tool_calls is not None:
ret["tool_calls"] = message.model_dump()["tool_calls"]
ret["content"] += json.dumps(message.model_dump()["tool_calls"])
if message.function_call is not None:
ret["content"] += json.dumps(message.model_dump()["function_call"])
Expand Down Expand Up @@ -240,6 +241,15 @@ async def retry_async(
logger.exception(f"Retrying, exception: {e}")
logger.debug(f"Error response: {response}")
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
if mode == Mode.TOOLS:
kwargs["messages"].append(
{
"role": "tool",
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"name": response.choices[0].message.tool_calls[0].function.name,
"content": "failure"
}
)
kwargs["messages"].append(
{
"role": "user",
Expand Down Expand Up @@ -286,6 +296,15 @@ def retry_sync(
logger.exception(f"Retrying, exception: {e}")
logger.debug(f"Error response: {response}")
kwargs["messages"].append(dump_message(response.choices[0].message))
if mode == Mode.TOOLS:
kwargs["messages"].append(
{
"role": "tool",
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"name": response.choices[0].message.tool_calls[0].function.name,
"content": "failure"
}
)
kwargs["messages"].append(
{
"role": "user",
Expand Down Expand Up @@ -323,10 +342,6 @@ async def new_chatcompletion_async(
*args,
**kwargs,
):
if mode == Mode.TOOLS:
max_retries = 0
logger.warning("max_retries is not supported when using tool calls")

response_model, new_kwargs = handle_response_model(
response_model=response_model, kwargs=kwargs, mode=mode
) # type: ignore
Expand All @@ -349,10 +364,6 @@ def new_chatcompletion_sync(
*args,
**kwargs,
):
if mode == Mode.TOOLS:
max_retries = 0
logger.warning("max_retries is not supported when using tool calls")

response_model, new_kwargs = handle_response_model(
response_model=response_model, kwargs=kwargs, mode=mode
) # type: ignore
Expand Down Expand Up @@ -406,4 +417,4 @@ def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
- `validation_context` parameter to validate the response using the pydantic model
- `strict` parameter to use strict json parsing
"""
return patch(client, mode=mode)
return patch(client, mode=mode)
7 changes: 5 additions & 2 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_override_docs():
{
"role": "assistant",
"content": 'Hello, world![{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
"tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}],
},
),
(
Expand All @@ -110,6 +111,7 @@ def test_override_docs():
{
"role": "assistant",
"content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
"tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}],
},
),
(
Expand Down Expand Up @@ -151,7 +153,7 @@ def test_override_docs():
"tool_calls and no content and function_call",
ChatCompletionMessage(
role="assistant",
content=None,
content="",
function_call=FunctionCall(arguments="", name="test_tool"),
tool_calls=[
ChatCompletionMessageToolCall(
Expand All @@ -164,6 +166,7 @@ def test_override_docs():
{
"role": "assistant",
"content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]{"arguments": "", "name": "test_tool"}',
"tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]
},
),
],
Expand All @@ -173,4 +176,4 @@ def test_dump_message(
message: ChatCompletionMessage,
expected: ChatCompletionMessageParam,
):
assert dump_message(message) == expected, name_of_test
assert dump_message(message) == expected, name_of_test

0 comments on commit c16d622

Please sign in to comment.