Skip to content

Commit

Permalink
fix: Ollama Chat Generator - add missing to_dict and from_dict me…
Browse files Browse the repository at this point in the history
…thods (#1110)

* add missing to_dict/from_dict and tests

* linting
  • Loading branch information
anakin87 authored Sep 26, 2024
1 parent 0f560c2 commit dee3e77
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, Callable, Dict, List, Optional

from haystack import component
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable

from ollama import Client

Expand Down Expand Up @@ -63,6 +64,39 @@ def __init__(

self._client = Client(host=self.url, timeout=self.timeout)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model=self.model,
url=self.url,
generation_kwargs=self.generation_kwargs,
timeout=self.timeout,
streaming_callback=callback_name,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)

def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]:
return {"role": message.role.value, "content": message.content}

Expand Down
37 changes: 37 additions & 0 deletions integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import Mock

import pytest
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, ChatRole
from ollama._types import ResponseError

Expand Down Expand Up @@ -39,6 +40,42 @@ def test_init(self):
assert component.generation_kwargs == {"temperature": 0.5}
assert component.timeout == 5

def test_to_dict(self):
component = OllamaChatGenerator(
model="llama2",
streaming_callback=print_streaming_chunk,
url="custom_url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
data = component.to_dict()
assert data == {
"type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
"init_parameters": {
"timeout": 120,
"model": "llama2",
"url": "custom_url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}

def test_from_dict(self):
data = {
"type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
"init_parameters": {
"timeout": 120,
"model": "llama2",
"url": "custom_url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = OllamaChatGenerator.from_dict(data)
assert component.model == "llama2"
assert component.streaming_callback is print_streaming_chunk
assert component.url == "custom_url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

def test_build_message_from_ollama_response(self):
model = "some_model"

Expand Down

0 comments on commit dee3e77

Please sign in to comment.