Skip to content

Commit

Permalink
Enhance langgraph integration to preserve metadata (explodinggradient…
Browse files Browse the repository at this point in the history
…s#1878)

# Enhance langgraph integration to preserve AI metadata

## Description
This PR updates the `langgraph.py` integration to ensure that metadata
are preserved. This enhancement is crucial for multi-agent scenarios
where identifying the source AI is important for evaluation.

## Changes
- Updated `langgraph.py` to ensure AI names and metadata are preserved.

## Motivation and Context
In the current implementation, metadata such as the name assigned to an
AI is not saved. In the era of multi-agent systems, it is essential to
have information about which AI made a particular statement for accurate
evaluation. This update addresses this issue by preserving the necessary
metadata.

### Example Code and Output
```python
import json
from typing import List, Union

from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage

import ragas.messages as r
from ragas.integrations.langgraph import convert_message_with_metadata

def test_convert_message_with_metadata():
    from langchain_core.messages import HumanMessage, AIMessage

    human_message = HumanMessage(content="Hello", name="me", additional_kwargs={"key1": "value1"})
    ai_message = AIMessage(content="Hi", name="ai_1", additional_kwargs={"tool_calls": [{"function": {"name": "tool1", "arguments": '{"arg1": "val1"}'}}]})

    converted_messages = convert_message_with_metadata([human_message, ai_message])

    for msg in converted_messages:
        print(f"Content: {msg.content}, Metadata: {msg.metadata}")

if __name__ == "__main__":
    test_convert_message_with_metadata()
```

```
Output
Content: Hello, Metadata: {'additional_kwargs': {'key1': 'value1'}, 'response_metadata': {}, 'type': 'human', 'name': 'me', 'id': None, 'example': False}
Content: Hi, Metadata: {'additional_kwargs': {'tool_calls': [{'function': {'name': 'tool1', 'arguments': '{"arg1": "val1"}'}}]}, 'response_metadata': {}, 'type': 'ai', 'name': 'ai_1', 'id': None, 'example': False, 'tool_calls': [{'name': 'tool1', 'args': {'arg1': 'val1'}, 'id': None, 'type': 'tool_call'}], 'invalid_tool_calls': [], 'usage_metadata': None}
```
  • Loading branch information
i-w-a authored Jan 30, 2025
1 parent 0626b5d commit 34db378
Showing 1 changed file with 43 additions and 18 deletions.
61 changes: 43 additions & 18 deletions src/ragas/integrations/langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,22 @@


def convert_to_ragas_messages(
messages: List[Union[HumanMessage, SystemMessage, AIMessage, ToolMessage]]
messages: List[Union[HumanMessage, SystemMessage, AIMessage, ToolMessage]], metadata: bool = False
) -> List[Union[r.HumanMessage, r.AIMessage, r.ToolMessage]]:
"""
Convert LangChain messages into Ragas messages for agent evaluation.
Convert LangChain messages into Ragas messages with metadata for agent evaluation.
Parameters
----------
messages : List[Union[HumanMessage, SystemMessage, AIMessage, ToolMessage]]
List of LangChain message objects to be converted.
metadata : bool, optional (default=False)
Whether to include metadata in the converted messages.
Returns
-------
List[Union[r.HumanMessage, r.AIMessage, r.ToolMessage]]
List of corresponding Ragas message objects.
List of corresponding Ragas message objects with metadata.
Raises
------
Expand All @@ -42,14 +44,30 @@ def _validate_string_content(message, message_type: str) -> str:
)
return message.content

MESSAGE_TYPE_MAP = {
HumanMessage: lambda m: r.HumanMessage(
content=_validate_string_content(m, "HumanMessage")
),
ToolMessage: lambda m: r.ToolMessage(
content=_validate_string_content(m, "ToolMessage")
),
}
def _extract_metadata(message) -> dict:

return {k: v for k, v in message.__dict__.items() if k != "content"}

if metadata:
MESSAGE_TYPE_MAP = {
HumanMessage: lambda m: r.HumanMessage(
content=_validate_string_content(m, "HumanMessage"),
metadata=_extract_metadata(m)
),
ToolMessage: lambda m: r.ToolMessage(
content=_validate_string_content(m, "ToolMessage"),
metadata=_extract_metadata(m)
),
}
else:
MESSAGE_TYPE_MAP = {
HumanMessage: lambda m: r.HumanMessage(
content=_validate_string_content(m, "HumanMessage")
),
ToolMessage: lambda m: r.ToolMessage(
content=_validate_string_content(m, "ToolMessage")
),
}

def _extract_tool_calls(message: AIMessage) -> List[r.ToolCall]:
tool_calls = message.additional_kwargs.get("tool_calls", [])
Expand All @@ -61,18 +79,25 @@ def _extract_tool_calls(message: AIMessage) -> List[r.ToolCall]:
for tool_call in tool_calls
]

def _convert_ai_message(message: AIMessage) -> r.AIMessage:
def _convert_ai_message(message: AIMessage, metadata: bool) -> r.AIMessage:
tool_calls = _extract_tool_calls(message) if message.additional_kwargs else None
return r.AIMessage(
content=_validate_string_content(message, "AIMessage"),
tool_calls=tool_calls,
)
if metadata:
return r.AIMessage(
content=_validate_string_content(message, "AIMessage"),
tool_calls=tool_calls,
metadata=_extract_metadata(message)
)
else:
return r.AIMessage(
content=_validate_string_content(message, "AIMessage"),
tool_calls=tool_calls
)

def _convert_message(message):
def _convert_message(message, metadata: bool = False):
if isinstance(message, SystemMessage):
return None # Skip SystemMessages
if isinstance(message, AIMessage):
return _convert_ai_message(message)
return _convert_ai_message(message, metadata)
converter = MESSAGE_TYPE_MAP.get(type(message))
if converter is None:
raise ValueError(f"Unsupported message type: {type(message).__name__}")
Expand Down

0 comments on commit 34db378

Please sign in to comment.