Skip to content

Commit

Permalink
Move logic from driver to task, remove flag
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jan 3, 2025
1 parent 0925b38 commit c9bcefa
Show file tree
Hide file tree
Showing 41 changed files with 315 additions and 504 deletions.
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- `BaseVectorStoreDriver.query_vector` for querying vector stores with vectors.
- Structured Output support for all Prompt Drivers.
- `PromptTask.output_schema` for setting an output schema to be used with Structured Output.
- `Agent.output_schema` for setting an output schema to be used on the Agent's Prompt Task.
- `BasePromptDriver.structured_output_strategy` for changing the Structured Output strategy between `native`, `tool`, and `rule`.

## [1.1.1] - 2025-01-03

Expand All @@ -31,8 +35,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Structure.run_stream()` for streaming Events from a Structure as an iterator.
- Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`.
- Validators to `Agent` initialization.
- `BasePromptDriver.use_structured_output` for enabling or disabling structured output.
- `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`.

### Changed

Expand Down
25 changes: 11 additions & 14 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,27 @@ You can pass images to the Driver if the model supports it:

## Structured Output

Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems.
Some LLMs provide functionality often referred to as "Structured Output".
This means instructing the LLM to output data in a particular format, usually JSON.
This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems.

Structured output can be enabled or disabled for a Prompt Driver by setting the [use_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_structured_output).

If `use_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of:

- `native`: The Driver will use the LLM's structured output functionality provided by the API.
- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md) and try to force the LLM to use a Tool.

Each Driver may have a different default setting depending on the LLM provider's capabilities.
!!! warning
Each Driver may have a different default setting depending on the LLM provider's capabilities.

### Prompt Task

The easiest way to get started with structured output is by using a [PromptTask](../structures/tasks.md#prompt)'s [output_schema](../../reference/griptape/tasks/prompt_task.md#griptape.tasks.PromptTask.output_schema) parameter.

You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of:

- `native`: The Driver will use the LLM's structured output functionality provided by the API.
- `tool`: The Task will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and the Driver will try to force the LLM to use the Tool.
- `rule`: The Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort.

```python
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py"
```

If `use_structured_output=False`, the Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt.

!!! warning
Not every LLM supports `use_structured_output` or all `structured_output_strategy` options.

## Prompt Drivers

Griptape offers the following Prompt Drivers for interacting with LLMs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
PromptTask(
prompt_driver=OpenAiChatPromptDriver(
model="gpt-4o",
use_structured_output=True, # optional
structured_output_strategy="native", # optional
),
output_schema=schema.Schema(
Expand Down
17 changes: 6 additions & 11 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any

from attrs import Attribute, Factory, define, field
from schema import Schema
Expand Down Expand Up @@ -41,6 +41,7 @@
import boto3

from griptape.common import PromptStack
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
from griptape.tools import BaseTool

logger = logging.getLogger(Defaults.logging_config.logger_name)
Expand All @@ -55,17 +56,16 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: StructuredOutputStrategy = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.")
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

Expand Down Expand Up @@ -134,12 +134,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"toolChoice": self.tool_choice,
}

if (
prompt_stack.output_schema is not None
and self.use_structured_output
and self.structured_output_strategy == "tool"
):
self._add_structured_output_tool_if_absent(prompt_stack)
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
params["toolConfig"]["toolChoice"] = {"any": {}}

params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import boto3

from griptape.common import PromptStack
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy

logger = logging.getLogger(Defaults.logging_config.logger_name)

Expand All @@ -39,8 +40,18 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
),
kw_only=True,
)
structured_output_strategy: StructuredOutputStrategy = field(
default="rule", kw_only=True, metadata={"serializable": True}
)
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value != "rule":
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

@lazy_property()
def client(self) -> Any:
return self.session.client("sagemaker-runtime")
Expand Down
17 changes: 6 additions & 11 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Optional

from attrs import Attribute, Factory, define, field
from schema import Schema
Expand Down Expand Up @@ -42,6 +42,7 @@
from anthropic import Client
from anthropic.types import ContentBlock, ContentBlockDeltaEvent, ContentBlockStartEvent

from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
from griptape.tools.base_tool import BaseTool


Expand All @@ -68,8 +69,7 @@ class AnthropicPromptDriver(BasePromptDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: StructuredOutputStrategy = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
Expand All @@ -80,9 +80,9 @@ def client(self) -> Client:
return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.")
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

Expand Down Expand Up @@ -136,12 +136,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if prompt_stack.tools and self.use_native_tools:
params["tool_choice"] = self.tool_choice

if (
prompt_stack.output_schema is not None
and self.use_structured_output
and self.structured_output_strategy == "tool"
):
self._add_structured_output_tool_if_absent(prompt_stack)
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
params["tool_choice"] = {"type": "any"}

params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)
Expand Down
17 changes: 4 additions & 13 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

from griptape.tokenizers import BaseTokenizer

StructuredOutputStrategy = Literal["native", "tool", "rule"]


@define(kw_only=True)
class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
Expand All @@ -56,9 +58,8 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
structured_output_strategy: StructuredOutputStrategy = field(
default="rule", kw_only=True, metadata={"serializable": True}
)
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

Expand Down Expand Up @@ -126,16 +127,6 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ...
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...

def _add_structured_output_tool_if_absent(self, prompt_stack: PromptStack) -> None:
from griptape.tools.structured_output.tool import StructuredOutputTool

if prompt_stack.output_schema is None:
raise ValueError("PromptStack must have an output schema to use structured output.")

structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_output_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_output_tool)

def __process_run(self, prompt_stack: PromptStack) -> Message:
return self.try_run(prompt_stack)

Expand Down
15 changes: 5 additions & 10 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class CoherePromptDriver(BasePromptDriver):
model: str = field(metadata={"serializable": True})
force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
tokenizer: BaseTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
Expand Down Expand Up @@ -112,15 +111,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema is not None and self.use_structured_output:
if self.structured_output_strategy == "native":
params["response_format"] = {
"type": "json_object",
"schema": prompt_stack.output_schema.json_schema("Output"),
}
elif self.structured_output_strategy == "tool":
# TODO: Implement tool choice once supported
self._add_structured_output_tool_if_absent(prompt_stack)
if prompt_stack.output_schema is not None and self.structured_output_strategy == "native":
params["response_format"] = {
"type": "json_object",
"schema": prompt_stack.output_schema.json_schema("Output"),
}

if prompt_stack.tools and self.use_native_tools:
params["tools"] = self.__to_cohere_tools(prompt_stack.tools)
Expand Down
17 changes: 6 additions & 11 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Optional

from attrs import Attribute, Factory, define, field
from schema import Schema
Expand Down Expand Up @@ -37,6 +37,7 @@
from google.generativeai.protos import Part
from google.generativeai.types import ContentDict, ContentsType, GenerateContentResponse

from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
from griptape.tools import BaseTool

logger = logging.getLogger(Defaults.logging_config.logger_name)
Expand All @@ -63,17 +64,16 @@ class GooglePromptDriver(BasePromptDriver):
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: StructuredOutputStrategy = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True})
_client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value == "native":
raise ValueError("GooglePromptDriver does not support `native` structured output mode.")
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

Expand Down Expand Up @@ -164,13 +164,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if prompt_stack.tools and self.use_native_tools:
params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}

if (
prompt_stack.output_schema is not None
and self.use_structured_output
and self.structured_output_strategy == "tool"
):
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
params["tool_config"]["function_calling_config"]["mode"] = "auto"
self._add_structured_output_tool_if_absent(prompt_stack)

params["tools"] = self.__to_google_tools(prompt_stack.tools)

Expand Down
13 changes: 7 additions & 6 deletions griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING

from attrs import Attribute, Factory, define, field

Expand All @@ -17,6 +17,8 @@

from huggingface_hub import InferenceClient

from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy

logger = logging.getLogger(Defaults.logging_config.logger_name)


Expand All @@ -35,8 +37,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
api_token: str = field(kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: StructuredOutputStrategy = field(
default="native", kw_only=True, metadata={"serializable": True}
)
tokenizer: HuggingFaceTokenizer = field(
Expand All @@ -56,9 +57,9 @@ def client(self) -> InferenceClient:
)

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value == "tool":
raise ValueError("HuggingFaceHubPromptDriver does not support `tool` structured output mode.")
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

Expand Down Expand Up @@ -121,7 +122,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema and self.use_structured_output and self.structured_output_strategy == "native":
if prompt_stack.output_schema and self.structured_output_strategy == "native":
# https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding
output_schema = prompt_stack.output_schema.json_schema("Output Schema")
# Grammar does not support $schema and $id
Expand Down
Loading

0 comments on commit c9bcefa

Please sign in to comment.