diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 1f3a0bf1e..9502a187e 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,7 +1,8 @@ from typing import Any, Callable, Dict, List, Optional -from haystack import component +from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from ollama import Client @@ -63,6 +64,39 @@ def __init__( self._client = Client(host=self.url, timeout=self.timeout) + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + model=self.model, + url=self.url, + generation_kwargs=self.generation_kwargs, + timeout=self.timeout, + streaming_callback=callback_name, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) + def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: return {"role": message.role.value, "content": message.content} diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index 79d70675a..a46758df3 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole from ollama._types import ResponseError @@ -39,6 +40,42 @@ def test_init(self): assert component.generation_kwargs == {"temperature": 0.5} assert component.timeout == 5 + def test_to_dict(self): + component = OllamaChatGenerator( + model="llama2", + streaming_callback=print_streaming_chunk, + url="custom_url", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator", + "init_parameters": { + "timeout": 120, + "model": "llama2", + "url": "custom_url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + + def test_from_dict(self): + data = { + "type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator", + "init_parameters": { + "timeout": 120, + "model": "llama2", + "url": "custom_url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + component = OllamaChatGenerator.from_dict(data) + assert component.model == "llama2" + assert component.streaming_callback is print_streaming_chunk + assert component.url == "custom_url" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + def test_build_message_from_ollama_response(self): model = "some_model"