diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a68e193..816f8a00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ ## Next +### Added +- Support for conversations with message history, including a new `message_history` parameter for LLM interactions. +- Ability to include system instructions and override them for specific invocations. +- Summarization of chat history to enhance query embedding and context handling. + +### Changed +- Updated LLM implementations to handle message history consistently across providers. + ## 1.3.0 ### Added diff --git a/docs/README.md b/docs/README.md index fd8f5ffd..d014faa4 100644 --- a/docs/README.md +++ b/docs/README.md @@ -2,9 +2,11 @@ Building the docs requires Python 3.8.1+ -Ensure the dev dependencies in `pyproject.toml` are installed. +1. Ensure the dev dependencies in `pyproject.toml` are installed. -From the root directory, run the Makefile: +2. Add your changes to the appropriate `.rst` source file in `docs/source` directory. + +3. From the root directory, run the Makefile: ``` make -C docs html diff --git a/docs/source/types.rst b/docs/source/types.rst index 253994ad..adf3c9b6 100644 --- a/docs/source/types.rst +++ b/docs/source/types.rst @@ -28,6 +28,12 @@ LLMResponse .. autoclass:: neo4j_graphrag.llm.types.LLMResponse +LLMMessage +=========== + +.. autoclass:: neo4j_graphrag.llm.types.LLMMessage + + RagResultModel ============== diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index e035048c..322d8d23 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -1,21 +1,34 @@ import random import string -from typing import Any +from typing import Any, Optional from neo4j_graphrag.llm import LLMInterface, LLMResponse +from neo4j_graphrag.llm.types import LLMMessage class CustomLLM(LLMInterface): - def __init__(self, model_name: str, **kwargs: Any): + def __init__( + self, model_name: str, system_instruction: Optional[str] = None, **kwargs: Any + ): super().__init__(model_name, **kwargs) - def invoke(self, input: str) -> LLMResponse: + def invoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: content: str = ( self.model_name + ": " + "".join(random.choices(string.ascii_letters, k=30)) ) return LLMResponse(content=content) - async def ainvoke(self, input: str) -> LLMResponse: + async def ainvoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: raise NotImplementedError() diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 3d65fcb5..48a864e4 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -27,6 +27,7 @@ from neo4j_graphrag.generation.prompts import RagTemplate from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.llm.types import LLMMessage from neo4j_graphrag.retrievers.base import Retriever from neo4j_graphrag.types import RetrieverResult @@ -83,6 +84,7 @@ def __init__( def search( self, query_text: str = "", + message_history: Optional[list[LLMMessage]] = None, examples: str = "", retriever_config: Optional[dict[str, Any]] = None, return_context: bool | None = None, @@ -99,14 +101,15 @@ def search( Args: - query_text (str): The user question + query_text (str): The user question. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. examples (str): Examples added to the LLM prompt. - retriever_config (Optional[dict]): Parameters passed to the retriever + retriever_config (Optional[dict]): Parameters passed to the retriever. search method; e.g.: top_k - return_context (bool): Whether to append the retriever result to the final result (default: False) + return_context (bool): Whether to append the retriever result to the final result (default: False). Returns: - RagResultModel: The LLM-generated answer + RagResultModel: The LLM-generated answer. """ if return_context is None: @@ -124,9 +127,9 @@ def search( ) except ValidationError as e: raise SearchValidationError(e.errors()) - query_text = validated_data.query_text + query = self.build_query(validated_data.query_text, message_history) retriever_result: RetrieverResult = self.retriever.search( - query_text=query_text, **validated_data.retriever_config + query_text=query, **validated_data.retriever_config ) context = "\n".join(item.content for item in retriever_result.items) prompt = self.prompt_template.format( @@ -134,8 +137,44 @@ def search( ) logger.debug(f"RAG: retriever_result={retriever_result}") logger.debug(f"RAG: prompt={prompt}") - answer = self.llm.invoke(prompt) + answer = self.llm.invoke(prompt, message_history) result: dict[str, Any] = {"answer": answer.content} if return_context: result["retriever_result"] = retriever_result return RagResultModel(**result) + + def build_query( + self, query_text: str, message_history: Optional[list[LLMMessage]] = None + ) -> str: + summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words." + if message_history: + summarization_prompt = self.chat_summary_prompt( + message_history=message_history + ) + summary = self.llm.invoke( + input=summarization_prompt, + system_instruction=summary_system_message, + ).content + return self.conversation_prompt(summary=summary, current_query=query_text) + return query_text + + def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str: + message_list = [ + ": ".join([f"{value}" for _, value in message.items()]) + for message in message_history + ] + history = "\n".join(message_list) + return f""" +Summarize the message history: + +{history} +""" + + def conversation_prompt(self, summary: str, current_query: str) -> str: + return f""" +Message Summary: +{summary} + +Current Query: +{current_query} +""" diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index e8f551cb..04d6555e 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -13,11 +13,22 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Optional +from typing import Any, Iterable, Optional, TYPE_CHECKING, cast + +from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import LLMResponse +from neo4j_graphrag.llm.types import ( + BaseMessage, + LLMMessage, + LLMResponse, + MessageList, + UserMessage, +) + +if TYPE_CHECKING: + from anthropic.types.message_param import MessageParam class AnthropicLLM(LLMInterface): @@ -26,6 +37,7 @@ class AnthropicLLM(LLMInterface): Args: model_name (str, optional): Name of the LLM to use. Defaults to "gemini-1.5-flash-001". model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None. + system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. Raises: @@ -49,6 +61,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): try: @@ -58,55 +71,86 @@ def __init__( """Could not import Anthropic Python client. Please install it with `pip install "neo4j-graphrag[anthropic]"`.""" ) - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, system_instruction) self.anthropic = anthropic self.client = anthropic.Anthropic(**kwargs) self.async_client = anthropic.AsyncAnthropic(**kwargs) - def invoke(self, input: str) -> LLMResponse: + def get_messages( + self, input: str, message_history: Optional[list[LLMMessage]] = None + ) -> Iterable[MessageParam]: + messages: list[dict[str, str]] = [] + if message_history: + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + messages.extend(cast(Iterable[dict[str, Any]], message_history)) + messages.append(UserMessage(content=input).model_dump()) + return messages # type: ignore + + def invoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from the LLM. """ try: + messages = self.get_messages(input, message_history) + system_message = ( + system_instruction + if system_instruction is not None + else self.system_instruction + ) response = self.client.messages.create( model=self.model_name, - messages=[ - { - "role": "user", - "content": input, - } - ], + system=system_message, # type: ignore + messages=messages, **self.model_params, ) - return LLMResponse(content=response.content) + return LLMResponse(content=response.content) # type: ignore except self.anthropic.APIError as e: raise LLMGenerationError(e) - async def ainvoke(self, input: str) -> LLMResponse: + async def ainvoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from the LLM. """ try: + messages = self.get_messages(input, message_history) + system_message = ( + system_instruction + if system_instruction is not None + else self.system_instruction + ) response = await self.async_client.messages.create( model=self.model_name, - messages=[ - { - "role": "user", - "content": input, - } - ], + system=system_message, # type: ignore + messages=messages, **self.model_params, ) - return LLMResponse(content=response.content) + return LLMResponse(content=response.content) # type: ignore except self.anthropic.APIError as e: raise LLMGenerationError(e) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 3d98423b..eab3eb4f 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from typing import Any, Optional -from .types import LLMResponse +from .types import LLMMessage, LLMResponse class LLMInterface(ABC): @@ -26,6 +26,7 @@ class LLMInterface(ABC): Args: model_name (str): The name of the language model. model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None. + system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. """ @@ -33,17 +34,26 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): self.model_name = model_name self.model_params = model_params or {} + self.system_instruction = system_instruction @abstractmethod - def invoke(self, input: str) -> LLMResponse: + def invoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Sends a text input to the LLM and retrieves a response. Args: - input (str): Text sent to the LLM + input (str): Text sent to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from the LLM. @@ -53,11 +63,18 @@ def invoke(self, input: str) -> LLMResponse: """ @abstractmethod - async def ainvoke(self, input: str) -> LLMResponse: + async def ainvoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Asynchronously sends a text input to the LLM and retrieves a response. Args: - input (str): Text sent to the LLM + input (str): Text sent to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from the LLM. diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index aeddafd3..63e54aaa 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -14,11 +14,22 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Iterable, Optional, cast +from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import LLMResponse +from neo4j_graphrag.llm.types import ( + BaseMessage, + LLMMessage, + LLMResponse, + MessageList, + SystemMessage, + UserMessage, +) + +if TYPE_CHECKING: + from cohere import ChatMessages class CohereLLM(LLMInterface): @@ -27,6 +38,7 @@ class CohereLLM(LLMInterface): Args: model_name (str, optional): Name of the LLM to use. Defaults to "gemini-1.5-flash-001". model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None. + system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. Raises: @@ -46,9 +58,9 @@ def __init__( self, model_name: str = "", model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ) -> None: - super().__init__(model_name, model_params) try: import cohere except ImportError: @@ -56,49 +68,88 @@ def __init__( """Could not import cohere python client. Please install it with `pip install "neo4j-graphrag[cohere]"`.""" ) - + super().__init__(model_name, model_params, system_instruction) self.cohere = cohere self.cohere_api_error = cohere.core.api_error.ApiError - self.client = cohere.Client(**kwargs) - self.async_client = cohere.AsyncClient(**kwargs) + self.client = cohere.ClientV2(**kwargs) + self.async_client = cohere.AsyncClientV2(**kwargs) - def invoke(self, input: str) -> LLMResponse: + def get_messages( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> ChatMessages: + messages = [] + system_message = ( + system_instruction + if system_instruction is not None + else self.system_instruction + ) + if system_message: + messages.append(SystemMessage(content=system_message).model_dump()) + if message_history: + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + messages.extend(cast(Iterable[dict[str, Any]], message_history)) + messages.append(UserMessage(content=input).model_dump()) + return messages # type: ignore + + def invoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from the LLM. """ try: + messages = self.get_messages(input, message_history, system_instruction) res = self.client.chat( - message=input, + messages=messages, model=self.model_name, ) except self.cohere_api_error as e: raise LLMGenerationError(e) return LLMResponse( - content=res.text, + content=res.message.content[0].text, ) - async def ainvoke(self, input: str) -> LLMResponse: + async def ainvoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from the LLM. """ try: - res = await self.async_client.chat( - message=input, + messages = self.get_messages(input, message_history, system_instruction) + res = self.async_client.chat( + messages=messages, model=self.model_name, ) except self.cohere_api_error as e: raise LLMGenerationError(e) return LLMResponse( - content=res.text, + content=res.message.content[0].text, ) diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index b9254ffe..a3c84759 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -15,21 +15,23 @@ from __future__ import annotations import os -from typing import Any, Optional, Union - -from ..exceptions import LLMGenerationError -from .base import LLMInterface -from .types import LLMResponse +from typing import Any, Iterable, Optional, cast +from pydantic import ValidationError + +from neo4j_graphrag.exceptions import LLMGenerationError +from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.types import ( + BaseMessage, + LLMMessage, + LLMResponse, + MessageList, + SystemMessage, + UserMessage, +) try: - from mistralai import Mistral - from mistralai.models.assistantmessage import AssistantMessage + from mistralai import Mistral, Messages from mistralai.models.sdkerror import SDKError - from mistralai.models.systemmessage import SystemMessage - from mistralai.models.toolmessage import ToolMessage - from mistralai.models.usermessage import UserMessage - - MessageType = Union[AssistantMessage, SystemMessage, ToolMessage, UserMessage] except ImportError: Mistral = None # type: ignore SDKError = None # type: ignore @@ -40,6 +42,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): """ @@ -48,6 +51,7 @@ def __init__( model_name (str): model_params (str): Parameters like temperature and such that will be passed to the chat completions endpoint + system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. kwargs: All other parameters will be passed to the Mistral client. """ @@ -56,21 +60,48 @@ def __init__( """Could not import Mistral Python client. Please install it with `pip install "neo4j-graphrag[mistralai]"`.""" ) - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, system_instruction) api_key = kwargs.pop("api_key", None) if api_key is None: api_key = os.getenv("MISTRAL_API_KEY", "") self.client = Mistral(api_key=api_key, **kwargs) - def get_messages(self, input: str) -> list[MessageType]: - return [UserMessage(content=input)] - - def invoke(self, input: str) -> LLMResponse: + def get_messages( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> list[Messages]: + messages = [] + system_message = ( + system_instruction + if system_instruction is not None + else self.system_instruction + ) + if system_message: + messages.append(SystemMessage(content=system_message).model_dump()) + if message_history: + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + messages.extend(cast(Iterable[dict[str, Any]], message_history)) + messages.append(UserMessage(content=input).model_dump()) + return cast(list[Messages], messages) + + def invoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Sends a text input to the Mistral chat completion model and returns the response's content. Args: - input (str): Text sent to the LLM + input (str): Text sent to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from MistralAI. @@ -79,9 +110,10 @@ def invoke(self, input: str) -> LLMResponse: LLMGenerationError: If anything goes wrong. """ try: + messages = self.get_messages(input, message_history, system_instruction) response = self.client.chat.complete( model=self.model_name, - messages=self.get_messages(input), + messages=messages, **self.model_params, ) content: str = "" @@ -93,12 +125,19 @@ def invoke(self, input: str) -> LLMResponse: except SDKError as e: raise LLMGenerationError(e) - async def ainvoke(self, input: str) -> LLMResponse: + async def ainvoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Asynchronously sends a text input to the MistralAI chat completion model and returns the response's content. Args: - input (str): Text sent to the LLM + input (str): Text sent to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from MistralAI. @@ -107,9 +146,10 @@ async def ainvoke(self, input: str) -> LLMResponse: LLMGenerationError: If anything goes wrong. """ try: + messages = self.get_messages(input, message_history, system_instruction) response = await self.client.chat.complete_async( model=self.model_name, - messages=self.get_messages(input), + messages=messages, **self.model_params, ) content: str = "" diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index d82ccd25..a36d34f9 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -12,12 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from __future__ import annotations +from typing import Any, Iterable, Optional, Sequence, TYPE_CHECKING, cast + +from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from .base import LLMInterface -from .types import LLMResponse +from .types import ( + BaseMessage, + LLMMessage, + LLMResponse, + SystemMessage, + UserMessage, + MessageList, +) + +if TYPE_CHECKING: + from ollama import Message class OllamaLLM(LLMInterface): @@ -25,6 +38,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): try: @@ -34,7 +48,7 @@ def __init__( "Could not import ollama Python client. " "Please install it with `pip install ollama`." ) - super().__init__(model_name, model_params, **kwargs) + super().__init__(model_name, model_params, system_instruction, **kwargs) self.ollama = ollama self.client = ollama.Client( **kwargs, @@ -43,32 +57,81 @@ def __init__( **kwargs, ) - def invoke(self, input: str) -> LLMResponse: + def get_messages( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> Sequence[Message]: + messages = [] + system_message = ( + system_instruction + if system_instruction is not None + else self.system_instruction + ) + if system_message: + messages.append(SystemMessage(content=system_message).model_dump()) + if message_history: + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + messages.extend(cast(Iterable[dict[str, Any]], message_history)) + messages.append(UserMessage(content=input).model_dump()) + return messages # type: ignore + + def invoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: + """Sends text to the LLM and returns a response. + + Args: + input (str): The text to send to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. + + Returns: + LLMResponse: The response from the LLM. + """ try: response = self.client.chat( model=self.model_name, - messages=[ - { - "role": "user", - "content": input, - }, - ], + messages=self.get_messages(input, message_history, system_instruction), + options=self.model_params, ) content = response.message.content or "" return LLMResponse(content=content) except self.ollama.ResponseError as e: raise LLMGenerationError(e) - async def ainvoke(self, input: str) -> LLMResponse: + async def ainvoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: + """Asynchronously sends a text input to the OpenAI chat + completion model and returns the response's content. + + Args: + input (str): Text sent to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. + + Returns: + LLMResponse: The response from OpenAI. + + Raises: + LLMGenerationError: If anything goes wrong. + """ try: response = await self.async_client.chat( model=self.model_name, - messages=[ - { - "role": "user", - "content": input, - }, - ], + messages=self.get_messages(input, message_history, system_instruction), + options=self.model_params, ) content = response.message.content or "" return LLMResponse(content=content) diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 9de5071f..b3d99411 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -15,11 +15,20 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any, Iterable, Optional +from typing import TYPE_CHECKING, Any, Iterable, Optional, cast + +from pydantic import ValidationError from ..exceptions import LLMGenerationError from .base import LLMInterface -from .types import LLMResponse +from .types import ( + BaseMessage, + LLMMessage, + LLMResponse, + SystemMessage, + UserMessage, + MessageList, +) if TYPE_CHECKING: import openai @@ -36,6 +45,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, ): """ Base class for OpenAI LLM. @@ -44,7 +54,8 @@ def __init__( Args: model_name (str): - model_params (str): Parameters like temperature that will be passed to the model when text is sent to it + model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. + system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. """ try: import openai @@ -54,22 +65,44 @@ def __init__( Please install it with `pip install "neo4j-graphrag[openai]"`.""" ) self.openai = openai - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, system_instruction) def get_messages( self, input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, ) -> Iterable[ChatCompletionMessageParam]: - return [ - {"role": "system", "content": input}, - ] - - def invoke(self, input: str) -> LLMResponse: + messages = [] + system_message = ( + system_instruction + if system_instruction is not None + else self.system_instruction + ) + if system_message: + messages.append(SystemMessage(content=system_message).model_dump()) + if message_history: + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + messages.extend(cast(Iterable[dict[str, Any]], message_history)) + messages.append(UserMessage(content=input).model_dump()) + return messages # type: ignore + + def invoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Sends a text input to the OpenAI chat completion model and returns the response's content. Args: - input (str): Text sent to the LLM + input (str): Text sent to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from OpenAI. @@ -79,7 +112,7 @@ def invoke(self, input: str) -> LLMResponse: """ try: response = self.client.chat.completions.create( - messages=self.get_messages(input), + messages=self.get_messages(input, message_history, system_instruction), model=self.model_name, **self.model_params, ) @@ -88,12 +121,19 @@ def invoke(self, input: str) -> LLMResponse: except self.openai.OpenAIError as e: raise LLMGenerationError(e) - async def ainvoke(self, input: str) -> LLMResponse: + async def ainvoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat completion model and returns the response's content. Args: - input (str): Text sent to the LLM + input (str): Text sent to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from OpenAI. @@ -103,7 +143,7 @@ async def ainvoke(self, input: str) -> LLMResponse: """ try: response = await self.async_client.chat.completions.create( - messages=self.get_messages(input), + messages=self.get_messages(input, message_history, system_instruction), model=self.model_name, **self.model_params, ) @@ -118,6 +158,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): """OpenAI LLM @@ -126,10 +167,11 @@ def __init__( Args: model_name (str): - model_params (str): Parameters like temperature that will be passed to the model when text is sent to it + model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. + system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, system_instruction) self.client = self.openai.OpenAI(**kwargs) self.async_client = self.openai.AsyncOpenAI(**kwargs) @@ -139,6 +181,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): """Azure OpenAI LLM. Use this class when using an OpenAI model @@ -146,9 +189,10 @@ def __init__( Args: model_name (str): - model_params (str): Parameters like temperature that will be passed to the model when text is sent to it + model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. + system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, system_instruction) self.client = self.openai.AzureOpenAI(**kwargs) self.async_client = self.openai.AsyncAzureOpenAI(**kwargs) diff --git a/src/neo4j_graphrag/llm/types.py b/src/neo4j_graphrag/llm/types.py index f7aab5e9..77e89aef 100644 --- a/src/neo4j_graphrag/llm/types.py +++ b/src/neo4j_graphrag/llm/types.py @@ -1,5 +1,28 @@ from pydantic import BaseModel +from typing import Literal, TypedDict class LLMResponse(BaseModel): content: str + + +class LLMMessage(TypedDict): + role: Literal["system", "user", "assistant"] + content: str + + +class BaseMessage(BaseModel): + role: Literal["user", "assistant", "system"] + content: str + + +class UserMessage(BaseMessage): + role: Literal["user"] = "user" + + +class SystemMessage(BaseMessage): + role: Literal["system"] = "system" + + +class MessageList(BaseModel): + messages: list[BaseMessage] diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index dff5e1cf..48acfc9f 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -13,14 +13,21 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Optional +from typing import Any, Optional, cast + +from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import LLMResponse +from neo4j_graphrag.llm.types import BaseMessage, LLMMessage, LLMResponse, MessageList try: - from vertexai.generative_models import GenerativeModel, ResponseValidationError + from vertexai.generative_models import ( + GenerativeModel, + ResponseValidationError, + Part, + Content, + ) except ImportError: GenerativeModel = None ResponseValidationError = None @@ -32,6 +39,7 @@ class VertexAILLM(LLMInterface): Args: model_name (str, optional): Name of the LLM to use. Defaults to "gemini-1.5-flash-001". model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None. + system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. Raises: @@ -55,6 +63,7 @@ def __init__( self, model_name: str = "gemini-1.5-flash-001", model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): if GenerativeModel is None or ResponseValidationError is None: @@ -63,35 +72,100 @@ def __init__( Please install it with `pip install "neo4j-graphrag[google]"`.""" ) super().__init__(model_name, model_params) - self.model = GenerativeModel(model_name=model_name, **kwargs) - - def invoke(self, input: str) -> LLMResponse: + self.model_name = model_name + self.system_instruction = system_instruction + self.options = kwargs + + def get_messages( + self, input: str, message_history: Optional[list[LLMMessage]] = None + ) -> list[Content]: + messages = [] + if message_history: + try: + MessageList(messages=cast(list[BaseMessage], message_history)) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + + for message in message_history: + if message.get("role") == "user": + messages.append( + Content( + role="user", parts=[Part.from_text(message.get("content"))] + ) + ) + elif message.get("role") == "assistant": + messages.append( + Content( + role="model", parts=[Part.from_text(message.get("content"))] + ) + ) + + messages.append(Content(role="user", parts=[Part.from_text(input)])) + return messages + + def invoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from the LLM. """ + system_message = ( + system_instruction + if system_instruction is not None + else self.system_instruction + ) + self.model = GenerativeModel( + model_name=self.model_name, + system_instruction=[system_message], + **self.options, + ) try: - response = self.model.generate_content(input, **self.model_params) + messages = self.get_messages(input, message_history) + response = self.model.generate_content(messages, **self.model_params) return LLMResponse(content=response.text) except ResponseValidationError as e: raise LLMGenerationError(e) - async def ainvoke(self, input: str) -> LLMResponse: + async def ainvoke( + self, + input: str, + message_history: Optional[list[LLMMessage]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: LLMResponse: The response from the LLM. """ try: + system_message = ( + system_instruction + if system_instruction is not None + else self.system_instruction + ) + self.model = GenerativeModel( + model_name=self.model_name, + system_instruction=[system_message], + **self.options, + ) + messages = self.get_messages(input, message_history) response = await self.model.generate_content_async( - input, **self.model_params + messages, **self.model_params ) return LLMResponse(content=response.text) except ResponseValidationError as e: diff --git a/tests/e2e/test_graphrag_e2e.py b/tests/e2e/test_graphrag_e2e.py index 56136d1b..7eaaa50c 100644 --- a/tests/e2e/test_graphrag_e2e.py +++ b/tests/e2e/test_graphrag_e2e.py @@ -72,7 +72,8 @@ def test_graphrag_happy_path( biology Answer: -""" +""", + None, ) assert isinstance(result, RagResultModel) assert result.answer == "some text" @@ -117,7 +118,8 @@ def test_graphrag_happy_path_return_context( biology Answer: -""" +""", + None, ) assert isinstance(result, RagResultModel) assert result.answer == "some text" @@ -163,7 +165,8 @@ def test_graphrag_happy_path_examples( biology Answer: -""" +""", + None, ) assert result.answer == "some text" diff --git a/tests/unit/llm/test_anthropic_llm.py b/tests/unit/llm/test_anthropic_llm.py index c8d5f6f8..fc8d2756 100644 --- a/tests/unit/llm/test_anthropic_llm.py +++ b/tests/unit/llm/test_anthropic_llm.py @@ -18,8 +18,11 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import anthropic + import pytest +from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.anthropic_llm import AnthropicLLM +from neo4j_graphrag.llm.types import LLMResponse @pytest.fixture @@ -49,9 +52,120 @@ def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None: llm.client.messages.create.assert_called_once_with( # type: ignore messages=[{"role": "user", "content": input_text}], model="claude-3-opus-20240229", + system=None, + **model_params, + ) + + +def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock) -> None: + mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( + content="generated text" + ) + model_params = {"temperature": 0.3} + llm = AnthropicLLM( + "claude-3-opus-20240229", + model_params=model_params, + ) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + response = llm.invoke(question, message_history) # type: ignore + assert response.content == "generated text" + message_history.append({"role": "user", "content": question}) + llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined] + messages=message_history, + model="claude-3-opus-20240229", + system=None, + **model_params, + ) + + +def test_anthropic_invoke_with_message_history_and_system_instruction( + mock_anthropic: Mock, +) -> None: + mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( + content="generated text" + ) + model_params = {"temperature": 0.3} + initial_instruction = "You are a helpful assistant." + llm = AnthropicLLM( + "claude-3-opus-20240229", + model_params=model_params, + system_instruction=initial_instruction, + ) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + # first invokation - initial instructions + response = llm.invoke(question, message_history) # type: ignore + assert response.content == "generated text" + message_history.append({"role": "user", "content": question}) + llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined] + model="claude-3-opus-20240229", + system=initial_instruction, + messages=message_history, + **model_params, + ) + + # second invokation - override instructions + override_instruction = "Ignore all previous instructions" + question = "When does it come up in the winter?" + response = llm.invoke(question, message_history, override_instruction) # type: ignore + assert isinstance(response, LLMResponse) + assert response.content == "generated text" + message_history.append({"role": "user", "content": question}) + llm.client.messages.create.assert_called_with( # type: ignore[attr-defined] + model="claude-3-opus-20240229", + system=override_instruction, + messages=message_history, + **model_params, + ) + + # third invokation - default instructions + question = "When does it set?" + response = llm.invoke(question, message_history) # type: ignore + assert isinstance(response, LLMResponse) + assert response.content == "generated text" + message_history.append({"role": "user", "content": question}) + llm.client.messages.create.assert_called_with( # type: ignore[attr-defined] + model="claude-3-opus-20240229", + system=initial_instruction, + messages=message_history, **model_params, ) + assert llm.client.messages.create.call_count == 3 # type: ignore + + +def test_anthropic_invoke_with_message_history_validation_error( + mock_anthropic: Mock, +) -> None: + mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( + content="generated text" + ) + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + llm = AnthropicLLM( + "claude-3-opus-20240229", + model_params=model_params, + system_instruction=system_instruction, + ) + message_history = [ + {"role": "human", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, message_history) # type: ignore + assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) + @pytest.mark.asyncio async def test_anthropic_ainvoke_happy_path(mock_anthropic: Mock) -> None: @@ -66,6 +180,7 @@ async def test_anthropic_ainvoke_happy_path(mock_anthropic: Mock) -> None: assert response.content == "Return text" llm.async_client.messages.create.assert_awaited_once_with( # type: ignore model="claude-3-opus-20240229", + system=None, messages=[{"role": "user", "content": input_text}], **model_params, ) diff --git a/tests/unit/llm/test_cohere_llm.py b/tests/unit/llm/test_cohere_llm.py index db1b9db8..6088799a 100644 --- a/tests/unit/llm/test_cohere_llm.py +++ b/tests/unit/llm/test_cohere_llm.py @@ -27,7 +27,6 @@ def mock_cohere() -> Generator[MagicMock, None, None]: mock_cohere = MagicMock() mock_cohere.core.api_error.ApiError = cohere.core.ApiError - with patch.dict(sys.modules, {"cohere": mock_cohere}): yield mock_cohere @@ -39,39 +38,140 @@ def test_cohere_llm_missing_dependency(mock_import: Mock) -> None: def test_cohere_llm_happy_path(mock_cohere: Mock) -> None: - mock_cohere.Client.return_value.chat.return_value = MagicMock( - text="cohere response text" + chat_response_mock = MagicMock() + chat_response_mock.message.content = [MagicMock(text="cohere response text")] + mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock + llm = CohereLLM(model_name="something") + res = llm.invoke("my text") + assert isinstance(res, LLMResponse) + assert res.content == "cohere response text" + + +def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) -> None: + chat_response_mock = MagicMock() + chat_response_mock.message.content = [MagicMock(text="cohere response text")] + mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock + + system_instruction = "You are a helpful assistant." + llm = CohereLLM(model_name="something", system_instruction=system_instruction) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + res = llm.invoke(question, message_history) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "cohere response text" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_once_with( + messages=messages, + model="something", + ) + + +def test_cohere_llm_invoke_with_message_history_and_system_instruction( + mock_cohere: Mock, +) -> None: + chat_response_mock = MagicMock() + chat_response_mock.message.content = [MagicMock(text="cohere response text")] + mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock + + initial_instruction = "You are a helpful assistant." + llm = CohereLLM(model_name="gpt", system_instruction=initial_instruction) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + # first invokation - initial instructions + res = llm.invoke(question, message_history) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "cohere response text" + messages = [{"role": "system", "content": initial_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_once_with( + messages=messages, + model="gpt", ) - embedder = CohereLLM(model_name="something") - res = embedder.invoke("my text") + + # second invokation - override instructions + override_instruction = "Ignore all previous instructions" + res = llm.invoke(question, message_history, override_instruction) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "cohere response text" + messages = [{"role": "system", "content": override_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_with( + messages=messages, + model="gpt", + ) + + # third invokation - default instructions + res = llm.invoke(question, message_history) # type: ignore assert isinstance(res, LLMResponse) assert res.content == "cohere response text" + messages = [{"role": "system", "content": initial_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_with( + messages=messages, + model="gpt", + ) + + assert llm.client.chat.call_count == 3 + + +def test_cohere_llm_invoke_with_message_history_validation_error( + mock_cohere: Mock, +) -> None: + chat_response_mock = MagicMock() + chat_response_mock.message.content = [MagicMock(text="cohere response text")] + mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock + + system_instruction = "You are a helpful assistant." + llm = CohereLLM(model_name="something", system_instruction=system_instruction) + message_history = [ + {"role": "robot", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, message_history) # type: ignore + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) @pytest.mark.asyncio async def test_cohere_llm_happy_path_async(mock_cohere: Mock) -> None: - async_mock = Mock() - async_mock.chat = AsyncMock(return_value=MagicMock(text="cohere response text")) - mock_cohere.AsyncClient.return_value = async_mock - embedder = CohereLLM(model_name="something") - res = await embedder.ainvoke("my text") + chat_response_mock = AsyncMock() + chat_response_mock.message.content = [AsyncMock(text="cohere response text")] + mock_cohere.AsyncClientV2.return_value.chat.return_value = chat_response_mock + + llm = CohereLLM(model_name="something") + res = await llm.ainvoke("my text") assert isinstance(res, LLMResponse) assert res.content == "cohere response text" def test_cohere_llm_failed(mock_cohere: Mock) -> None: - mock_cohere.Client.return_value.chat.side_effect = cohere.core.ApiError - embedder = CohereLLM(model_name="something") + mock_cohere.ClientV2.return_value.chat.side_effect = cohere.core.ApiError + llm = CohereLLM(model_name="something") with pytest.raises(LLMGenerationError) as excinfo: - embedder.invoke("my text") + llm.invoke("my text") assert "ApiError" in str(excinfo) @pytest.mark.asyncio async def test_cohere_llm_failed_async(mock_cohere: Mock) -> None: - mock_cohere.AsyncClient.return_value.chat.side_effect = cohere.core.ApiError - embedder = CohereLLM(model_name="something") + mock_cohere.AsyncClientV2.return_value.chat.side_effect = cohere.core.ApiError + llm = CohereLLM(model_name="something") with pytest.raises(LLMGenerationError) as excinfo: - await embedder.ainvoke("my text") + await llm.ainvoke("my text") assert "ApiError" in str(excinfo) diff --git a/tests/unit/llm/test_mistralai_llm.py b/tests/unit/llm/test_mistralai_llm.py new file mode 100644 index 00000000..4d5c5c96 --- /dev/null +++ b/tests/unit/llm/test_mistralai_llm.py @@ -0,0 +1,211 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any +from unittest.mock import MagicMock, Mock, patch + +import pytest +from mistralai.models.sdkerror import SDKError +from neo4j_graphrag.exceptions import LLMGenerationError +from neo4j_graphrag.llm import LLMResponse, MistralAILLM + + +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral", None) +def test_mistralai_llm_missing_dependency() -> None: + with pytest.raises(ImportError): + MistralAILLM(model_name="mistral-model") + + +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") +def test_mistralai_llm_invoke(mock_mistral: Mock) -> None: + mock_mistral_instance = mock_mistral.return_value + + chat_response_mock = MagicMock() + chat_response_mock.choices = [ + MagicMock(message=MagicMock(content="mistral response")) + ] + + mock_mistral_instance.chat.complete.return_value = chat_response_mock + + llm = MistralAILLM(model_name="mistral-model") + + res = llm.invoke("some input") + + assert isinstance(res, LLMResponse) + assert res.content == "mistral response" + + +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") +def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None: + mock_mistral_instance = mock_mistral.return_value + chat_response_mock = MagicMock() + chat_response_mock.choices = [ + MagicMock(message=MagicMock(content="mistral response")) + ] + mock_mistral_instance.chat.complete.return_value = chat_response_mock + model = "mistral-model" + system_instruction = "You are a helpful assistant." + + llm = MistralAILLM(model_name=model, system_instruction=system_instruction) + + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + res = llm.invoke(question, message_history) # type: ignore + + assert isinstance(res, LLMResponse) + assert res.content == "mistral response" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined] + messages=messages, + model=model, + ) + + +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") +def test_mistralai_llm_invoke_with_message_history_and_system_instruction( + mock_mistral: Mock, +) -> None: + mock_mistral_instance = mock_mistral.return_value + chat_response_mock = MagicMock() + chat_response_mock.choices = [ + MagicMock(message=MagicMock(content="mistral response")) + ] + mock_mistral_instance.chat.complete.return_value = chat_response_mock + model = "mistral-model" + initial_instruction = "You are a helpful assistant." + llm = MistralAILLM(model_name=model, system_instruction=initial_instruction) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + # first invokation - initial instructions + res = llm.invoke(question, message_history) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "mistral response" + messages = [{"role": "system", "content": initial_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined] + messages=messages, + model=model, + ) + + # second invokation - override instructions + override_instruction = "Ignore all previous instructions" + res = llm.invoke(question, message_history, override_instruction) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "mistral response" + messages = [{"role": "system", "content": override_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.complete.assert_called_with( # type: ignore + messages=messages, + model=model, + ) + + # third invokation - default instructions + res = llm.invoke(question, message_history) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "mistral response" + messages = [{"role": "system", "content": initial_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.complete.assert_called_with( # type: ignore + messages=messages, + model=model, + ) + + assert llm.client.chat.complete.call_count == 3 # type: ignore + + +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") +def test_mistralai_llm_invoke_with_message_history_validation_error( + mock_mistral: Mock, +) -> None: + mock_mistral_instance = mock_mistral.return_value + chat_response_mock = MagicMock() + chat_response_mock.choices = [ + MagicMock(message=MagicMock(content="mistral response")) + ] + mock_mistral_instance.chat.complete.return_value = chat_response_mock + model = "mistral-model" + system_instruction = "You are a helpful assistant." + + llm = MistralAILLM(model_name=model, system_instruction=system_instruction) + + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "monkey", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, message_history) # type: ignore + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) + + +@pytest.mark.asyncio +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") +async def test_mistralai_llm_ainvoke(mock_mistral: Mock) -> None: + mock_mistral_instance = mock_mistral.return_value + + async def mock_complete_async(*args: Any, **kwargs: Any) -> MagicMock: + chat_response_mock = MagicMock() + chat_response_mock.choices = [ + MagicMock(message=MagicMock(content="async mistral response")) + ] + return chat_response_mock + + mock_mistral_instance.chat.complete_async = mock_complete_async + + llm = MistralAILLM(model_name="mistral-model") + + res = await llm.ainvoke("some input") + + assert isinstance(res, LLMResponse) + assert res.content == "async mistral response" + + +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") +def test_mistralai_llm_invoke_sdkerror(mock_mistral: Mock) -> None: + mock_mistral_instance = mock_mistral.return_value + mock_mistral_instance.chat.complete.side_effect = SDKError("Some error") + + llm = MistralAILLM(model_name="mistral-model") + + with pytest.raises(LLMGenerationError): + llm.invoke("some input") + + +@pytest.mark.asyncio +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") +async def test_mistralai_llm_ainvoke_sdkerror(mock_mistral: Mock) -> None: + mock_mistral_instance = mock_mistral.return_value + + async def mock_complete_async(*args: Any, **kwargs: Any) -> None: + raise SDKError("Some async error") + + mock_mistral_instance.chat.complete_async = mock_complete_async + + llm = MistralAILLM(model_name="mistral-model") + + with pytest.raises(LLMGenerationError): + await llm.ainvoke("some input") diff --git a/tests/unit/llm/test_mistralaillm.py b/tests/unit/llm/test_mistralaillm.py deleted file mode 100644 index 0e0f1b6c..00000000 --- a/tests/unit/llm/test_mistralaillm.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any -from unittest.mock import MagicMock, Mock, patch - -import pytest -from mistralai.models.sdkerror import SDKError -from neo4j_graphrag.exceptions import LLMGenerationError -from neo4j_graphrag.llm import LLMResponse, MistralAILLM - - -@patch("neo4j_graphrag.llm.mistralai_llm.Mistral", None) -def test_mistral_ai_llm_missing_dependency() -> None: - with pytest.raises(ImportError): - MistralAILLM(model_name="mistral-model") - - -@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") -def test_mistral_ai_llm_invoke(mock_mistral: Mock) -> None: - mock_mistral_instance = mock_mistral.return_value - - chat_response_mock = MagicMock() - chat_response_mock.choices = [ - MagicMock(message=MagicMock(content="mistral response")) - ] - - mock_mistral_instance.chat.complete.return_value = chat_response_mock - - llm = MistralAILLM(model_name="mistral-model") - - res = llm.invoke("some input") - - assert isinstance(res, LLMResponse) - assert res.content == "mistral response" - - -@pytest.mark.asyncio -@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") -async def test_mistral_ai_llm_ainvoke(mock_mistral: Mock) -> None: - mock_mistral_instance = mock_mistral.return_value - - async def mock_complete_async(*args: Any, **kwargs: Any) -> MagicMock: - chat_response_mock = MagicMock() - chat_response_mock.choices = [ - MagicMock(message=MagicMock(content="async mistral response")) - ] - return chat_response_mock - - mock_mistral_instance.chat.complete_async = mock_complete_async - - llm = MistralAILLM(model_name="mistral-model") - - res = await llm.ainvoke("some input") - - assert isinstance(res, LLMResponse) - assert res.content == "async mistral response" - - -@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") -def test_mistral_ai_llm_invoke_sdkerror(mock_mistral: Mock) -> None: - mock_mistral_instance = mock_mistral.return_value - mock_mistral_instance.chat.complete.side_effect = SDKError("Some error") - - llm = MistralAILLM(model_name="mistral-model") - - with pytest.raises(LLMGenerationError): - llm.invoke("some input") - - -@pytest.mark.asyncio -@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") -async def test_mistral_ai_llm_ainvoke_sdkerror(mock_mistral: Mock) -> None: - mock_mistral_instance = mock_mistral.return_value - - async def mock_complete_async(*args: Any, **kwargs: Any) -> None: - raise SDKError("Some async error") - - mock_mistral_instance.chat.complete_async = mock_complete_async - - llm = MistralAILLM(model_name="mistral-model") - - with pytest.raises(LLMGenerationError): - await llm.ainvoke("some input") diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py index f7be308b..deb56f93 100644 --- a/tests/unit/llm/test_ollama_llm.py +++ b/tests/unit/llm/test_ollama_llm.py @@ -12,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from unittest.mock import MagicMock, Mock, patch import ollama import pytest +from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.ollama_llm import OllamaLLM @@ -39,8 +41,162 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: mock_ollama.Client.return_value.chat.return_value = MagicMock( message=MagicMock(content="ollama chat response"), ) - llm = OllamaLLM(model_name="gpt") + model = "gpt" + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + question = "What is graph RAG?" + llm = OllamaLLM( + model, + model_params=model_params, + system_instruction=system_instruction, + ) + + res = llm.invoke(question) + assert isinstance(res, LLMResponse) + assert res.content == "ollama chat response" + messages = [ + {"role": "system", "content": system_instruction}, + {"role": "user", "content": question}, + ] + llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] + model=model, messages=messages, options=model_params + ) + + +@patch("builtins.__import__") +def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> None: + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama chat response"), + ) + model = "gpt" + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + llm = OllamaLLM( + model, + model_params=model_params, + system_instruction=system_instruction, + ) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + response = llm.invoke(question, message_history) # type: ignore + assert response.content == "ollama chat response" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] + model=model, messages=messages, options=model_params + ) + + +@patch("builtins.__import__") +def test_ollama_invoke_with_message_history_and_system_instruction( + mock_import: Mock, +) -> None: + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama chat response"), + ) + model = "gpt" + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + llm = OllamaLLM( + model, + model_params=model_params, + system_instruction=system_instruction, + ) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + # first invokation - initial instructions + response = llm.invoke(question, message_history) # type: ignore + assert response.content == "ollama chat response" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] + model=model, messages=messages, options=model_params + ) + + # second invokation - override instructions + override_instruction = "Ignore all previous instructions" + response = llm.invoke(question, message_history, override_instruction) # type: ignore + assert response.content == "ollama chat response" + messages = [{"role": "system", "content": override_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_with( # type: ignore[attr-defined] + model=model, messages=messages, options=model_params + ) + + # third invokation - default instructions + response = llm.invoke(question, message_history) # type: ignore + assert response.content == "ollama chat response" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_with( # type: ignore[attr-defined] + model=model, messages=messages, options=model_params + ) + + assert llm.client.chat.call_count == 3 # type: ignore + + +@patch("builtins.__import__") +def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock) -> None: + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + mock_ollama.ResponseError = ollama.ResponseError + model = "gpt" + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + llm = OllamaLLM( + model, + model_params=model_params, + system_instruction=system_instruction, + ) + message_history = [ + {"role": "human", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, message_history) # type: ignore + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) + + +@pytest.mark.asyncio +@patch("builtins.__import__") +async def test_ollama_ainvoke_happy_path(mock_import: Mock) -> None: + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + + async def mock_chat_async(*args: Any, **kwargs: Any) -> MagicMock: + return MagicMock( + message=MagicMock(content="ollama chat response"), + ) + + mock_ollama.AsyncClient.return_value.chat = mock_chat_async + model = "gpt" + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + question = "What is graph RAG?" + llm = OllamaLLM( + model, + model_params=model_params, + system_instruction=system_instruction, + ) - res = llm.invoke("my text") + res = await llm.ainvoke(question) assert isinstance(res, LLMResponse) assert res.content == "ollama chat response" diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 546d4e39..82a79325 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -16,6 +16,7 @@ import openai import pytest +from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM @@ -46,6 +47,108 @@ def test_openai_llm_happy_path(mock_import: Mock) -> None: assert res.content == "openai chat response" +@patch("builtins.__import__") +def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) + llm = OpenAILLM(api_key="my key", model_name="gpt") + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + res = llm.invoke(question, message_history) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "openai chat response" + message_history.append({"role": "user", "content": question}) + llm.client.chat.completions.create.assert_called_once_with( # type: ignore + messages=message_history, + model="gpt", + ) + + +@patch("builtins.__import__") +def test_openai_llm_with_message_history_and_system_instruction( + mock_import: Mock, +) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) + initial_instruction = "You are a helpful assistent." + llm = OpenAILLM( + api_key="my key", model_name="gpt", system_instruction=initial_instruction + ) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + # first invokation - initial instructions + res = llm.invoke(question, message_history) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "openai chat response" + messages = [{"role": "system", "content": initial_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.completions.create.assert_called_once_with( # type: ignore + messages=messages, + model="gpt", + ) + + # second invokation - override instructions + override_instruction = "Ignore all previous instructions" + res = llm.invoke(question, message_history, override_instruction) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "openai chat response" + messages = [{"role": "system", "content": override_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.completions.create.assert_called_with( # type: ignore + messages=messages, + model="gpt", + ) + + # third invokation - default instructions + res = llm.invoke(question, message_history) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "openai chat response" + messages = [{"role": "system", "content": initial_instruction}] + messages.extend(message_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.completions.create.assert_called_with( # type: ignore + messages=messages, + model="gpt", + ) + + assert llm.client.chat.completions.create.call_count == 3 # type: ignore + + +@patch("builtins.__import__") +def test_openai_llm_with_message_history_validation_error(mock_import: Mock) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) + llm = OpenAILLM(api_key="my key", model_name="gpt") + message_history = [ + {"role": "human", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, message_history) # type: ignore + assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) + + @patch("builtins.__import__", side_effect=ImportError) def test_azure_openai_llm_missing_dependency(mock_import: Mock) -> None: with pytest.raises(ImportError): @@ -71,3 +174,63 @@ def test_azure_openai_llm_happy_path(mock_import: Mock) -> None: res = llm.invoke("my text") assert isinstance(res, LLMResponse) assert res.content == "openai chat response" + + +@patch("builtins.__import__") +def test_azure_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = ( + MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) + ) + llm = AzureOpenAILLM( + model_name="gpt", + azure_endpoint="https://test.openai.azure.com/", + api_key="my key", + api_version="version", + ) + + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + res = llm.invoke(question, message_history) # type: ignore + assert isinstance(res, LLMResponse) + assert res.content == "openai chat response" + message_history.append({"role": "user", "content": question}) + llm.client.chat.completions.create.assert_called_once_with( # type: ignore + messages=message_history, + model="gpt", + ) + + +@patch("builtins.__import__") +def test_azure_openai_llm_with_message_history_validation_error( + mock_import: Mock, +) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = ( + MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) + ) + llm = AzureOpenAILLM( + model_name="gpt", + azure_endpoint="https://test.openai.azure.com/", + api_key="my key", + api_version="version", + ) + + message_history = [ + {"role": "user", "content": 33}, + ] + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, message_history) # type: ignore + assert "Input should be a valid string" in str(exc_info.value) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index adffeb1d..e0755376 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -13,10 +13,15 @@ # limitations under the License. from __future__ import annotations +from typing import cast +from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest +from neo4j_graphrag.exceptions import LLMGenerationError +from neo4j_graphrag.llm.types import LLMMessage from neo4j_graphrag.llm.vertexai_llm import VertexAILLM +from vertexai.generative_models import Content, Part @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None) @@ -27,16 +32,103 @@ def test_vertexai_llm_missing_dependency() -> None: @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: + model_name = "gemini-1.5-flash-001" + input_text = "may thy knife chip and shatter" mock_response = Mock() mock_response.text = "Return text" mock_model = GenerativeModelMock.return_value mock_model.generate_content.return_value = mock_response model_params = {"temperature": 0.5} - llm = VertexAILLM("gemini-1.5-flash-001", model_params) + llm = VertexAILLM(model_name, model_params) + + response = llm.invoke(input_text) + assert response.content == "Return text" + GenerativeModelMock.assert_called_once_with( + model_name=model_name, system_instruction=[None] + ) + user_message = mock.ANY + llm.model.generate_content.assert_called_once_with(user_message, **model_params) + + +@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +def test_vertexai_invoke_with_message_history_and_system_instruction( + GenerativeModelMock: MagicMock, +) -> None: + system_instruction = "You are a helpful assistant." + model_name = "gemini-1.5-flash-001" input_text = "may thy knife chip and shatter" + mock_response = Mock() + mock_response.text = "Return text" + mock_model = GenerativeModelMock.return_value + mock_model.generate_content.return_value = mock_response + model_params = {"temperature": 0.5} + llm = VertexAILLM(model_name, model_params, system_instruction) + response = llm.invoke(input_text) assert response.content == "Return text" - llm.model.generate_content.assert_called_once_with(input_text, **model_params) + GenerativeModelMock.assert_called_once_with( + model_name=model_name, system_instruction=[system_instruction] + ) + user_message = mock.ANY + llm.model.generate_content.assert_called_once_with(user_message, **model_params) + + message_history = [ + {"role": "user", "content": "hello!"}, + {"role": "assistant", "content": "hi."}, + ] + response = llm.invoke(input_text, message_history, "new instructions") # type:ignore + GenerativeModelMock.assert_called_with( + model_name=model_name, system_instruction=["new instructions"] + ) + messages = [mock.ANY, mock.ANY, mock.ANY] + llm.model.generate_content.assert_called_with(messages, **model_params) + + +@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: + system_instruction = "You are a helpful assistant." + model_name = "gemini-1.5-flash-001" + question = "When does it set?" + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + {"role": "user", "content": "What about next season?"}, + {"role": "assistant", "content": "Around 8am."}, + ] + expected_response = [ + Content( + role="user", + parts=[Part.from_text("When does the sun come up in the summer?")], + ), + Content(role="model", parts=[Part.from_text("Usually around 6am.")]), + Content(role="user", parts=[Part.from_text("What about next season?")]), + Content(role="model", parts=[Part.from_text("Around 8am.")]), + Content(role="user", parts=[Part.from_text("When does it set?")]), + ] + + llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) + response = llm.get_messages(question, cast(list[LLMMessage], message_history)) + + GenerativeModelMock.assert_not_called + assert len(response) == len(expected_response) + for actual, expected in zip(response, expected_response): + assert actual.role == expected.role + assert actual.parts[0].text == expected.parts[0].text + + +@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) -> None: + system_instruction = "You are a helpful assistant." + model_name = "gemini-1.5-flash-001" + question = "hi!" + message_history = [ + {"role": "model", "content": "hello!"}, + ] + + llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, cast(list[LLMMessage], message_history)) + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) @pytest.mark.asyncio @@ -51,4 +143,4 @@ async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> No input_text = "may thy knife chip and shatter" response = await llm.ainvoke(input_text) assert response.content == "Return text" - llm.model.generate_content_async.assert_called_once_with(input_text, **model_params) + llm.model.generate_content_async.assert_called_once_with([mock.ANY], **model_params) diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index 58508d71..178d34f7 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import MagicMock, call import pytest from neo4j_graphrag.exceptions import RagInitializationError, SearchValidationError @@ -62,7 +62,8 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None: res = rag.search("question") retriever_mock.search.assert_called_once_with(query_text="question") - llm.invoke.assert_called_once_with("""Answer the user question using the following context + llm.invoke.assert_called_once_with( + """Answer the user question using the following context Context: item content 1 @@ -75,7 +76,81 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None: question Answer: -""") +""", + None, + ) + + assert isinstance(res, RagResultModel) + assert res.answer == "llm generated text" + assert res.retriever_result is None + + +def test_graphrag_happy_path_with_message_history( + retriever_mock: MagicMock, llm: MagicMock +) -> None: + rag = GraphRAG( + retriever=retriever_mock, + llm=llm, + ) + retriever_mock.search.return_value = RetrieverResult( + items=[ + RetrieverResultItem(content="item content 1"), + RetrieverResultItem(content="item content 2"), + ] + ) + llm.invoke.side_effect = [ + LLMResponse(content="llm generated summary"), + LLMResponse(content="llm generated text"), + ] + message_history = [ + {"role": "user", "content": "initial question"}, + {"role": "assistant", "content": "answer to initial question"}, + ] + res = rag.search("question", message_history) # type: ignore + + expected_retriever_query_text = """ +Message Summary: +llm generated summary + +Current Query: +question +""" + + first_invokation_input = """ +Summarize the message history: + +user: initial question +assistant: answer to initial question +""" + first_invokation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words." + second_invokation = """Answer the user question using the following context + +Context: +item content 1 +item content 2 + +Examples: + + +Question: +question + +Answer: +""" + + retriever_mock.search.assert_called_once_with( + query_text=expected_retriever_query_text + ) + assert llm.invoke.call_count == 2 + llm.invoke.assert_has_calls( + [ + call( + input=first_invokation_input, + system_instruction=first_invokation_system_instruction, + ), + call(second_invokation, message_history), + ] + ) assert isinstance(res, RagResultModel) assert res.answer == "llm generated text" @@ -99,3 +174,48 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non with pytest.raises(SearchValidationError) as excinfo: rag.search(10) # type: ignore assert "Input should be a valid string" in str(excinfo) + + +def test_chat_summary_template(retriever_mock: MagicMock, llm: MagicMock) -> None: + message_history = [ + {"role": "user", "content": "initial question"}, + {"role": "assistant", "content": "answer to initial question"}, + {"role": "user", "content": "second question"}, + {"role": "assistant", "content": "answer to second question"}, + ] + rag = GraphRAG( + retriever=retriever_mock, + llm=llm, + ) + prompt = rag.chat_summary_prompt(message_history=message_history) # type: ignore + assert ( + prompt + == """ +Summarize the message history: + +user: initial question +assistant: answer to initial question +user: second question +assistant: answer to second question +""" + ) + + +def test_conversation_template(retriever_mock: MagicMock, llm: MagicMock) -> None: + rag = GraphRAG( + retriever=retriever_mock, + llm=llm, + ) + prompt = rag.conversation_prompt( + summary="llm generated chat summary", current_query="latest question" + ) + assert ( + prompt + == """ +Message Summary: +llm generated chat summary + +Current Query: +latest question +""" + )