diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a06a20d5..74ad21294 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added -- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`. +- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, `OllamaPromptDriver`, and `CoherePromptDriver`. - `OllamaEmbeddingDriver` for generating embeddings with Ollama. - `GriptapeCloudKnowledgeBaseVectorStoreDriver` to query Griptape Cloud Knowledge Bases. - `GriptapeCloudEventListenerDriver.api_key` defaults to the value in the `GT_CLOUD_API_KEY` environment variable. diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 8be1a7c3d..ab749bf7c 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -282,21 +282,24 @@ agent.run( This driver requires the `drivers-prompt-ollama` [extra](../index.md#extras). The [OllamaPromptDriver](../../reference/griptape/drivers/prompt/ollama_prompt_driver.md) connects to the [Ollama Chat Completion API](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion). +This driver uses [Ollama tool calling](https://ollama.com/blog/tool-support) when using [Tools](../tools/index.md). ```python from griptape.config import StructureConfig from griptape.drivers import OllamaPromptDriver +from griptape.tools import Calculator from griptape.structures import Agent agent = Agent( config=StructureConfig( prompt_driver=OllamaPromptDriver( - model="llama3", + model="llama3.1", ), ), + tools=[Calculator()], ) -agent.run("What color is the sky at different times of the day?") +agent.run("What is (192 + 12) ^ 4") ``` ### Hugging Face Hub diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index ea4f8b344..70d4ce89a 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -1,18 +1,22 @@ from __future__ import annotations from collections.abc import Iterator -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from attrs import Factory, define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import ActionArtifact, TextArtifact from griptape.common import ( + ActionCallMessageContent, + ActionResultMessageContent, + BaseMessageContent, DeltaMessage, ImageMessageContent, Message, PromptStack, TextDeltaMessageContent, TextMessageContent, + ToolAction, observable, ) from griptape.drivers import BasePromptDriver @@ -23,6 +27,7 @@ from ollama import Client from griptape.tokenizers.base_tokenizer import BaseTokenizer + from griptape.tools import BaseTool @define @@ -61,6 +66,7 @@ class OllamaPromptDriver(BasePromptDriver): ), kw_only=True, ) + use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) @observable def try_run(self, prompt_stack: PromptStack) -> Message: @@ -68,7 +74,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: if isinstance(response, dict): return Message( - content=[TextMessageContent(TextArtifact(value=response["message"]["content"]))], + content=self.__to_prompt_stack_message_content(response), role=Message.ASSISTANT_ROLE, ) else: @@ -87,24 +93,134 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self._prompt_stack_to_messages(prompt_stack) - return {"messages": messages, "model": self.model, "options": self.options} + return { + "messages": messages, + "model": self.model, + "options": self.options, + **( + {"tools": self.__to_ollama_tools(prompt_stack.tools)} + if prompt_stack.tools + and self.use_native_tools + and not self.stream # Tool calling is only supported when not streaming + else {} + ), + } def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: - return [ - { - "role": message.role, - "content": message.to_text(), - **( - { - "images": [ - content.artifact.base64 - for content in message.content - if isinstance(content, ImageMessageContent) - ], - } - if any(isinstance(content, ImageMessageContent) for content in message.content) - else {} - ), + ollama_messages = [] + for message in prompt_stack.messages: + action_result_contents = message.get_content_type(ActionResultMessageContent) + + # Function calls need to be handled separately from the rest of the message content + if action_result_contents: + ollama_messages.extend( + [ + { + "role": self.__to_ollama_role(message, action_result_content), + "content": self.__to_ollama_message_content(action_result_content), + } + for action_result_content in action_result_contents + ], + ) + + text_contents = message.get_content_type(TextMessageContent) + if text_contents: + ollama_messages.append({"role": self.__to_ollama_role(message), "content": message.to_text()}) + else: + ollama_message: dict[str, Any] = { + "role": self.__to_ollama_role(message), + "content": message.to_text(), + } + + action_call_contents = message.get_content_type(ActionCallMessageContent) + if action_call_contents: + ollama_message["tool_calls"] = [ + self.__to_ollama_message_content(action_call_content) + for action_call_content in action_call_contents + ] + + image_contents = message.get_content_type(ImageMessageContent) + if image_contents: + ollama_message["images"] = [ + self.__to_ollama_message_content(image_content) for image_content in image_contents + ] + + ollama_messages.append(ollama_message) + + return ollama_messages + + def __to_ollama_message_content(self, content: BaseMessageContent) -> str | dict: + if isinstance(content, TextMessageContent): + return content.artifact.to_text() + elif isinstance(content, ImageMessageContent): + return content.artifact.base64 + elif isinstance(content, ActionCallMessageContent): + action = content.artifact.value + + return { + "type": "function", + "id": action.tag, + "function": {"name": action.to_native_tool_name(), "arguments": action.input}, } - for message in prompt_stack.messages - ] + elif isinstance(content, ActionResultMessageContent): + return content.artifact.to_text() + else: + raise ValueError(f"Unsupported content type: {type(content)}") + + def __to_ollama_tools(self, tools: list[BaseTool]) -> list[dict]: + ollama_tools = [] + + for tool in tools: + for activity in tool.activities(): + ollama_tool = { + "function": { + "name": tool.to_native_tool_name(activity), + "description": tool.activity_description(activity), + }, + "type": "function", + } + + activity_schema = tool.activity_schema(activity) + if activity_schema is not None: + ollama_tool["function"]["parameters"] = activity_schema.json_schema("Parameters Schema")[ + "properties" + ]["values"] + + ollama_tools.append(ollama_tool) + return ollama_tools + + def __to_ollama_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str: + if message.is_system(): + return "system" + elif message.is_assistant(): + return "assistant" + else: + if isinstance(message_content, ActionResultMessageContent): + return "tool" + else: + return "user" + + def __to_prompt_stack_message_content(self, response: dict) -> list[BaseMessageContent]: + content = [] + message = response["message"] + + if "content" in message and message["content"]: + content.append(TextMessageContent(TextArtifact(response["message"]["content"]))) + if "tool_calls" in message: + content.extend( + [ + ActionCallMessageContent( + ActionArtifact( + ToolAction( + tag=tool_call["function"]["name"], + name=ToolAction.from_native_tool_name(tool_call["function"]["name"])[0], + path=ToolAction.from_native_tool_name(tool_call["function"]["name"])[1], + input=tool_call["function"]["arguments"], + ), + ), + ) + for tool_call in message["tool_calls"] + ], + ) + + return content diff --git a/poetry.lock b/poetry.lock index 58fff8c4e..02c878fd5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2534,9 +2534,13 @@ files = [ {file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"}, {file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"}, + {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"}, + {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"}, + {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"}, + {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"}, {file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"}, {file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"}, @@ -3573,13 +3577,13 @@ files = [ [[package]] name = "ollama" -version = "0.2.1" +version = "0.3.0" description = "The official Python client for Ollama." optional = true python-versions = "<4.0,>=3.8" files = [ - {file = "ollama-0.2.1-py3-none-any.whl", hash = "sha256:b6e2414921c94f573a903d1069d682ba2fb2607070ea9e19ca4a7872f2a460ec"}, - {file = "ollama-0.2.1.tar.gz", hash = "sha256:fa316baa9a81eac3beb4affb0a17deb3008fdd6ed05b123c26306cfbe4c349b6"}, + {file = "ollama-0.3.0-py3-none-any.whl", hash = "sha256:cd7010c4e2a37d7f08f36cd35c4592b14f1ec0d1bf3df10342cd47963d81ad7a"}, + {file = "ollama-0.3.0.tar.gz", hash = "sha256:6ff493a2945ba76cdd6b7912a1cd79a45cfd9ba9120d14adeb63b2b5a7f353da"}, ] [package.dependencies] @@ -6828,4 +6832,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "20511683fde939102f4d9e331fd9ecd064ad6cc490a5bf2f6c3f4366b4e43447" +content-hash = "2b4b54981fadfbeb7fb06d2b481cdaa7b5b7c9238efd811419cda21340897ce6" diff --git a/pyproject.toml b/pyproject.toml index 8bb0ca5d4..1fb379829 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ voyageai = {version = "^0.2.1", optional = true} elevenlabs = {version = "^1.1.2", optional = true} qdrant-client = { version = "^1.10.1", optional = true } pusher = {version = "^3.3.2", optional = true} -ollama = {version = "^0.2.1", optional = true} +ollama = {version = "^0.3.0", optional = true} duckduckgo-search = {version = "^6.1.12", optional = true} sqlalchemy = {version = "^2.0.31", optional = true} opentelemetry-sdk = {version = "^1.25.0", optional = true} diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index e51da368a..797880fdc 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,17 +1,111 @@ import pytest -from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact -from griptape.common import PromptStack -from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent +from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.common import PromptStack, TextDeltaMessageContent, ToolAction from griptape.drivers import OllamaPromptDriver +from tests.mocks.mock_tool.tool import MockTool class TestOllamaPromptDriver: + OLLAMA_TOOLS = [ + { + "function": { + "description": "test description: foo", + "name": "MockTool_test", + "parameters": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + }, + }, + "type": "function", + }, + { + "function": { + "description": "test description: foo", + "name": "MockTool_test_error", + "parameters": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + }, + }, + "type": "function", + }, + { + "function": { + "description": "test description: foo", + "name": "MockTool_test_exception", + "parameters": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + }, + }, + "type": "function", + }, + { + "function": { + "description": "test description", + "name": "MockTool_test_list_output", + }, + "type": "function", + }, + { + "function": { + "description": "test description", + "name": "MockTool_test_no_schema", + }, + "type": "function", + }, + { + "function": { + "description": "test description: foo", + "name": "MockTool_test_str_output", + "parameters": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + }, + }, + "type": "function", + }, + { + "function": { + "description": "test description", + "name": "MockTool_test_without_default_memory", + "parameters": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + }, + }, + "type": "function", + }, + ] + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("ollama.Client") - mock_client.return_value.chat.return_value = {"message": {"content": "model-output"}} + mock_client.return_value.chat.return_value = { + "message": { + "content": "model-output", + "tool_calls": [ + { + "function": { + "name": "MockTool_test", + "arguments": {"foo": "bar"}, + } + } + ], + }, + } return mock_client @@ -22,12 +116,10 @@ def mock_stream_client(self, mocker): return mock_stream_client - def test_init(self): - assert OllamaPromptDriver(model="llama") - - def test_try_run(self, mock_client): - # Given + @pytest.fixture() + def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") prompt_stack.add_user_message( @@ -36,26 +128,87 @@ def test_try_run(self, mock_client): ) ) prompt_stack.add_assistant_message("assistant-input") - driver = OllamaPromptDriver(model="llama") - expected_messages = [ + prompt_stack.add_assistant_message( + ListArtifact( + [ + TextArtifact(""), + ActionArtifact(ToolAction(tag="MockTool_test", name="MockTool", path="test", input={"foo": "bar"})), + ] + ) + ) + prompt_stack.add_user_message( + ListArtifact( + [ + TextArtifact("keep-going"), + ActionArtifact( + ToolAction( + tag="MockTool_test", + name="MockTool", + path="test", + input={"foo": "bar"}, + output=TextArtifact("tool-output"), + ) + ), + ] + ) + ) + return prompt_stack + + @pytest.fixture() + def messages(self): + return [ {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, - {"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, + { + "role": "user", + "content": "user-input", + "images": ["aW1hZ2UtZGF0YQ=="], + }, {"role": "assistant", "content": "assistant-input"}, + { + "content": "", + "role": "assistant", + "tool_calls": [ + { + "function": {"arguments": {"foo": "bar"}, "name": "MockTool_test"}, + "id": "MockTool_test", + "type": "function", + } + ], + }, + {"content": "tool-output", "role": "tool"}, + {"content": "keep-going", "role": "user"}, ] + def test_init(self): + assert OllamaPromptDriver(model="llama") + + @pytest.mark.parametrize("use_native_tools", [True]) + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + # Given + driver = OllamaPromptDriver(model="llama") + # When message = driver.try_run(prompt_stack) # Then mock_client.return_value.chat.assert_called_once_with( - messages=expected_messages, + messages=messages, model=driver.model, - options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, + options={ + "temperature": driver.temperature, + "stop": [], + "num_predict": driver.max_tokens, + }, + **{"tools": self.OLLAMA_TOOLS} if use_native_tools else {}, ) - assert message.value == "model-output" - assert message.usage.input_tokens is None - assert message.usage.output_tokens is None + assert isinstance(message.value[0], TextArtifact) + assert message.value[0].value == "model-output" + assert isinstance(message.value[1], ActionArtifact) + assert message.value[1].value.tag == "MockTool_test" + assert message.value[1].value.name == "MockTool" + assert message.value[1].value.path == "test" + assert message.value[1].value.input == {"foo": "bar"} def test_try_run_bad_response(self, mock_client): # Given