Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prompt Driver extra params field #1291

Merged
merged 3 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BaseFileManagerDriver.load_artifact()` & `BaseFileManagerDriver.save_artifact()` for loading & saving artifacts as files.
- Events `BaseChunkEvent`, `TextChunkEvent`, `ActionChunkEvent`.
- `wrapt` dependency for more robust decorators.
- `BasePromptDriver.extra_params` for passing extra parameters not explicitly declared by the Driver.

### Changed

Expand All @@ -32,6 +33,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- If `EventListener.handler` is None, the event will be published to the `event_listener_driver` as-is.
- **BREAKING**: Moved `griptape.common.observable.observable` to `griptape.common.decorators.observable`.
- **BREAKING**: `AnthropicDriversConfig` no longer bundles `VoyageAiEmbeddingDriver`.
- **BREAKING**: Removed `HuggingFaceHubPromptDriver.params`, use `HuggingFaceHubPromptDriver.extra_params` instead.
- **BREAKING**: Removed `HuggingFacePipelinePromptDriver.params`, use `HuggingFacePipelinePromptDriver.extra_params` instead.
- Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`.
- `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
- `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if prompt_stack.tools and self.use_native_tools
else {}
),
**self.extra_params,
}

def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"eos_token_id": self.tokenizer.tokenizer.eos_token_id,
"stop_strings": self.tokenizer.stop_sequences,
"return_full_text": False,
**self.extra_params,
}

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
else {}
),
**({"system": system_message} if system_message else {}),
**self.extra_params,
}

def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
Expand Down
2 changes: 2 additions & 0 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: An instance of `BaseTokenizer` to when calculating tokens.
stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided.
use_native_tools: Whether to use LLM's native function calling capabilities. Must be supported by the model.
extra_params: Extra parameters to pass to the model.
"""

temperature: float = field(default=0.1, metadata={"serializable": True})
Expand All @@ -54,6 +55,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

def before_run(self, prompt_stack: PromptStack) -> None:
EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
else {}
),
**({"preamble": preamble} if preamble else {}),
**self.extra_params,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question for all, should extra_params be able to overwrite all other parameters?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my suspicion is yes, but it could lead to weird behavior

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While it could allow users to shoot themselves in the foot, I don't think we should block behavior.

}

def __to_cohere_messages(self, messages: list[Message]) -> list[dict]:
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
**self.extra_params,
},
),
**(
Expand Down
13 changes: 9 additions & 4 deletions griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
Attributes:
api_token: Hugging Face Hub API token.
use_gpu: Use GPU during model run.
params: Custom model run parameters.
model: Hugging Face Hub model name.
client: Custom `InferenceApi`.
tokenizer: Custom `HuggingFaceTokenizer`.
"""

api_token: str = field(kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
tokenizer: HuggingFaceTokenizer = field(
default=Factory(
Expand All @@ -56,7 +54,7 @@ def client(self) -> InferenceClient:
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
prompt = self.prompt_stack_to_string(prompt_stack)
full_params = {"return_full_text": False, "max_new_tokens": self.max_tokens, **self.params}
full_params = self._base_params(prompt_stack)
logger.debug((prompt, full_params))

response = self.client.text_generation(
Expand All @@ -76,7 +74,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
prompt = self.prompt_stack_to_string(prompt_stack)
full_params = {"return_full_text": False, "max_new_tokens": self.max_tokens, "stream": True, **self.params}
full_params = {**self._base_params(prompt_stack), "stream": True}
logger.debug((prompt, full_params))

response = self.client.text_generation(prompt, **full_params)
Expand All @@ -95,6 +93,13 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

def _base_params(self, prompt_stack: PromptStack) -> dict:
return {
"return_full_text": False,
"max_new_tokens": self.max_tokens,
**self.extra_params,
}

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
messages = []
for message in prompt_stack.messages:
Expand Down
21 changes: 11 additions & 10 deletions griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver):
"""Hugging Face Pipeline Prompt Driver.

Attributes:
params: Custom model run parameters.
model: Hugging Face Hub model name.
"""

max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
tokenizer: HuggingFaceTokenizer = field(
default=Factory(
lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens),
Expand All @@ -56,20 +54,15 @@ def pipeline(self) -> TextGenerationPipeline:
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
messages = self._prompt_stack_to_messages(prompt_stack)
full_params = self._base_params(prompt_stack)
logger.debug(
(
messages,
{"max_new_tokens": self.max_tokens, "temperature": self.temperature, "do_sample": True, **self.params},
full_params,
)
)

result = self.pipeline(
messages,
max_new_tokens=self.max_tokens,
temperature=self.temperature,
do_sample=True,
**self.params,
)
result = self.pipeline(messages, **full_params)
logger.debug(result)

if isinstance(result, list):
Expand All @@ -96,6 +89,14 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

def _base_params(self, prompt_stack: PromptStack) -> dict:
return {
"max_new_tokens": self.max_tokens,
"temperature": self.temperature,
"do_sample": True,
**self.extra_params,
}

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
messages = []

Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
and not self.stream # Tool calling is only supported when not streaming
else {}
),
**self.extra_params,
}

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
**({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
**({"stream_options": {"include_usage": True}} if self.stream else {}),
**self.extra_params,
}

if self.response_format is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_to_dict(self, config):
"type": "AmazonBedrockPromptDriver",
"tool_choice": {"auto": {}},
"use_native_tools": True,
"extra_params": {},
},
"vector_store_driver": {
"embedding_driver": {
Expand Down Expand Up @@ -117,6 +118,7 @@ def test_to_dict_with_values(self, config_with_values):
"type": "AmazonBedrockPromptDriver",
"tool_choice": {"auto": {}},
"use_native_tools": True,
"extra_params": {},
},
"vector_store_driver": {
"embedding_driver": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_to_dict(self, config):
"top_p": 0.999,
"top_k": 250,
"use_native_tools": True,
"extra_params": {},
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
"image_query_driver": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_to_dict(self, config):
"stream": False,
"user": "",
"use_native_tools": True,
"extra_params": {},
},
"conversation_memory_driver": {
"type": "LocalConversationMemoryDriver",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_cohere_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_to_dict(self, config):
"model": "command-r",
"force_single_step": False,
"use_native_tools": True,
"extra_params": {},
},
"embedding_driver": {
"type": "CohereEmbeddingDriver",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_to_dict(self, config):
"max_tokens": None,
"stream": False,
"use_native_tools": False,
"extra_params": {},
},
"conversation_memory_driver": {
"type": "LocalConversationMemoryDriver",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_google_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_to_dict(self, config):
"top_k": None,
"tool_choice": "auto",
"use_native_tools": True,
"extra_params": {},
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
"image_query_driver": {"type": "DummyImageQueryDriver"},
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_openai_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_to_dict(self, config):
"stream": False,
"user": "",
"use_native_tools": True,
"extra_params": {},
},
"conversation_memory_driver": {
"type": "LocalConversationMemoryDriver",
Expand Down
10 changes: 8 additions & 2 deletions tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,9 @@ def messages(self):
@pytest.mark.parametrize("use_native_tools", [True, False])
def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools):
# Given
driver = AmazonBedrockPromptDriver(model="ai21.j2", use_native_tools=use_native_tools)
driver = AmazonBedrockPromptDriver(
model="ai21.j2", use_native_tools=use_native_tools, extra_params={"foo": "bar"}
)

# When
message = driver.try_run(prompt_stack)
Expand All @@ -376,6 +378,7 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools):
if use_native_tools
else {}
),
foo="bar",
)
assert isinstance(message.value[0], TextArtifact)
assert message.value[0].value == "model-output"
Expand All @@ -390,7 +393,9 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools):
@pytest.mark.parametrize("use_native_tools", [True, False])
def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools):
# Given
driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True, use_native_tools=use_native_tools)
driver = AmazonBedrockPromptDriver(
model="ai21.j2", stream=True, use_native_tools=use_native_tools, extra_params={"foo": "bar"}
)

# When
stream = driver.try_stream(prompt_stack)
Expand All @@ -408,6 +413,7 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_
if prompt_stack.tools and use_native_tools
else {}
),
foo="bar",
)

event = next(stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_init(self):

def test_try_run(self, mock_client):
# Given
driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model")
driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model", extra_params={"foo": "bar"})
prompt_stack = PromptStack()
prompt_stack.add_user_message("prompt-stack")

Expand All @@ -61,6 +61,7 @@ def test_try_run(self, mock_client):
"eos_token_id": 1,
"stop_strings": [],
"return_full_text": False,
"foo": "bar",
},
}
),
Expand Down Expand Up @@ -91,6 +92,7 @@ def test_try_run(self, mock_client):
"eos_token_id": 1,
"stop_strings": [],
"return_full_text": False,
"foo": "bar",
},
}
),
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/drivers/prompt/test_anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ def test_init(self):
@pytest.mark.parametrize("use_native_tools", [True, False])
def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools):
# Given
driver = AnthropicPromptDriver(model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools)
driver = AnthropicPromptDriver(
model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"}
)

# When
message = driver.try_run(prompt_stack)
Expand All @@ -361,6 +363,7 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools):
top_k=250,
**{"system": "system-input"} if prompt_stack.system_messages else {},
**{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
foo="bar",
)
assert isinstance(message.value[0], TextArtifact)
assert message.value[0].value == "model-output"
Expand All @@ -376,7 +379,11 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools):
def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools):
# Given
driver = AnthropicPromptDriver(
model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools
model="claude-3-haiku",
api_key="api-key",
stream=True,
use_native_tools=use_native_tools,
extra_params={"foo": "bar"},
)

# When
Expand All @@ -395,6 +402,7 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na
top_k=250,
**{"system": "system-input"} if prompt_stack.system_messages else {},
**{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
foo="bar",
)
assert event.usage.input_tokens == 5

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_
azure_deployment="deployment-id",
model="gpt-4",
use_native_tools=use_native_tools,
extra_params={"foo": "bar"},
)

# When
Expand All @@ -86,6 +87,7 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_
user=driver.user,
messages=messages,
**{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
foo="bar",
)
assert isinstance(message.value[0], TextArtifact)
assert message.value[0].value == "model-output"
Expand All @@ -104,6 +106,7 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack,
model="gpt-4",
stream=True,
use_native_tools=use_native_tools,
extra_params={"foo": "bar"},
)

# When
Expand All @@ -118,6 +121,7 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack,
stream=True,
messages=messages,
**{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
foo="bar",
)

assert isinstance(event.content, TextDeltaMessageContent)
Expand Down
Loading
Loading