Skip to content

Commit

Permalink
Rename method for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jan 3, 2025
1 parent 07b7757 commit 24e997f
Show file tree
Hide file tree
Showing 8 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
and self.use_structured_output
and self.structured_output_strategy == "tool"
):
self._add_structured_output_tool(prompt_stack)
self._add_structured_output_tool_if_absent(prompt_stack)
params["toolConfig"]["toolChoice"] = {"any": {}}

params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools)
Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
and self.use_structured_output
and self.structured_output_strategy == "tool"
):
self._add_structured_output_tool(prompt_stack)
self._add_structured_output_tool_if_absent(prompt_stack)
params["tool_choice"] = {"type": "any"}

params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)
Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ...
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...

def _add_structured_output_tool(self, prompt_stack: PromptStack) -> None:
def _add_structured_output_tool_if_absent(self, prompt_stack: PromptStack) -> None:
from griptape.tools.structured_output.tool import StructuredOutputTool

if prompt_stack.output_schema is None:
Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
}
elif self.structured_output_strategy == "tool":
# TODO: Implement tool choice once supported
self._add_structured_output_tool(prompt_stack)
self._add_structured_output_tool_if_absent(prompt_stack)

if prompt_stack.tools and self.use_native_tools:
params["tools"] = self.__to_cohere_tools(prompt_stack.tools)
Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
and self.structured_output_strategy == "tool"
):
params["tool_config"]["function_calling_config"]["mode"] = "auto"
self._add_structured_output_tool(prompt_stack)
self._add_structured_output_tool_if_absent(prompt_stack)

params["tools"] = self.__to_google_tools(prompt_stack.tools)

Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
params["format"] = prompt_stack.output_schema.json_schema("Output")
elif self.structured_output_strategy == "tool":
# TODO: Implement tool choice once supported
self._add_structured_output_tool(prompt_stack)
self._add_structured_output_tool_if_absent(prompt_stack)

# Tool calling is only supported when not streaming
if prompt_stack.tools and self.use_native_tools and not self.stream:
Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
}
elif self.structured_output_strategy == "tool" and self.use_native_tools:
params["tool_choice"] = "required"
self._add_structured_output_tool(prompt_stack)
self._add_structured_output_tool_if_absent(prompt_stack)

if self.response_format is not None:
if self.response_format == {"type": "json_object"}:
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/drivers/prompt/test_base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ def test__add_structured_output_tool(self):
prompt_stack = PromptStack()

with pytest.raises(ValueError, match="PromptStack must have an output schema to use structured output."):
mock_prompt_driver._add_structured_output_tool(prompt_stack)
mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack)

prompt_stack.output_schema = Schema({"foo": str})

mock_prompt_driver._add_structured_output_tool(prompt_stack)
mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack)
# Ensure it doesn't get added twice
mock_prompt_driver._add_structured_output_tool(prompt_stack)
mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack)
assert len(prompt_stack.tools) == 1
assert isinstance(prompt_stack.tools[0], StructuredOutputTool)
assert prompt_stack.tools[0].output_schema is prompt_stack.output_schema

0 comments on commit 24e997f

Please sign in to comment.