Skip to content

Commit

Permalink
Ollama
Browse files Browse the repository at this point in the history
* plus added the `options` parameter to the ollama `chat` call
  • Loading branch information
leila-messallem committed Dec 13, 2024
1 parent 6288907 commit a362fd3
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 34 deletions.
58 changes: 38 additions & 20 deletions src/neo4j_graphrag/llm/ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Iterable, Optional

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError

from .base import LLMInterface
from .types import LLMResponse
from .types import LLMResponse, SystemMessage, UserMessage, MessageList

try:
import ollama
from ollama import Message
except ImportError:
ollama = None


class OllamaLLM(LLMInterface):
def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
try:
import ollama
except ImportError:
if ollama is None:
raise ImportError(
"Could not import ollama Python client. "
"Please install it with `pip install ollama`."
)
super().__init__(model_name, model_params, **kwargs)
super().__init__(model_name, model_params, system_instruction, **kwargs)
self.ollama = ollama
self.client = ollama.Client(
**kwargs,
Expand All @@ -43,32 +50,43 @@ def __init__(
**kwargs,
)

def invoke(self, input: str) -> LLMResponse:
def get_messages(
self, input: str, chat_history: Optional[list[Any]] = None
) -> Iterable[Message]:
messages = []
if self.system_instruction:
messages.append(SystemMessage(content=self.system_instruction).model_dump())
if chat_history:
try:
MessageList(messages=chat_history)
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(chat_history)
messages.append(UserMessage(content=input).model_dump())
return messages

def invoke(
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
try:
response = self.client.chat(
model=self.model_name,
messages=[
{
"role": "user",
"content": input,
},
],
messages=self.get_messages(input, chat_history),
options=self.model_params,
)
content = response.message.content or ""
return LLMResponse(content=content)
except self.ollama.ResponseError as e:
raise LLMGenerationError(e)

async def ainvoke(self, input: str) -> LLMResponse:
async def ainvoke(
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
try:
response = await self.async_client.chat(
model=self.model_name,
messages=[
{
"role": "user",
"content": input,
},
],
messages=self.get_messages(input, chat_history),
options=self.model_params,
)
content = response.message.content or ""
return LLMResponse(content=content)
Expand Down
117 changes: 103 additions & 14 deletions tests/unit/llm/test_ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,124 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from unittest.mock import MagicMock, Mock, patch

import ollama
import pytest
from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm import LLMResponse
from neo4j_graphrag.llm.ollama_llm import OllamaLLM


def get_mock_ollama() -> MagicMock:
mock = MagicMock()
mock.ResponseError = ollama.ResponseError
return mock


@patch("builtins.__import__", side_effect=ImportError)
def test_ollama_llm_missing_dependency(mock_import: Mock) -> None:
@patch("neo4j_graphrag.llm.ollama_llm.ollama", None)
def test_ollama_llm_missing_dependency() -> None:
with pytest.raises(ImportError):
OllamaLLM(model_name="gpt-4o")


@patch("builtins.__import__")
def test_ollama_llm_happy_path(mock_import: Mock) -> None:
mock_ollama = get_mock_ollama()
mock_import.return_value = mock_ollama
@patch("neo4j_graphrag.llm.ollama_llm.ollama")
def test_ollama_llm_happy_path(mock_ollama: Mock) -> None:
mock_ollama.Client.return_value.chat.return_value = MagicMock(
message=MagicMock(content="ollama chat response"),
)
model = "gpt"
model_params = {"temperature": 0.3}
system_instruction = "You are a helpful assistant."
question = "What is graph RAG?"
llm = OllamaLLM(
model,
model_params=model_params,
system_instruction=system_instruction,
)

res = llm.invoke(question)
assert isinstance(res, LLMResponse)
assert res.content == "ollama chat response"
messages = [
{"role": "system", "content": system_instruction},
{"role": "user", "content": question},
]
llm.client.chat.assert_called_once_with(
model=model, messages=messages, options=model_params
)


@patch("neo4j_graphrag.llm.ollama_llm.ollama")
def test_ollama_invoke_with_chat_history_happy_path(mock_ollama: Mock) -> None:
mock_ollama.Client.return_value.chat.return_value = MagicMock(
message=MagicMock(content="ollama chat response"),
)
model = "gpt"
model_params = {"temperature": 0.3}
system_instruction = "You are a helpful assistant."
llm = OllamaLLM(
model,
model_params=model_params,
system_instruction=system_instruction,
)
chat_history = [
{"role": "user", "content": "When does the sun come up in the summer?"},
{"role": "assistant", "content": "Usually around 6am."},
]
question = "What about next season?"

response = llm.invoke(question, chat_history)
assert response.content == "ollama chat response"
messages = [{"role": "system", "content": system_instruction}]
messages.extend(chat_history)
messages.append({"role": "user", "content": question})
llm.client.chat.assert_called_once_with(
model=model, messages=messages, options=model_params
)


@patch("neo4j_graphrag.llm.ollama_llm.ollama")
def test_ollama_invoke_with_chat_history_validation_error(
mock_ollama: Mock,
) -> None:
mock_ollama.Client.return_value.chat.return_value = MagicMock(
message=MagicMock(content="ollama chat response"),
)
llm = OllamaLLM(model_name="gpt")
mock_ollama.ResponseError = ollama.ResponseError
model = "gpt"
model_params = {"temperature": 0.3}
system_instruction = "You are a helpful assistant."
llm = OllamaLLM(
model,
model_params=model_params,
system_instruction=system_instruction,
)
chat_history = [
{"role": "human", "content": "When does the sun come up in the summer?"},
{"role": "assistant", "content": "Usually around 6am."},
]
question = "What about next season?"

with pytest.raises(LLMGenerationError) as exc_info:
llm.invoke(question, chat_history)
assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value)


@pytest.mark.asyncio
@patch("neo4j_graphrag.llm.ollama_llm.ollama")
async def test_ollama_ainvoke_happy_path(mock_ollama: Mock) -> None:
async def mock_chat_async(*args: Any, **kwargs: Any) -> MagicMock:
return MagicMock(
message=MagicMock(content="ollama chat response"),
)

mock_ollama.AsyncClient.return_value.chat = mock_chat_async
model = "gpt"
model_params = {"temperature": 0.3}
system_instruction = "You are a helpful assistant."
question = "What is graph RAG?"
llm = OllamaLLM(
model,
model_params=model_params,
system_instruction=system_instruction,
)

res = llm.invoke("my text")
res = await llm.ainvoke(question)
assert isinstance(res, LLMResponse)
assert res.content == "ollama chat response"

0 comments on commit a362fd3

Please sign in to comment.