diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 1190a189..f451a8b5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -72,6 +72,7 @@ class OpenAIModel(Model): model_name: OpenAIModelName client: AsyncOpenAI = field(repr=False) system_prompt_role: OpenAISystemPromptRole | None = field(default=None) + supports_parallel_tools: bool = field(default=True) def __init__( self, @@ -82,6 +83,7 @@ def __init__( openai_client: AsyncOpenAI | None = None, http_client: AsyncHTTPClient | None = None, system_prompt_role: OpenAISystemPromptRole | None = None, + supports_parallel_tools: bool = True, ): """Initialize an OpenAI model. @@ -99,6 +101,8 @@ def __init__( http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`. In the future, this may be inferred from the model name. + supports_parallel_tools: Whether the model supports parallel tool calls. Set to False for models + that don't support this parameter (e.g. some third-party models). """ self.model_name: OpenAIModelName = model_name if openai_client is not None: @@ -111,6 +115,7 @@ def __init__( else: self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client()) self.system_prompt_role = system_prompt_role + self.supports_parallel_tools = supports_parallel_tools async def agent_model( self, @@ -129,6 +134,8 @@ async def agent_model( allow_text_result, tools, self.system_prompt_role, + self.supports_parallel_tools, + ) def name(self) -> str: @@ -155,6 +162,7 @@ class OpenAIAgentModel(AgentModel): allow_text_result: bool tools: list[chat.ChatCompletionToolParam] system_prompt_role: OpenAISystemPromptRole | None + supports_parallel_tools: bool = True async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None @@ -195,24 +203,43 @@ async def _completions_create( openai_messages = list(chain(*(self._map_message(m) for m in messages))) - return await self.client.chat.completions.create( - model=self.model_name, - messages=openai_messages, - n=1, - parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), - tools=self.tools or NOT_GIVEN, - tool_choice=tool_choice or NOT_GIVEN, - stream=stream, - stream_options={'include_usage': True} if stream else NOT_GIVEN, - max_tokens=model_settings.get('max_tokens', NOT_GIVEN), - temperature=model_settings.get('temperature', NOT_GIVEN), - top_p=model_settings.get('top_p', NOT_GIVEN), - timeout=model_settings.get('timeout', NOT_GIVEN), - seed=model_settings.get('seed', NOT_GIVEN), - presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), - frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), - logit_bias=model_settings.get('logit_bias', NOT_GIVEN), - ) + if self.supports_parallel_tools: + return await self.client.chat.completions.create( + model=self.model_name, + messages=openai_messages, + n=1, + parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN), + tools=self.tools or NOT_GIVEN, + tool_choice=tool_choice or NOT_GIVEN, + stream=stream, + stream_options={'include_usage': True} if stream else NOT_GIVEN, + max_tokens=model_settings.get('max_tokens', NOT_GIVEN), + temperature=model_settings.get('temperature', NOT_GIVEN), + top_p=model_settings.get('top_p', NOT_GIVEN), + timeout=model_settings.get('timeout', NOT_GIVEN), + seed=model_settings.get('seed', NOT_GIVEN), + presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), + frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), + logit_bias=model_settings.get('logit_bias', NOT_GIVEN), + ) + else: + return await self.client.chat.completions.create( + model=self.model_name, + messages=openai_messages, + n=1, + tools=self.tools or NOT_GIVEN, + tool_choice=tool_choice or NOT_GIVEN, + stream=stream, + stream_options={'include_usage': True} if stream else NOT_GIVEN, + max_tokens=model_settings.get('max_tokens', NOT_GIVEN), + temperature=model_settings.get('temperature', NOT_GIVEN), + top_p=model_settings.get('top_p', NOT_GIVEN), + timeout=model_settings.get('timeout', NOT_GIVEN), + seed=model_settings.get('seed', NOT_GIVEN), + presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), + frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), + logit_bias=model_settings.get('logit_bias', NOT_GIVEN), + ) def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 0fa1f923..24445885 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -563,3 +563,38 @@ async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_cal await agent.run('Hello') assert get_mock_chat_completion_kwargs(mock_client)[0]['parallel_tool_calls'] == parallel_tool_calls + +async def test_supports_parallel_tools(allow_model_requests: None): + """Test that parallel_tool_calls parameter is only included when supports_parallel_tools is True.""" + c = completion_message( + ChatCompletionMessage( + content=None, + role='assistant', + tool_calls=[ + chat.ChatCompletionMessageToolCall( + id='123', + function=Function(arguments='{"response": [1, 2, 3]}', name='final_result'), + type='function', + ) + ], + ) + ) + + # Test with supports_parallel_tools=True (default) + mock_client = MockOpenAI.create_mock(c) + m = OpenAIModel('gpt-4o', openai_client=mock_client) + agent = Agent(m, result_type=list[int]) + await agent.run('Hello') + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert 'parallel_tool_calls' in kwargs + assert kwargs['parallel_tool_calls'] is True + + # Test with supports_parallel_tools=False + mock_client = MockOpenAI.create_mock(c) + m = OpenAIModel('gpt-4o', openai_client=mock_client, supports_parallel_tools=False) + agent = Agent(m, result_type=list[int]) + await agent.run('Hello') + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert 'parallel_tool_calls' not in kwargs \ No newline at end of file