Skip to content

Commit

Permalink
fix: update OllamaChatGenerator with keep_alive parameter (#119)
Browse files Browse the repository at this point in the history
* update ollama chat generator

* pin

* fix
  • Loading branch information
anakin87 authored Oct 16, 2024
1 parent 259a213 commit 52d7567
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, Union

from haystack import component, default_from_dict
from haystack.dataclasses import StreamingChunk
Expand Down Expand Up @@ -96,6 +96,7 @@ def __init__(
url: str = "http://localhost:11434",
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: int = 120,
keep_alive: Optional[Union[float, str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
):
Expand All @@ -112,6 +113,14 @@ def __init__(
[Ollama docs](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout:
The number of seconds before throwing a timeout error from the Ollama API.
:param keep_alive:
The option that controls how long the model will stay loaded into memory following the request.
If not set, it will use the default value from the Ollama (5 minutes).
The value can be set to:
- a duration string (such as "10m" or "24h")
- a number in seconds (such as 3600)
- any negative number which will keep the model loaded in memory (e.g. -1 or "-1m")
- '0' which will unload the model immediately after generating a response.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
Expand All @@ -134,6 +143,7 @@ def __init__(
url=url,
generation_kwargs=generation_kwargs,
timeout=timeout,
keep_alive=keep_alive,
streaming_callback=streaming_callback,
)

Expand Down Expand Up @@ -238,7 +248,12 @@ def run(

ollama_messages = [_convert_message_to_ollama_format(msg) for msg in messages]
response = self._client.chat(
model=self.model, messages=ollama_messages, tools=ollama_tools, stream=stream, options=generation_kwargs
model=self.model,
messages=ollama_messages,
tools=ollama_tools,
stream=stream,
keep_alive=self.keep_alive,
options=generation_kwargs,
)

if stream:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ extra-dependencies = [
"fastapi",
# Tools support
"jsonschema",
"ollama-haystack>=1.0.0",
"ollama-haystack>=1.1.0",
# LLMMetadataExtractor dependencies
"amazon-bedrock-haystack>=1.0.2",
"google-vertex-haystack>=2.0.0",
Expand Down
8 changes: 8 additions & 0 deletions test/components/generators/ollama/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,15 @@ def test_init_default(self):
assert component.timeout == 120
assert component.streaming_callback is None
assert component.tools is None
assert component.keep_alive is None

def test_init(self, tools):
component = OllamaChatGenerator(
model="llama2",
url="http://my-custom-endpoint:11434",
generation_kwargs={"temperature": 0.5},
timeout=5,
keep_alive="10m",
streaming_callback=print_streaming_chunk,
tools=tools,
)
Expand All @@ -120,6 +122,7 @@ def test_init(self, tools):
assert component.url == "http://my-custom-endpoint:11434"
assert component.generation_kwargs == {"temperature": 0.5}
assert component.timeout == 5
assert component.keep_alive == "10m"
assert component.streaming_callback is print_streaming_chunk
assert component.tools == tools

Expand All @@ -143,6 +146,7 @@ def test_to_dict(self):
url="custom_url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
tools=[tool],
keep_alive="5m",
)
data = component.to_dict()
assert data == {
Expand All @@ -152,6 +156,7 @@ def test_to_dict(self):
"model": "llama2",
"url": "custom_url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"keep_alive": "5m",
"generation_kwargs": {
"max_tokens": 10,
"some_test_param": "test-params",
Expand Down Expand Up @@ -185,6 +190,7 @@ def test_from_dict(self):
"timeout": 120,
"model": "llama2",
"url": "custom_url",
"keep_alive": "5m",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {
"max_tokens": 10,
Expand All @@ -208,6 +214,7 @@ def test_from_dict(self):
assert component.model == "llama2"
assert component.streaming_callback is print_streaming_chunk
assert component.url == "custom_url"
assert component.keep_alive == "5m"
assert component.generation_kwargs == {
"max_tokens": 10,
"some_test_param": "test-params",
Expand Down Expand Up @@ -303,6 +310,7 @@ def test_run(self, mock_client):
stream=False,
tools=None,
options={},
keep_alive=None,
)

assert "replies" in result
Expand Down

0 comments on commit 52d7567

Please sign in to comment.