Skip to content

Commit

Permalink
Move logic from driver to task, remove flag
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jan 3, 2025
1 parent 24e997f commit 5ea1534
Show file tree
Hide file tree
Showing 39 changed files with 189 additions and 353 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Structure.run_stream()` for streaming Events from a Structure as an iterator.
- Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`.
- Validators to `Agent` initialization.
- `BasePromptDriver.use_structured_output` for enabling or disabling structured output.
- `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`.

### Changed
Expand Down
8 changes: 2 additions & 6 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ You can pass images to the Driver if the model supports it:

Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems.

Structured output can be enabled or disabled for a Prompt Driver by setting the [use_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_structured_output).

If `use_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of:
You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of:

- `native`: The Driver will use the LLM's structured output functionality provided by the API.
- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md) and try to force the LLM to use a Tool.
Expand All @@ -46,10 +44,8 @@ The easiest way to get started with structured output is by using a [PromptTask]
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py"
```

If `use_structured_output=False`, the Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt.

!!! warning
Not every LLM supports `use_structured_output` or all `structured_output_strategy` options.
Not every LLM supports all `structured_output_strategy` options.

## Prompt Drivers

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
PromptTask(
prompt_driver=OpenAiChatPromptDriver(
model="gpt-4o",
use_structured_output=True, # optional
structured_output_strategy="native", # optional
),
output_schema=schema.Schema(
Expand Down
17 changes: 6 additions & 11 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any

from attrs import Attribute, Factory, define, field
from schema import Schema
Expand Down Expand Up @@ -41,6 +41,7 @@
import boto3

from griptape.common import PromptStack
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
from griptape.tools import BaseTool

logger = logging.getLogger(Defaults.logging_config.logger_name)
Expand All @@ -55,17 +56,16 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: StructuredOutputStrategy = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.")
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

Expand Down Expand Up @@ -134,12 +134,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"toolChoice": self.tool_choice,
}

if (
prompt_stack.output_schema is not None
and self.use_structured_output
and self.structured_output_strategy == "tool"
):
self._add_structured_output_tool_if_absent(prompt_stack)
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
params["toolConfig"]["toolChoice"] = {"any": {}}

params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import boto3

from griptape.common import PromptStack
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy

logger = logging.getLogger(Defaults.logging_config.logger_name)

Expand All @@ -39,8 +40,18 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
),
kw_only=True,
)
structured_output_strategy: StructuredOutputStrategy = field(
default="rule", kw_only=True, metadata={"serializable": True}
)
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value != "rule":
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

@lazy_property()
def client(self) -> Any:
return self.session.client("sagemaker-runtime")
Expand Down
17 changes: 6 additions & 11 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Optional

from attrs import Attribute, Factory, define, field
from schema import Schema
Expand Down Expand Up @@ -42,6 +42,7 @@
from anthropic import Client
from anthropic.types import ContentBlock, ContentBlockDeltaEvent, ContentBlockStartEvent

from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
from griptape.tools.base_tool import BaseTool


Expand All @@ -68,8 +69,7 @@ class AnthropicPromptDriver(BasePromptDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: StructuredOutputStrategy = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
Expand All @@ -80,9 +80,9 @@ def client(self) -> Client:
return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.")
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

Expand Down Expand Up @@ -136,12 +136,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if prompt_stack.tools and self.use_native_tools:
params["tool_choice"] = self.tool_choice

if (
prompt_stack.output_schema is not None
and self.use_structured_output
and self.structured_output_strategy == "tool"
):
self._add_structured_output_tool_if_absent(prompt_stack)
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
params["tool_choice"] = {"type": "any"}

params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)
Expand Down
17 changes: 4 additions & 13 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

from griptape.tokenizers import BaseTokenizer

StructuredOutputStrategy = Literal["native", "tool", "rule"]


@define(kw_only=True)
class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
Expand All @@ -56,9 +58,8 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
structured_output_strategy: StructuredOutputStrategy = field(
default="rule", kw_only=True, metadata={"serializable": True}
)
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

Expand Down Expand Up @@ -126,16 +127,6 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ...
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...

def _add_structured_output_tool_if_absent(self, prompt_stack: PromptStack) -> None:
from griptape.tools.structured_output.tool import StructuredOutputTool

if prompt_stack.output_schema is None:
raise ValueError("PromptStack must have an output schema to use structured output.")

structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_output_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_output_tool)

def __process_run(self, prompt_stack: PromptStack) -> Message:
return self.try_run(prompt_stack)

Expand Down
15 changes: 5 additions & 10 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class CoherePromptDriver(BasePromptDriver):
model: str = field(metadata={"serializable": True})
force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
tokenizer: BaseTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
Expand Down Expand Up @@ -112,15 +111,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema is not None and self.use_structured_output:
if self.structured_output_strategy == "native":
params["response_format"] = {
"type": "json_object",
"schema": prompt_stack.output_schema.json_schema("Output"),
}
elif self.structured_output_strategy == "tool":
# TODO: Implement tool choice once supported
self._add_structured_output_tool_if_absent(prompt_stack)
if prompt_stack.output_schema is not None and self.structured_output_strategy == "native":
params["response_format"] = {
"type": "json_object",
"schema": prompt_stack.output_schema.json_schema("Output"),
}

if prompt_stack.tools and self.use_native_tools:
params["tools"] = self.__to_cohere_tools(prompt_stack.tools)
Expand Down
17 changes: 6 additions & 11 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Optional

from attrs import Attribute, Factory, define, field
from schema import Schema
Expand Down Expand Up @@ -37,6 +37,7 @@
from google.generativeai.protos import Part
from google.generativeai.types import ContentDict, ContentsType, GenerateContentResponse

from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
from griptape.tools import BaseTool

logger = logging.getLogger(Defaults.logging_config.logger_name)
Expand All @@ -63,17 +64,16 @@ class GooglePromptDriver(BasePromptDriver):
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: StructuredOutputStrategy = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True})
_client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value == "native":
raise ValueError("GooglePromptDriver does not support `native` structured output mode.")
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

Expand Down Expand Up @@ -164,13 +164,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if prompt_stack.tools and self.use_native_tools:
params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}

if (
prompt_stack.output_schema is not None
and self.use_structured_output
and self.structured_output_strategy == "tool"
):
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
params["tool_config"]["function_calling_config"]["mode"] = "auto"
self._add_structured_output_tool_if_absent(prompt_stack)

params["tools"] = self.__to_google_tools(prompt_stack.tools)

Expand Down
7 changes: 3 additions & 4 deletions griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
api_token: str = field(kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
)
Expand All @@ -56,9 +55,9 @@ def client(self) -> InferenceClient:
)

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value == "tool":
raise ValueError("HuggingFaceHubPromptDriver does not support `tool` structured output mode.")
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

Expand Down Expand Up @@ -121,7 +120,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema and self.use_structured_output and self.structured_output_strategy == "native":
if prompt_stack.output_schema and self.structured_output_strategy == "native":
# https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding
output_schema = prompt_stack.output_schema.json_schema("Output Schema")
# Grammar does not support $schema and $id
Expand Down
14 changes: 13 additions & 1 deletion griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import TYPE_CHECKING

from attrs import Factory, define, field
from attrs import Attribute, Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.common import DeltaMessage, Message, PromptStack, TextMessageContent, observable
Expand All @@ -18,6 +18,8 @@

from transformers import TextGenerationPipeline

from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy

logger = logging.getLogger(Defaults.logging_config.logger_name)


Expand All @@ -38,10 +40,20 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver):
),
kw_only=True,
)
structured_output_strategy: StructuredOutputStrategy = field(
default="rule", kw_only=True, metadata={"serializable": True}
)
_pipeline: TextGenerationPipeline = field(
default=None, kw_only=True, alias="pipeline", metadata={"serializable": False}
)

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value != "rule":
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

@lazy_property()
def pipeline(self) -> TextGenerationPipeline:
return import_optional_dependency("transformers").pipeline(
Expand Down
9 changes: 2 additions & 7 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class OllamaPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
Expand Down Expand Up @@ -110,12 +109,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema is not None and self.use_structured_output:
if self.structured_output_strategy == "native":
params["format"] = prompt_stack.output_schema.json_schema("Output")
elif self.structured_output_strategy == "tool":
# TODO: Implement tool choice once supported
self._add_structured_output_tool_if_absent(prompt_stack)
if prompt_stack.output_schema is not None and self.structured_output_strategy == "native":
params["format"] = prompt_stack.output_schema.json_schema("Output")

# Tool calling is only supported when not streaming
if prompt_stack.tools and self.use_native_tools and not self.stream:
Expand Down
Loading

0 comments on commit 5ea1534

Please sign in to comment.