From 70a76473d2f6769cac8fa481e25fd9123f4b903e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 28 Oct 2024 16:31:15 -0700 Subject: [PATCH] Add Prompt Driver extra params field (#1291) --- CHANGELOG.md | 3 +++ .../prompt/amazon_bedrock_prompt_driver.py | 1 + ...mazon_sagemaker_jumpstart_prompt_driver.py | 1 + .../drivers/prompt/anthropic_prompt_driver.py | 1 + griptape/drivers/prompt/base_prompt_driver.py | 2 ++ .../drivers/prompt/cohere_prompt_driver.py | 1 + .../drivers/prompt/google_prompt_driver.py | 1 + .../prompt/huggingface_hub_prompt_driver.py | 13 ++++++++---- .../huggingface_pipeline_prompt_driver.py | 21 ++++++++++--------- .../drivers/prompt/ollama_prompt_driver.py | 1 + .../prompt/openai_chat_prompt_driver.py | 1 + .../test_amazon_bedrock_drivers_config.py | 2 ++ .../drivers/test_anthropic_drivers_config.py | 1 + .../test_azure_openai_drivers_config.py | 1 + .../drivers/test_cohere_drivers_config.py | 1 + .../configs/drivers/test_drivers_config.py | 1 + .../drivers/test_google_drivers_config.py | 1 + .../drivers/test_openai_driver_config.py | 1 + .../test_amazon_bedrock_prompt_driver.py | 10 +++++++-- ...mazon_sagemaker_jumpstart_prompt_driver.py | 4 +++- .../prompt/test_anthropic_prompt_driver.py | 12 +++++++++-- .../test_azure_openai_chat_prompt_driver.py | 4 ++++ .../prompt/test_cohere_prompt_driver.py | 14 +++++++++++-- .../prompt/test_google_prompt_driver.py | 21 +++++++++++++++---- .../test_hugging_face_hub_prompt_driver.py | 19 +++++++++++++++-- ...est_hugging_face_pipeline_prompt_driver.py | 15 +++++++++++-- .../prompt/test_ollama_prompt_driver.py | 6 ++++-- .../prompt/test_openai_chat_prompt_driver.py | 11 ++++++++-- 28 files changed, 137 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b37bf7b18..cf7fbf82f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseFileManagerDriver.load_artifact()` & `BaseFileManagerDriver.save_artifact()` for loading & saving artifacts as files. - Events `BaseChunkEvent`, `TextChunkEvent`, `ActionChunkEvent`. - `wrapt` dependency for more robust decorators. +- `BasePromptDriver.extra_params` for passing extra parameters not explicitly declared by the Driver. ### Changed @@ -32,6 +33,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - If `EventListener.handler` is None, the event will be published to the `event_listener_driver` as-is. - **BREAKING**: Moved `griptape.common.observable.observable` to `griptape.common.decorators.observable`. - **BREAKING**: `AnthropicDriversConfig` no longer bundles `VoyageAiEmbeddingDriver`. +- **BREAKING**: Removed `HuggingFaceHubPromptDriver.params`, use `HuggingFaceHubPromptDriver.extra_params` instead. +- **BREAKING**: Removed `HuggingFacePipelinePromptDriver.params`, use `HuggingFacePipelinePromptDriver.extra_params` instead. - Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`. - `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. - `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 1499b84c5..34459e1c5 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -117,6 +117,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools else {} ), + **self.extra_params, } def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]: diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index 8347f1f17..d98ac9fd4 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -105,6 +105,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "eos_token_id": self.tokenizer.tokenizer.eos_token_id, "stop_strings": self.tokenizer.stop_sequences, "return_full_text": False, + **self.extra_params, } def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 054049fe8..3341006a1 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -124,6 +124,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: else {} ), **({"system": system_message} if system_message else {}), + **self.extra_params, } def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]: diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 9af43f082..1524d7ed9 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -45,6 +45,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: An instance of `BaseTokenizer` to when calculating tokens. stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided. use_native_tools: Whether to use LLM's native function calling capabilities. Must be supported by the model. + extra_params: Extra parameters to pass to the model. """ temperature: float = field(default=0.1, metadata={"serializable": True}) @@ -54,6 +55,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 6bd8fb010..8b42b4083 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -128,6 +128,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: else {} ), **({"preamble": preamble} if preamble else {}), + **self.extra_params, } def __to_cohere_messages(self, messages: list[Message]) -> list[dict]: diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 1634bc613..2a6bdbf6d 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -145,6 +145,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, + **self.extra_params, }, ), **( diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 11ae1c145..c2c45c3ae 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -27,7 +27,6 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): Attributes: api_token: Hugging Face Hub API token. use_gpu: Use GPU during model run. - params: Custom model run parameters. model: Hugging Face Hub model name. client: Custom `InferenceApi`. tokenizer: Custom `HuggingFaceTokenizer`. @@ -35,7 +34,6 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) - params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) tokenizer: HuggingFaceTokenizer = field( default=Factory( @@ -56,7 +54,7 @@ def client(self) -> InferenceClient: @observable def try_run(self, prompt_stack: PromptStack) -> Message: prompt = self.prompt_stack_to_string(prompt_stack) - full_params = {"return_full_text": False, "max_new_tokens": self.max_tokens, **self.params} + full_params = self._base_params(prompt_stack) logger.debug((prompt, full_params)) response = self.client.text_generation( @@ -76,7 +74,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: prompt = self.prompt_stack_to_string(prompt_stack) - full_params = {"return_full_text": False, "max_new_tokens": self.max_tokens, "stream": True, **self.params} + full_params = {**self._base_params(prompt_stack), "stream": True} logger.debug((prompt, full_params)) response = self.client.text_generation(prompt, **full_params) @@ -95,6 +93,13 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) + def _base_params(self, prompt_stack: PromptStack) -> dict: + return { + "return_full_text": False, + "max_new_tokens": self.max_tokens, + **self.extra_params, + } + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 87e20b8ec..a197523df 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -26,13 +26,11 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): """Hugging Face Pipeline Prompt Driver. Attributes: - params: Custom model run parameters. model: Hugging Face Hub model name. """ max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) - params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), @@ -56,20 +54,15 @@ def pipeline(self) -> TextGenerationPipeline: @observable def try_run(self, prompt_stack: PromptStack) -> Message: messages = self._prompt_stack_to_messages(prompt_stack) + full_params = self._base_params(prompt_stack) logger.debug( ( messages, - {"max_new_tokens": self.max_tokens, "temperature": self.temperature, "do_sample": True, **self.params}, + full_params, ) ) - result = self.pipeline( - messages, - max_new_tokens=self.max_tokens, - temperature=self.temperature, - do_sample=True, - **self.params, - ) + result = self.pipeline(messages, **full_params) logger.debug(result) if isinstance(result, list): @@ -96,6 +89,14 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) + def _base_params(self, prompt_stack: PromptStack) -> dict: + return { + "max_new_tokens": self.max_tokens, + "temperature": self.temperature, + "do_sample": True, + **self.extra_params, + } + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 01db026ff..ca6813c23 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -116,6 +116,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: and not self.stream # Tool calling is only supported when not streaming else {} ), + **self.extra_params, } def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 8e87be7d5..8a1098b4a 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -154,6 +154,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), **({"stream_options": {"include_usage": True}} if self.stream else {}), + **self.extra_params, } if self.response_format is not None: diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index b061e5b67..bdde495de 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -57,6 +57,7 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "extra_params": {}, }, "vector_store_driver": { "embedding_driver": { @@ -117,6 +118,7 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "extra_params": {}, }, "vector_store_driver": { "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index bfc5b06f9..bd232283f 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -25,6 +25,7 @@ def test_to_dict(self, config): "top_p": 0.999, "top_k": 250, "use_native_tools": True, + "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, "image_query_driver": { diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 6c6d49483..a4af1692f 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -35,6 +35,7 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "extra_params": {}, }, "conversation_memory_driver": { "type": "LocalConversationMemoryDriver", diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index b828fef41..0032b6e7d 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -27,6 +27,7 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, + "extra_params": {}, }, "embedding_driver": { "type": "CohereEmbeddingDriver", diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index 74055f4e5..a1138769b 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,6 +18,7 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, + "extra_params": {}, }, "conversation_memory_driver": { "type": "LocalConversationMemoryDriver", diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index ab695369e..8eacda7c6 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,6 +25,7 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, + "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, "image_query_driver": {"type": "DummyImageQueryDriver"}, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index a3cca9608..09ceccfdc 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -27,6 +27,7 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "extra_params": {}, }, "conversation_memory_driver": { "type": "LocalConversationMemoryDriver", diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index 40b0a8a0e..6d0dd757e 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -359,7 +359,9 @@ def messages(self): @pytest.mark.parametrize("use_native_tools", [True, False]) def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): # Given - driver = AmazonBedrockPromptDriver(model="ai21.j2", use_native_tools=use_native_tools) + driver = AmazonBedrockPromptDriver( + model="ai21.j2", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + ) # When message = driver.try_run(prompt_stack) @@ -376,6 +378,7 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): if use_native_tools else {} ), + foo="bar", ) assert isinstance(message.value[0], TextArtifact) assert message.value[0].value == "model-output" @@ -390,7 +393,9 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): @pytest.mark.parametrize("use_native_tools", [True, False]) def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools): # Given - driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True, use_native_tools=use_native_tools) + driver = AmazonBedrockPromptDriver( + model="ai21.j2", stream=True, use_native_tools=use_native_tools, extra_params={"foo": "bar"} + ) # When stream = driver.try_stream(prompt_stack) @@ -408,6 +413,7 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ if prompt_stack.tools and use_native_tools else {} ), + foo="bar", ) event = next(stream) diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py index c894524f5..c7b0682c2 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -37,7 +37,7 @@ def test_init(self): def test_try_run(self, mock_client): # Given - driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") + driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model", extra_params={"foo": "bar"}) prompt_stack = PromptStack() prompt_stack.add_user_message("prompt-stack") @@ -61,6 +61,7 @@ def test_try_run(self, mock_client): "eos_token_id": 1, "stop_strings": [], "return_full_text": False, + "foo": "bar", }, } ), @@ -91,6 +92,7 @@ def test_try_run(self, mock_client): "eos_token_id": 1, "stop_strings": [], "return_full_text": False, + "foo": "bar", }, } ), diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 40c983f7d..2b84b5a17 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -345,7 +345,9 @@ def test_init(self): @pytest.mark.parametrize("use_native_tools", [True, False]) def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): # Given - driver = AnthropicPromptDriver(model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools) + driver = AnthropicPromptDriver( + model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + ) # When message = driver.try_run(prompt_stack) @@ -361,6 +363,7 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + foo="bar", ) assert isinstance(message.value[0], TextArtifact) assert message.value[0].value == "model-output" @@ -376,7 +379,11 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): # Given driver = AnthropicPromptDriver( - model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools + model="claude-3-haiku", + api_key="api-key", + stream=True, + use_native_tools=use_native_tools, + extra_params={"foo": "bar"}, ) # When @@ -395,6 +402,7 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + foo="bar", ) assert event.usage.input_tokens == 5 diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index dc0b54b0a..f9ac6bd59 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -74,6 +74,7 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ azure_deployment="deployment-id", model="gpt-4", use_native_tools=use_native_tools, + extra_params={"foo": "bar"}, ) # When @@ -86,6 +87,7 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ user=driver.user, messages=messages, **{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + foo="bar", ) assert isinstance(message.value[0], TextArtifact) assert message.value[0].value == "model-output" @@ -104,6 +106,7 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, model="gpt-4", stream=True, use_native_tools=use_native_tools, + extra_params={"foo": "bar"}, ) # When @@ -118,6 +121,7 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, stream=True, messages=messages, **{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + foo="bar", ) assert isinstance(event.content, TextDeltaMessageContent) diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index e110d9469..a42a899f1 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -136,7 +136,9 @@ def test_init(self): @pytest.mark.parametrize("use_native_tools", [True, False]) def test_try_run(self, mock_client, prompt_stack, use_native_tools): # Given - driver = CoherePromptDriver(model="command", api_key="api-key", use_native_tools=use_native_tools) + driver = CoherePromptDriver( + model="command", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + ) # When message = driver.try_run(prompt_stack) @@ -171,6 +173,7 @@ def test_try_run(self, mock_client, prompt_stack, use_native_tools): ], stop_sequences=[], temperature=0.1, + foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -187,7 +190,13 @@ def test_try_run(self, mock_client, prompt_stack, use_native_tools): @pytest.mark.parametrize("use_native_tools", [True, False]) def test_try_stream_run(self, mock_stream_client, prompt_stack, use_native_tools): # Given - driver = CoherePromptDriver(model="command", api_key="api-key", stream=True, use_native_tools=use_native_tools) + driver = CoherePromptDriver( + model="command", + api_key="api-key", + stream=True, + use_native_tools=use_native_tools, + extra_params={"foo": "bar"}, + ) # When stream = driver.try_stream(prompt_stack) @@ -223,6 +232,7 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, use_native_tools ], stop_sequences=[], temperature=0.1, + foo="bar", ) assert isinstance(event.content, TextDeltaMessageContent) diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 776664eb1..72cf51d03 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -169,7 +169,12 @@ def test_init(self): def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools): # Given driver = GooglePromptDriver( - model="gemini-pro", api_key="api-key", top_p=0.5, top_k=50, use_native_tools=use_native_tools + model="gemini-pro", + api_key="api-key", + top_p=0.5, + top_k=50, + use_native_tools=use_native_tools, + extra_params={"max_output_tokens": 10}, ) # When @@ -185,7 +190,9 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native call_args = mock_generative_model.return_value.generate_content.call_args assert messages == call_args.args[0] generation_config = call_args.kwargs["generation_config"] - assert generation_config == GenerationConfig(temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[]) + assert generation_config == GenerationConfig( + temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[], max_output_tokens=10 + ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] assert [ @@ -206,7 +213,13 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, use_native_tools): # Given driver = GooglePromptDriver( - model="gemini-pro", api_key="api-key", stream=True, top_p=0.5, top_k=50, use_native_tools=use_native_tools + model="gemini-pro", + api_key="api-key", + stream=True, + top_p=0.5, + top_k=50, + use_native_tools=use_native_tools, + extra_params={"max_output_tokens": 10}, ) # When @@ -225,7 +238,7 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, assert messages == call_args.args[0] assert call_args.kwargs["stream"] is True assert call_args.kwargs["generation_config"] == GenerationConfig( - temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[] + temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[], max_output_tokens=10 ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 1a4e1b25b..4b7aa4d13 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -47,25 +47,40 @@ def test_init(self): def test_try_run(self, prompt_stack, mock_client): # Given - driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id") + driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id", extra_params={"foo": "bar"}) # When message = driver.try_run(prompt_stack) # Then + mock_client.text_generation.assert_called_once_with( + "foo\n\nUser: bar", + return_full_text=False, + max_new_tokens=250, + foo="bar", + ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 def test_try_stream(self, prompt_stack, mock_client_stream): # Given - driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id", stream=True) + driver = HuggingFaceHubPromptDriver( + api_token="api-token", model="repo-id", stream=True, extra_params={"foo": "bar"} + ) # When stream = driver.try_stream(prompt_stack) event = next(stream) # Then + mock_client_stream.text_generation.assert_called_once_with( + "foo\n\nUser: bar", + return_full_text=False, + max_new_tokens=250, + foo="bar", + stream=True, + ) assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index 0ece6c976..ac607afc3 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -33,17 +33,28 @@ def prompt_stack(self): prompt_stack.add_assistant_message("assistant-input") return prompt_stack + @pytest.fixture() + def messages(self): + return [ + {"role": "system", "content": "system-input"}, + {"role": "user", "content": "user-input"}, + {"role": "assistant", "content": "assistant-input"}, + ] + def test_init(self): assert HuggingFacePipelinePromptDriver(model="gpt2", max_tokens=42) - def test_try_run(self, prompt_stack): + def test_try_run(self, prompt_stack, messages, mock_pipeline): # Given - driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) + driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, extra_params={"foo": "bar"}) # When message = driver.try_run(prompt_stack) # Then + mock_pipeline.return_value.assert_called_once_with( + messages, max_new_tokens=42, temperature=0.1, do_sample=True, foo="bar" + ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index e4e9c4712..1ee075809 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -205,7 +205,7 @@ def test_init(self): @pytest.mark.parametrize("use_native_tools", [True]) def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): # Given - driver = OllamaPromptDriver(model="llama") + driver = OllamaPromptDriver(model="llama", extra_params={"foo": "bar"}) # When message = driver.try_run(prompt_stack) @@ -220,6 +220,7 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): "num_predict": driver.max_tokens, }, **{"tools": self.OLLAMA_TOOLS} if use_native_tools else {}, + foo="bar", ) assert isinstance(message.value[0], TextArtifact) assert message.value[0].value == "model-output" @@ -256,7 +257,7 @@ def test_try_stream_run(self, mock_stream_client): {"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, {"role": "assistant", "content": "assistant-input"}, ] - driver = OllamaPromptDriver(model="llama", stream=True) + driver = OllamaPromptDriver(model="llama", stream=True, extra_params={"foo": "bar"}) # When text_artifact = next(driver.try_stream(prompt_stack)) @@ -267,6 +268,7 @@ def test_try_stream_run(self, mock_stream_client): model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, stream=True, + foo="bar", ) if isinstance(text_artifact, TextDeltaMessageContent): assert text_artifact.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 35f31a2a2..f61df782e 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -343,7 +343,9 @@ def test_init(self): def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_native_tools): # Given driver = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools + model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, + use_native_tools=use_native_tools, + extra_params={"foo": "bar"}, ) # When @@ -357,6 +359,7 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ messages=messages, seed=driver.seed, **{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + foo="bar", ) assert isinstance(message.value[0], TextArtifact) assert message.value[0].value == "model-output" @@ -439,7 +442,10 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create, def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools): # Given driver = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True, use_native_tools=use_native_tools + model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, + stream=True, + use_native_tools=use_native_tools, + extra_params={"foo": "bar"}, ) # When @@ -456,6 +462,7 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, seed=driver.seed, stream_options={"include_usage": True}, **{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + foo="bar", ) assert isinstance(event.content, TextDeltaMessageContent)