Skip to content

Commit

Permalink
Fix ollama response parsing (#1406)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Dec 9, 2024
1 parent 3606156 commit b9d09e0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 19 deletions.
15 changes: 6 additions & 9 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
logger = logging.getLogger(Defaults.logging_config.logger_name)

if TYPE_CHECKING:
from ollama import Client
from ollama import ChatResponse, Client

from griptape.tokenizers.base_tokenizer import BaseTokenizer
from griptape.tools import BaseTool
Expand Down Expand Up @@ -81,13 +81,10 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.client.chat(**params)
logger.debug(response)

if isinstance(response, dict):
return Message(
content=self.__to_prompt_stack_message_content(response),
role=Message.ASSISTANT_ROLE,
)
else:
raise Exception("invalid model response")
return Message(
content=self.__to_prompt_stack_message_content(response),
role=Message.ASSISTANT_ROLE,
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
Expand Down Expand Up @@ -213,7 +210,7 @@ def __to_ollama_role(self, message: Message, message_content: Optional[BaseMessa
else:
return "user"

def __to_prompt_stack_message_content(self, response: dict) -> list[BaseMessageContent]:
def __to_prompt_stack_message_content(self, response: ChatResponse) -> list[BaseMessageContent]:
content = []
message = response["message"]

Expand Down
10 changes: 0 additions & 10 deletions tests/unit/drivers/prompt/test_ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,6 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools):
assert message.value[1].value.path == "test"
assert message.value[1].value.input == {"foo": "bar"}

def test_try_run_bad_response(self, mock_client):
# Given
prompt_stack = PromptStack()
driver = OllamaPromptDriver(model="llama")
mock_client.return_value.chat.return_value = "bad-response"

# When/Then
with pytest.raises(Exception, match="invalid model response"):
driver.try_run(prompt_stack)

def test_try_stream_run(self, mock_stream_client):
# Given
prompt_stack = PromptStack()
Expand Down

0 comments on commit b9d09e0

Please sign in to comment.