Skip to content

Commit

Permalink
Add native function calling support to Ollama Prompt Driver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 29, 2024
1 parent 9f9ac91 commit eef20a8
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 43 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`.
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, `OllamaPromptDriver`, and `CoherePromptDriver`.
- `OllamaEmbeddingDriver` for generating embeddings with Ollama.
- `GriptapeCloudKnowledgeBaseVectorStoreDriver` to query Griptape Cloud Knowledge Bases.
- `GriptapeCloudEventListenerDriver.api_key` defaults to the value in the `GT_CLOUD_API_KEY` environment variable.
Expand Down
7 changes: 5 additions & 2 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,21 +282,24 @@ agent.run(
This driver requires the `drivers-prompt-ollama` [extra](../index.md#extras).

The [OllamaPromptDriver](../../reference/griptape/drivers/prompt/ollama_prompt_driver.md) connects to the [Ollama Chat Completion API](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion).
This driver uses [Ollama tool calling](https://ollama.com/blog/tool-support) when using [Tools](../tools/index.md).

```python
from griptape.config import StructureConfig
from griptape.drivers import OllamaPromptDriver
from griptape.tools import Calculator
from griptape.structures import Agent


agent = Agent(
config=StructureConfig(
prompt_driver=OllamaPromptDriver(
model="llama3",
model="llama3.1",
),
),
tools=[Calculator()],
)
agent.run("What color is the sky at different times of the day?")
agent.run("What is (192 + 12) ^ 4")
```

### Hugging Face Hub
Expand Down
156 changes: 135 additions & 21 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from __future__ import annotations

from collections.abc import Iterator
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.artifacts import ActionArtifact, TextArtifact
from griptape.common import (
ActionCallMessageContent,
ActionResultMessageContent,
BaseMessageContent,
DeltaMessage,
ImageMessageContent,
Message,
PromptStack,
TextDeltaMessageContent,
TextMessageContent,
ToolAction,
observable,
)
from griptape.drivers import BasePromptDriver
Expand All @@ -23,6 +27,7 @@
from ollama import Client

from griptape.tokenizers.base_tokenizer import BaseTokenizer
from griptape.tools import BaseTool


@define
Expand Down Expand Up @@ -61,14 +66,15 @@ class OllamaPromptDriver(BasePromptDriver):
),
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.client.chat(**self._base_params(prompt_stack))

if isinstance(response, dict):
return Message(
content=[TextMessageContent(TextArtifact(value=response["message"]["content"]))],
content=self.__to_prompt_stack_message_content(response),
role=Message.ASSISTANT_ROLE,
)
else:
Expand All @@ -87,24 +93,132 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
def _base_params(self, prompt_stack: PromptStack) -> dict:
messages = self._prompt_stack_to_messages(prompt_stack)

return {"messages": messages, "model": self.model, "options": self.options}
return {
"messages": messages,
"model": self.model,
"options": self.options,
**(
{"tools": self.__to_ollama_tools(prompt_stack.tools)}
if prompt_stack.tools and self.use_native_tools
else {}
),
}

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
return [
{
"role": message.role,
"content": message.to_text(),
**(
{
"images": [
content.artifact.base64
for content in message.content
if isinstance(content, ImageMessageContent)
],
}
if any(isinstance(content, ImageMessageContent) for content in message.content)
else {}
),
ollama_messages = []
for message in prompt_stack.messages:
action_result_contents = message.get_content_type(ActionResultMessageContent)

# Function calls need to be handled separately from the rest of the message content
if action_result_contents:
ollama_messages.extend(
[
{
"role": self.__to_ollama_role(message, action_result_content),
"content": self.__to_ollama_message_content(action_result_content),
}
for action_result_content in action_result_contents
],
)

text_contents = message.get_content_type(TextMessageContent)
if text_contents:
ollama_messages.append({"role": self.__to_ollama_role(message), "content": message.to_text()})
else:
ollama_message: dict[str, Any] = {
"role": self.__to_ollama_role(message),
"content": message.to_text(),
}

action_call_contents = message.get_content_type(ActionCallMessageContent)
if action_call_contents:
ollama_message["tool_calls"] = [
self.__to_ollama_message_content(action_call_content)
for action_call_content in action_call_contents
]

image_contents = message.get_content_type(ImageMessageContent)
if image_contents:
ollama_message["images"] = [
self.__to_ollama_message_content(image_content) for image_content in image_contents
]

ollama_messages.append(ollama_message)

return ollama_messages

def __to_ollama_message_content(self, content: BaseMessageContent) -> str | dict:
if isinstance(content, TextMessageContent):
return content.artifact.to_text()
elif isinstance(content, ImageMessageContent):
return content.artifact.base64
elif isinstance(content, ActionCallMessageContent):
action = content.artifact.value

return {
"type": "function",
"id": action.tag,
"function": {"name": action.to_native_tool_name(), "arguments": action.input},
}
for message in prompt_stack.messages
]
elif isinstance(content, ActionResultMessageContent):
return content.artifact.to_text()
else:
raise ValueError(f"Unsupported content type: {type(content)}")

def __to_ollama_tools(self, tools: list[BaseTool]) -> list[dict]:
ollama_tools = []

for tool in tools:
for activity in tool.activities():
ollama_tool = {
"function": {
"name": tool.to_native_tool_name(activity),
"description": tool.activity_description(activity),
},
"type": "function",
}

activity_schema = tool.activity_schema(activity)
if activity_schema is not None:
ollama_tool["function"]["parameters"] = activity_schema.json_schema("Parameters Schema")[
"properties"
]["values"]

ollama_tools.append(ollama_tool)
return ollama_tools

def __to_ollama_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str:
if message.is_system():
return "system"
elif message.is_assistant():
return "assistant"
else:
if isinstance(message_content, ActionResultMessageContent):
return "tool"
else:
return "user"

def __to_prompt_stack_message_content(self, response: dict) -> list[BaseMessageContent]:
content = []
message = response["message"]

if "content" in message and message["content"]:
content.append(TextMessageContent(TextArtifact(response["message"]["content"])))
if "tool_calls" in message:
content.extend(
[
ActionCallMessageContent(
ActionArtifact(
ToolAction(
tag=tool_call["function"]["name"],
name=ToolAction.from_native_tool_name(tool_call["function"]["name"])[0],
path=ToolAction.from_native_tool_name(tool_call["function"]["name"])[1],
input=tool_call["function"]["arguments"],
),
),
)
for tool_call in message["tool_calls"]
],
)

return content
12 changes: 8 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ voyageai = {version = "^0.2.1", optional = true}
elevenlabs = {version = "^1.1.2", optional = true}
qdrant-client = { version = "^1.10.1", optional = true }
pusher = {version = "^3.3.2", optional = true}
ollama = {version = "^0.2.1", optional = true}
ollama = {version = "^0.3.0", optional = true}
duckduckgo-search = {version = "^6.1.12", optional = true}
sqlalchemy = {version = "^2.0.31", optional = true}
opentelemetry-sdk = {version = "^1.25.0", optional = true}
Expand Down
Loading

0 comments on commit eef20a8

Please sign in to comment.