Skip to content

Commit

Permalink
Flip condition on reasoning effort
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Feb 10, 2025
1 parent f377fde commit d6dd5c3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
6 changes: 5 additions & 1 deletion griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"user": self.user,
"seed": self.seed,
**({"modalities": self.modalities} if self.modalities and not self.is_reasoning_model else {}),
**({"reasoning_effort": self.reasoning_effort} if self.model in ("o1", "o3-mini") else {}),
**(
{"reasoning_effort": self.reasoning_effort}
if self.is_reasoning_model and self.model != "o1-mini"
else {}
),
**({"temperature": self.temperature} if not self.is_reasoning_model else {}),
**({"audio": self.audio} if "audio" in self.modalities else {}),
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
Expand Down
18 changes: 5 additions & 13 deletions tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def test_init(self):

@pytest.mark.parametrize("use_native_tools", [True, False])
@pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"])
@pytest.mark.parametrize("model", ["gpt-4o", "o1", "o3"])
@pytest.mark.parametrize("model", ["gpt-4o", "o1", "o3", "o3-mini"])
@pytest.mark.parametrize("modalities", [["text"], ["text", "audio"], ["audio"]])
def test_try_run(
self,
Expand Down Expand Up @@ -496,7 +496,7 @@ def test_try_run(
**{
"reasoning_effort": driver.reasoning_effort,
}
if driver.model in ("o1", "o3-mini")
if driver.is_reasoning_model and model != "o1-mini"
else {},
**{
"temperature": driver.temperature,
Expand Down Expand Up @@ -631,7 +631,7 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create,

@pytest.mark.parametrize("use_native_tools", [True, False])
@pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"])
@pytest.mark.parametrize("model", ["gpt-4o", "o1", "o3"])
@pytest.mark.parametrize("model", ["gpt-4o", "o1", "o3", "o3-mini"])
@pytest.mark.parametrize("modalities", [["text"], ["text", "audio"], ["audio"]])
def test_try_stream_run(
self,
Expand Down Expand Up @@ -671,16 +671,8 @@ def test_try_stream_run(
}
if "audio" in driver.modalities
else {},
**{
"modalities": driver.modalities,
}
if not driver.is_reasoning_model
else {},
**{
"reasoning_effort": driver.reasoning_effort,
}
if driver.model in ("o1", "o3-mini")
else {},
**{"modalities": driver.modalities} if not driver.is_reasoning_model else {},
**{"reasoning_effort": driver.reasoning_effort} if driver.is_reasoning_model and model != "o1-mini" else {},
**{
"temperature": driver.temperature,
}
Expand Down

0 comments on commit d6dd5c3

Please sign in to comment.