Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

supports-parallel-tools param to OpenAI #821

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 45 additions & 18 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class OpenAIModel(Model):
model_name: OpenAIModelName
client: AsyncOpenAI = field(repr=False)
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
supports_parallel_tools: bool = field(default=True)

def __init__(
self,
Expand All @@ -82,6 +83,7 @@ def __init__(
openai_client: AsyncOpenAI | None = None,
http_client: AsyncHTTPClient | None = None,
system_prompt_role: OpenAISystemPromptRole | None = None,
supports_parallel_tools: bool = True,
):
"""Initialize an OpenAI model.

Expand All @@ -99,6 +101,8 @@ def __init__(
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
In the future, this may be inferred from the model name.
supports_parallel_tools: Whether the model supports parallel tool calls. Set to False for models
that don't support this parameter (e.g. some third-party models).
"""
self.model_name: OpenAIModelName = model_name
if openai_client is not None:
Expand All @@ -111,6 +115,7 @@ def __init__(
else:
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
self.system_prompt_role = system_prompt_role
self.supports_parallel_tools = supports_parallel_tools

async def agent_model(
self,
Expand All @@ -129,6 +134,8 @@ async def agent_model(
allow_text_result,
tools,
self.system_prompt_role,
self.supports_parallel_tools,

)

def name(self) -> str:
Expand All @@ -155,6 +162,7 @@ class OpenAIAgentModel(AgentModel):
allow_text_result: bool
tools: list[chat.ChatCompletionToolParam]
system_prompt_role: OpenAISystemPromptRole | None
supports_parallel_tools: bool = True

async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
Expand Down Expand Up @@ -195,24 +203,43 @@ async def _completions_create(

openai_messages = list(chain(*(self._map_message(m) for m in messages)))

return await self.client.chat.completions.create(
model=self.model_name,
messages=openai_messages,
n=1,
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
tools=self.tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stream=stream,
stream_options={'include_usage': True} if stream else NOT_GIVEN,
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
timeout=model_settings.get('timeout', NOT_GIVEN),
seed=model_settings.get('seed', NOT_GIVEN),
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
)
if self.supports_parallel_tools:
return await self.client.chat.completions.create(
model=self.model_name,
messages=openai_messages,
n=1,
parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
tools=self.tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stream=stream,
stream_options={'include_usage': True} if stream else NOT_GIVEN,
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
timeout=model_settings.get('timeout', NOT_GIVEN),
seed=model_settings.get('seed', NOT_GIVEN),
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
)
else:
return await self.client.chat.completions.create(
model=self.model_name,
messages=openai_messages,
n=1,
tools=self.tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stream=stream,
stream_options={'include_usage': True} if stream else NOT_GIVEN,
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
timeout=model_settings.get('timeout', NOT_GIVEN),
seed=model_settings.get('seed', NOT_GIVEN),
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
)

def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
Expand Down
35 changes: 35 additions & 0 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,38 @@ async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_cal

await agent.run('Hello')
assert get_mock_chat_completion_kwargs(mock_client)[0]['parallel_tool_calls'] == parallel_tool_calls

async def test_supports_parallel_tools(allow_model_requests: None):
"""Test that parallel_tool_calls parameter is only included when supports_parallel_tools is True."""
c = completion_message(
ChatCompletionMessage(
content=None,
role='assistant',
tool_calls=[
chat.ChatCompletionMessageToolCall(
id='123',
function=Function(arguments='{"response": [1, 2, 3]}', name='final_result'),
type='function',
)
],
)
)

# Test with supports_parallel_tools=True (default)
mock_client = MockOpenAI.create_mock(c)
m = OpenAIModel('gpt-4o', openai_client=mock_client)
agent = Agent(m, result_type=list[int])
await agent.run('Hello')

kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
assert 'parallel_tool_calls' in kwargs
assert kwargs['parallel_tool_calls'] is True

# Test with supports_parallel_tools=False
mock_client = MockOpenAI.create_mock(c)
m = OpenAIModel('gpt-4o', openai_client=mock_client, supports_parallel_tools=False)
agent = Agent(m, result_type=list[int])
await agent.run('Hello')

kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
assert 'parallel_tool_calls' not in kwargs
Loading