Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RedisConversationMemoryDriver #787

Merged
merged 12 commits into from
May 16, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added
- `RedisConversationMemoryDriver` to save conversation memory in redis.

### Changed
- Default behavior of OpenAiStructureConfig to utilize `gpt-4o` for prompt_driver.

Expand Down
30 changes: 30 additions & 0 deletions docs/griptape-framework/drivers/conversation-memory-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,33 @@ agent = Agent(conversation_memory=ConversationMemory(driver=dynamodb_driver))
agent.run("My name is Jeff.")
agent.run("What is my name?")
```


### Redis Conversation Memory Driver

!!! info
This driver requires the `drivers-memory-conversation-redis` [extra](../index.md#extras).

The [RedisConversationMemoryDriver](../../reference/griptape/drivers/memory/conversation/redis_conversation_memory_driver.md) allows you to persist Conversation Memory in [Redis](https://redis.io/).

```python
import os
import uuid
from griptape.drivers import RedisConversationMemoryDriver
from griptape.memory.structure import ConversationMemory
from griptape.structures import Agent

conversation_id = uuid.uuid4().hex
redis_conversation_driver = RedisConversationMemoryDriver(
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
index=os.environ["REDIS_INDEX"],
conversation_id = conversation_id
)

agent = Agent(conversation_memory=ConversationMemory(driver=redis_conversation_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
```
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver
from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver
from .memory.conversation.amazon_dynamodb_conversation_memory_driver import AmazonDynamoDbConversationMemoryDriver
from .memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver

from .embedding.base_embedding_driver import BaseEmbeddingDriver
from .embedding.openai_embedding_driver import OpenAiEmbeddingDriver
Expand Down Expand Up @@ -123,6 +124,7 @@
"BaseConversationMemoryDriver",
"LocalConversationMemoryDriver",
"AmazonDynamoDbConversationMemoryDriver",
"RedisConversationMemoryDriver",
"BaseEmbeddingDriver",
"OpenAiEmbeddingDriver",
"AzureOpenAiEmbeddingDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations
import uuid
from attr import define, field, Factory
from typing import Optional, TYPE_CHECKING
from griptape.drivers import BaseConversationMemoryDriver
from griptape.memory.structure import BaseConversationMemory
from griptape.utils.import_utils import import_optional_dependency

if TYPE_CHECKING:
from redis import Redis


@define
class RedisConversationMemoryDriver(BaseConversationMemoryDriver):
"""A Conversation Memory Driver for Redis.

This driver interfaces with a Redis instance and utilizes the Redis hashes and RediSearch module to store,
retrieve, and query conversations in a structured manner.
Proper setup of the Redis instance and RediSearch is necessary for the driver to function correctly.

Attributes:
host: The host of the Redis instance.
port: The port of the Redis instance.
db: The database of the Redis instance.
password: The password of the Redis instance.
index: The name of the index to use.
conversation_id: The id of the conversation.
"""

host: str = field(kw_only=True, metadata={"serializable": True})
port: int = field(kw_only=True, metadata={"serializable": True})
db: int = field(kw_only=True, default=0, metadata={"serializable": True})
password: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
index: str = field(kw_only=True, metadata={"serializable": True})
conversation_id: str = field(kw_only=True, default=uuid.uuid4().hex)

client: Redis = field(
default=Factory(
lambda self: import_optional_dependency("redis").Redis(
host=self.host, port=self.port, db=self.db, password=self.password, decode_responses=False
),
takes_self=True,
)
)

def store(self, memory: BaseConversationMemory) -> None:
self.client.hset(self.index, self.conversation_id, memory.to_json())

def load(self) -> Optional[BaseConversationMemory]:
key = self.index
memory_json = self.client.hget(key, self.conversation_id)
if memory_json:
memory = BaseConversationMemory.from_json(memory_json)
memory.driver = self
return memory
return None
27 changes: 25 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ drivers-sql-snowflake = ["snowflake-sqlalchemy", "snowflake", "snowflake-connect
drivers-sql-postgres = ["pgvector", "psycopg2-binary"]

drivers-memory-conversation-amazon-dynamodb = ["boto3"]
drivers-memory-conversation-redis = ["redis"]

drivers-vector-marqo = ["marqo"]
drivers-vector-pinecone = ["pinecone-client"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
import redis
from griptape.memory.structure.base_conversation_memory import BaseConversationMemory
from griptape.drivers.memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver

TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": "Hi There, Hello", "output": "Hello! How can I assist you today?"}], "max_runs": 2}'
CONVERSATION_ID = "117151897f344ff684b553d0655d8f39"
INDEX = "griptape_converstaion"
HOST = "127.0.0.1"
PORT = 6379
PASSWORD = ""


class TestRedisConversationMemoryDriver:
@pytest.fixture(autouse=True)
def mock_redis(self, mocker):
mocker.patch.object(redis.StrictRedis, "hset", return_value=None)
mocker.patch.object(redis.StrictRedis, "keys", return_value=[b"test"])
mocker.patch.object(redis.StrictRedis, "hget", return_value=TEST_CONVERSATION)

fake_redisearch = mocker.MagicMock()
fake_redisearch.search = mocker.MagicMock(return_value=mocker.MagicMock(docs=[]))
fake_redisearch.info = mocker.MagicMock(side_effect=Exception("Index not found"))
fake_redisearch.create_index = mocker.MagicMock(return_value=None)

mocker.patch.object(redis.StrictRedis, "ft", return_value=fake_redisearch)

@pytest.fixture
def driver(self):
return RedisConversationMemoryDriver(host=HOST, port=PORT, db=0, index=INDEX, conversation_id=CONVERSATION_ID)

def test_store(self, driver):
memory = BaseConversationMemory.from_json(TEST_CONVERSATION)
assert driver.store(memory) == None

def test_load(self, driver):
memory = driver.load()
assert memory.type == "ConversationMemory"
assert memory.max_runs == 2
assert memory.runs == BaseConversationMemory.from_json(TEST_CONVERSATION).runs
Loading