Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add native function calling support to Ollama Prompt Driver #1027

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`.
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, `OllamaPromptDriver`, and `CoherePromptDriver`.
- `OllamaEmbeddingDriver` for generating embeddings with Ollama.
- `GriptapeCloudKnowledgeBaseVectorStoreDriver` to query Griptape Cloud Knowledge Bases.
- `GriptapeCloudEventListenerDriver.api_key` defaults to the value in the `GT_CLOUD_API_KEY` environment variable.
Expand Down
7 changes: 5 additions & 2 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,21 +282,24 @@ agent.run(
This driver requires the `drivers-prompt-ollama` [extra](../index.md#extras).

The [OllamaPromptDriver](../../reference/griptape/drivers/prompt/ollama_prompt_driver.md) connects to the [Ollama Chat Completion API](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion).
This driver uses [Ollama tool calling](https://ollama.com/blog/tool-support) when using [Tools](../tools/index.md).

```python
from griptape.config import StructureConfig
from griptape.drivers import OllamaPromptDriver
from griptape.tools import Calculator
from griptape.structures import Agent


agent = Agent(
config=StructureConfig(
prompt_driver=OllamaPromptDriver(
model="llama3",
model="llama3.1",
),
),
tools=[Calculator()],
)
agent.run("What color is the sky at different times of the day?")
agent.run("What is (192 + 12) ^ 4")
```

### Hugging Face Hub
Expand Down
158 changes: 137 additions & 21 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from __future__ import annotations

from collections.abc import Iterator
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.artifacts import ActionArtifact, TextArtifact
from griptape.common import (
ActionCallMessageContent,
ActionResultMessageContent,
BaseMessageContent,
DeltaMessage,
ImageMessageContent,
Message,
PromptStack,
TextDeltaMessageContent,
TextMessageContent,
ToolAction,
observable,
)
from griptape.drivers import BasePromptDriver
Expand All @@ -23,6 +27,7 @@
from ollama import Client

from griptape.tokenizers.base_tokenizer import BaseTokenizer
from griptape.tools import BaseTool


@define
Expand Down Expand Up @@ -61,14 +66,15 @@
),
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.client.chat(**self._base_params(prompt_stack))

if isinstance(response, dict):
return Message(
content=[TextMessageContent(TextArtifact(value=response["message"]["content"]))],
content=self.__to_prompt_stack_message_content(response),
role=Message.ASSISTANT_ROLE,
)
else:
Expand All @@ -87,24 +93,134 @@
def _base_params(self, prompt_stack: PromptStack) -> dict:
messages = self._prompt_stack_to_messages(prompt_stack)

return {"messages": messages, "model": self.model, "options": self.options}
return {
"messages": messages,
"model": self.model,
"options": self.options,
**(
{"tools": self.__to_ollama_tools(prompt_stack.tools)}
if prompt_stack.tools
and self.use_native_tools
and not self.stream # Tool calling is only supported when not streaming
else {}
),
}

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
return [
{
"role": message.role,
"content": message.to_text(),
**(
{
"images": [
content.artifact.base64
for content in message.content
if isinstance(content, ImageMessageContent)
],
}
if any(isinstance(content, ImageMessageContent) for content in message.content)
else {}
),
ollama_messages = []
for message in prompt_stack.messages:
action_result_contents = message.get_content_type(ActionResultMessageContent)

# Function calls need to be handled separately from the rest of the message content
if action_result_contents:
ollama_messages.extend(
[
{
"role": self.__to_ollama_role(message, action_result_content),
"content": self.__to_ollama_message_content(action_result_content),
}
for action_result_content in action_result_contents
],
)

text_contents = message.get_content_type(TextMessageContent)
if text_contents:
ollama_messages.append({"role": self.__to_ollama_role(message), "content": message.to_text()})
else:
ollama_message: dict[str, Any] = {
"role": self.__to_ollama_role(message),
"content": message.to_text(),
}

action_call_contents = message.get_content_type(ActionCallMessageContent)
if action_call_contents:
ollama_message["tool_calls"] = [
self.__to_ollama_message_content(action_call_content)
for action_call_content in action_call_contents
]

image_contents = message.get_content_type(ImageMessageContent)
if image_contents:
ollama_message["images"] = [
self.__to_ollama_message_content(image_content) for image_content in image_contents
]

ollama_messages.append(ollama_message)

return ollama_messages

def __to_ollama_message_content(self, content: BaseMessageContent) -> str | dict:
if isinstance(content, TextMessageContent):
return content.artifact.to_text()

Check warning on line 154 in griptape/drivers/prompt/ollama_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/ollama_prompt_driver.py#L154

Added line #L154 was not covered by tests
elif isinstance(content, ImageMessageContent):
return content.artifact.base64
elif isinstance(content, ActionCallMessageContent):
action = content.artifact.value

return {
"type": "function",
"id": action.tag,
"function": {"name": action.to_native_tool_name(), "arguments": action.input},
}
for message in prompt_stack.messages
]
elif isinstance(content, ActionResultMessageContent):
return content.artifact.to_text()
else:
raise ValueError(f"Unsupported content type: {type(content)}")

Check warning on line 168 in griptape/drivers/prompt/ollama_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/ollama_prompt_driver.py#L168

Added line #L168 was not covered by tests

def __to_ollama_tools(self, tools: list[BaseTool]) -> list[dict]:
ollama_tools = []

for tool in tools:
for activity in tool.activities():
ollama_tool = {
"function": {
"name": tool.to_native_tool_name(activity),
"description": tool.activity_description(activity),
},
"type": "function",
}

activity_schema = tool.activity_schema(activity)
if activity_schema is not None:
ollama_tool["function"]["parameters"] = activity_schema.json_schema("Parameters Schema")[
"properties"
]["values"]

ollama_tools.append(ollama_tool)
return ollama_tools

def __to_ollama_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str:
if message.is_system():
return "system"
elif message.is_assistant():
return "assistant"
else:
if isinstance(message_content, ActionResultMessageContent):
return "tool"
else:
return "user"

def __to_prompt_stack_message_content(self, response: dict) -> list[BaseMessageContent]:
content = []
message = response["message"]

if "content" in message and message["content"]:
content.append(TextMessageContent(TextArtifact(response["message"]["content"])))
if "tool_calls" in message:
content.extend(
[
ActionCallMessageContent(
ActionArtifact(
ToolAction(
tag=tool_call["function"]["name"],
name=ToolAction.from_native_tool_name(tool_call["function"]["name"])[0],
path=ToolAction.from_native_tool_name(tool_call["function"]["name"])[1],
input=tool_call["function"]["arguments"],
),
),
)
for tool_call in message["tool_calls"]
],
)

return content
12 changes: 8 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ voyageai = {version = "^0.2.1", optional = true}
elevenlabs = {version = "^1.1.2", optional = true}
qdrant-client = { version = "^1.10.1", optional = true }
pusher = {version = "^3.3.2", optional = true}
ollama = {version = "^0.2.1", optional = true}
ollama = {version = "^0.3.0", optional = true}
duckduckgo-search = {version = "^6.1.12", optional = true}
sqlalchemy = {version = "^2.0.31", optional = true}
opentelemetry-sdk = {version = "^1.25.0", optional = true}
Expand Down
Loading
Loading