Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
leila-messallem committed Dec 16, 2024
1 parent 6aea7fa commit 4c62827
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 8 deletions.
4 changes: 3 additions & 1 deletion src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def build_query(
chat_history=chat_history
)
summary = self.llm.invoke(
input=summarization_prompt, system_instruction=summarization_prompt.SYSTEM_MESSAGE).content
input=summarization_prompt,
system_instruction=summarization_prompt.SYSTEM_MESSAGE,
).content
return ConversationTemplate().format(
summary=summary, current_query=query_text
)
Expand Down
5 changes: 4 additions & 1 deletion src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def __init__(

@abstractmethod
def invoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None
self,
input: str,
chat_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the LLM and retrieves a response.
Expand Down
16 changes: 13 additions & 3 deletions src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,17 @@ def __init__(
super().__init__(model_name, model_params, system_instruction)

def get_messages(
self, input: str, chat_history: Optional[list[Any]] = None, system_instruction: Optional[str] = None
self,
input: str,
chat_history: Optional[list[Any]] = None,
system_instruction: Optional[str] = None,
) -> Iterable[ChatCompletionMessageParam]:
messages = []
system_message = system_instruction if system_instruction is not None else self.system_instruction
system_message = (
system_instruction
if system_instruction is not None
else self.system_instruction
)
if system_message:
messages.append(SystemMessage(content=system_message).model_dump())
if chat_history:
Expand All @@ -77,7 +84,10 @@ def get_messages(
return messages

def invoke(
self, input: str, chat_history: Optional[list[Any]] = None, system_instruction: Optional[str] = None
self,
input: str,
chat_history: Optional[list[Any]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the OpenAI chat completion model
and returns the response's content.
Expand Down
15 changes: 12 additions & 3 deletions src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def get_messages(
return messages

def invoke(
self, input: str, chat_history: Optional[list[Any]] = None, system_instruction: Optional[str] = None
self,
input: str,
chat_history: Optional[list[Any]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -116,9 +119,15 @@ def invoke(
Returns:
LLMResponse: The response from the LLM.
"""
system_message = system_instruction if system_instruction is not None else self.system_instruction
system_message = (
system_instruction
if system_instruction is not None
else self.system_instruction
)
self.model = GenerativeModel(
model_name=self.model_name, system_instruction=[system_message], **self.model_params
model_name=self.model_name,
system_instruction=[system_message],
**self.model_params,
)
try:
messages = self.get_messages(input, chat_history)
Expand Down

0 comments on commit 4c62827

Please sign in to comment.