Skip to content

Commit

Permalink
Add Prompt Driver extra params field
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 25, 2024
1 parent 3542d06 commit 0f3e9c2
Show file tree
Hide file tree
Showing 28 changed files with 137 additions and 33 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BaseFileManagerDriver.load_artifact()` & `BaseFileManagerDriver.save_artifact()` for loading & saving artifacts as files.
- Events `BaseChunkEvent`, `TextChunkEvent`, `ActionChunkEvent`.
- `wrapt` dependency for more robust decorators.
- `BasePromptDriver.extra_params` for passing extra parameters not explicitly declared by the Driver.

### Changed

Expand All @@ -32,6 +33,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- If `EventListener.handler` is None, the event will be published to the `event_listener_driver` as-is.
- **BREAKING**: Moved `griptape.common.observable.observable` to `griptape.common.decorators.observable`.
- **BREAKING**: `AnthropicDriversConfig` no longer bundles `VoyageAiEmbeddingDriver`.
- **BREAKING**: Removed `HuggingFaceHubPromptDriver.params`, use `HuggingFaceHubPromptDriver.extra_params` instead.
- **BREAKING**: Removed `HuggingFacePipelinePromptDriver.params`, use `HuggingFacePipelinePromptDriver.extra_params` instead.
- Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`.
- `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
- `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if prompt_stack.tools and self.use_native_tools
else {}
),
**self.extra_params,
}

def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"eos_token_id": self.tokenizer.tokenizer.eos_token_id,
"stop_strings": self.tokenizer.stop_sequences,
"return_full_text": False,
**self.extra_params,
}

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
else {}
),
**({"system": system_message} if system_message else {}),
**self.extra_params,
}

def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
Expand Down
2 changes: 2 additions & 0 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: An instance of `BaseTokenizer` to when calculating tokens.
stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided.
use_native_tools: Whether to use LLM's native function calling capabilities. Must be supported by the model.
extra_params: Extra parameters to pass to the model.
"""

temperature: float = field(default=0.1, metadata={"serializable": True})
Expand All @@ -54,6 +55,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

def before_run(self, prompt_stack: PromptStack) -> None:
EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
else {}
),
**({"preamble": preamble} if preamble else {}),
**self.extra_params,
}

def __to_cohere_messages(self, messages: list[Message]) -> list[dict]:
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
**self.extra_params,
},
),
**(
Expand Down
13 changes: 9 additions & 4 deletions griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
Attributes:
api_token: Hugging Face Hub API token.
use_gpu: Use GPU during model run.
params: Custom model run parameters.
model: Hugging Face Hub model name.
client: Custom `InferenceApi`.
tokenizer: Custom `HuggingFaceTokenizer`.
"""

api_token: str = field(kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
tokenizer: HuggingFaceTokenizer = field(
default=Factory(
Expand All @@ -56,7 +54,7 @@ def client(self) -> InferenceClient:
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
prompt = self.prompt_stack_to_string(prompt_stack)
full_params = {"return_full_text": False, "max_new_tokens": self.max_tokens, **self.params}
full_params = self._base_params(prompt_stack)
logger.debug((prompt, full_params))

response = self.client.text_generation(
Expand All @@ -76,7 +74,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
prompt = self.prompt_stack_to_string(prompt_stack)
full_params = {"return_full_text": False, "max_new_tokens": self.max_tokens, "stream": True, **self.params}
full_params = {**self._base_params(prompt_stack), "stream": True}
logger.debug((prompt, full_params))

response = self.client.text_generation(prompt, **full_params)
Expand All @@ -95,6 +93,13 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

def _base_params(self, prompt_stack: PromptStack) -> dict:
return {
"return_full_text": False,
"max_new_tokens": self.max_tokens,
**self.extra_params,
}

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
messages = []
for message in prompt_stack.messages:
Expand Down
21 changes: 11 additions & 10 deletions griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver):
"""Hugging Face Pipeline Prompt Driver.
Attributes:
params: Custom model run parameters.
model: Hugging Face Hub model name.
"""

max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
tokenizer: HuggingFaceTokenizer = field(
default=Factory(
lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens),
Expand All @@ -56,20 +54,15 @@ def pipeline(self) -> TextGenerationPipeline:
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
messages = self._prompt_stack_to_messages(prompt_stack)
full_params = self._base_params(prompt_stack)
logger.debug(
(
messages,
{"max_new_tokens": self.max_tokens, "temperature": self.temperature, "do_sample": True, **self.params},
full_params,
)
)

result = self.pipeline(
messages,
max_new_tokens=self.max_tokens,
temperature=self.temperature,
do_sample=True,
**self.params,
)
result = self.pipeline(messages, **full_params)
logger.debug(result)

if isinstance(result, list):
Expand All @@ -96,6 +89,14 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

def _base_params(self, prompt_stack: PromptStack) -> dict:
return {
"max_new_tokens": self.max_tokens,
"temperature": self.temperature,
"do_sample": True,
**self.extra_params,
}

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
messages = []

Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
and not self.stream # Tool calling is only supported when not streaming
else {}
),
**self.extra_params,
}

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
**({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
**({"stream_options": {"include_usage": True}} if self.stream else {}),
**self.extra_params,
}

if self.response_format is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_to_dict(self, config):
"type": "AmazonBedrockPromptDriver",
"tool_choice": {"auto": {}},
"use_native_tools": True,
"extra_params": {},
},
"vector_store_driver": {
"embedding_driver": {
Expand Down Expand Up @@ -117,6 +118,7 @@ def test_to_dict_with_values(self, config_with_values):
"type": "AmazonBedrockPromptDriver",
"tool_choice": {"auto": {}},
"use_native_tools": True,
"extra_params": {},
},
"vector_store_driver": {
"embedding_driver": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_to_dict(self, config):
"top_p": 0.999,
"top_k": 250,
"use_native_tools": True,
"extra_params": {},
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
"image_query_driver": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_to_dict(self, config):
"stream": False,
"user": "",
"use_native_tools": True,
"extra_params": {},
},
"conversation_memory_driver": {
"type": "LocalConversationMemoryDriver",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_cohere_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_to_dict(self, config):
"model": "command-r",
"force_single_step": False,
"use_native_tools": True,
"extra_params": {},
},
"embedding_driver": {
"type": "CohereEmbeddingDriver",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_to_dict(self, config):
"max_tokens": None,
"stream": False,
"use_native_tools": False,
"extra_params": {},
},
"conversation_memory_driver": {
"type": "LocalConversationMemoryDriver",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_google_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_to_dict(self, config):
"top_k": None,
"tool_choice": "auto",
"use_native_tools": True,
"extra_params": {},
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
"image_query_driver": {"type": "DummyImageQueryDriver"},
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_openai_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_to_dict(self, config):
"stream": False,
"user": "",
"use_native_tools": True,
"extra_params": {},
},
"conversation_memory_driver": {
"type": "LocalConversationMemoryDriver",
Expand Down
10 changes: 8 additions & 2 deletions tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,9 @@ def messages(self):
@pytest.mark.parametrize("use_native_tools", [True, False])
def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools):
# Given
driver = AmazonBedrockPromptDriver(model="ai21.j2", use_native_tools=use_native_tools)
driver = AmazonBedrockPromptDriver(
model="ai21.j2", use_native_tools=use_native_tools, extra_params={"foo": "bar"}
)

# When
message = driver.try_run(prompt_stack)
Expand All @@ -376,6 +378,7 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools):
if use_native_tools
else {}
),
foo="bar",
)
assert isinstance(message.value[0], TextArtifact)
assert message.value[0].value == "model-output"
Expand All @@ -390,7 +393,9 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools):
@pytest.mark.parametrize("use_native_tools", [True, False])
def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools):
# Given
driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True, use_native_tools=use_native_tools)
driver = AmazonBedrockPromptDriver(
model="ai21.j2", stream=True, use_native_tools=use_native_tools, extra_params={"foo": "bar"}
)

# When
stream = driver.try_stream(prompt_stack)
Expand All @@ -408,6 +413,7 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_
if prompt_stack.tools and use_native_tools
else {}
),
foo="bar",
)

event = next(stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_init(self):

def test_try_run(self, mock_client):
# Given
driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model")
driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model", extra_params={"foo": "bar"})
prompt_stack = PromptStack()
prompt_stack.add_user_message("prompt-stack")

Expand All @@ -61,6 +61,7 @@ def test_try_run(self, mock_client):
"eos_token_id": 1,
"stop_strings": [],
"return_full_text": False,
"foo": "bar",
},
}
),
Expand Down Expand Up @@ -91,6 +92,7 @@ def test_try_run(self, mock_client):
"eos_token_id": 1,
"stop_strings": [],
"return_full_text": False,
"foo": "bar",
},
}
),
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/drivers/prompt/test_anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ def test_init(self):
@pytest.mark.parametrize("use_native_tools", [True, False])
def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools):
# Given
driver = AnthropicPromptDriver(model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools)
driver = AnthropicPromptDriver(
model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"}
)

# When
message = driver.try_run(prompt_stack)
Expand All @@ -361,6 +363,7 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools):
top_k=250,
**{"system": "system-input"} if prompt_stack.system_messages else {},
**{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
foo="bar",
)
assert isinstance(message.value[0], TextArtifact)
assert message.value[0].value == "model-output"
Expand All @@ -376,7 +379,11 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools):
def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools):
# Given
driver = AnthropicPromptDriver(
model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools
model="claude-3-haiku",
api_key="api-key",
stream=True,
use_native_tools=use_native_tools,
extra_params={"foo": "bar"},
)

# When
Expand All @@ -395,6 +402,7 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na
top_k=250,
**{"system": "system-input"} if prompt_stack.system_messages else {},
**{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
foo="bar",
)
assert event.usage.input_tokens == 5

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_
azure_deployment="deployment-id",
model="gpt-4",
use_native_tools=use_native_tools,
extra_params={"foo": "bar"},
)

# When
Expand All @@ -86,6 +87,7 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_
user=driver.user,
messages=messages,
**{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
foo="bar",
)
assert isinstance(message.value[0], TextArtifact)
assert message.value[0].value == "model-output"
Expand All @@ -104,6 +106,7 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack,
model="gpt-4",
stream=True,
use_native_tools=use_native_tools,
extra_params={"foo": "bar"},
)

# When
Expand All @@ -118,6 +121,7 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack,
stream=True,
messages=messages,
**{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
foo="bar",
)

assert isinstance(event.content, TextDeltaMessageContent)
Expand Down
Loading

0 comments on commit 0f3e9c2

Please sign in to comment.