Skip to content

Commit

Permalink
Add RedisConversationMemoryDriver (#787)
Browse files Browse the repository at this point in the history
Co-authored-by: torabshaikh <[email protected]>
  • Loading branch information
vachillo and torabshaikh authored May 16, 2024
1 parent 44a2c62 commit f1a2dba
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added
- `RedisConversationMemoryDriver` to save conversation memory in redis.

### Changed
- Default behavior of OpenAiStructureConfig to utilize `gpt-4o` for prompt_driver.

Expand Down
30 changes: 30 additions & 0 deletions docs/griptape-framework/drivers/conversation-memory-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,33 @@ agent = Agent(conversation_memory=ConversationMemory(driver=dynamodb_driver))
agent.run("My name is Jeff.")
agent.run("What is my name?")
```


### Redis Conversation Memory Driver

!!! info
This driver requires the `drivers-memory-conversation-redis` [extra](../index.md#extras).

The [RedisConversationMemoryDriver](../../reference/griptape/drivers/memory/conversation/redis_conversation_memory_driver.md) allows you to persist Conversation Memory in [Redis](https://redis.io/).

```python
import os
import uuid
from griptape.drivers import RedisConversationMemoryDriver
from griptape.memory.structure import ConversationMemory
from griptape.structures import Agent

conversation_id = uuid.uuid4().hex
redis_conversation_driver = RedisConversationMemoryDriver(
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
index=os.environ["REDIS_INDEX"],
conversation_id = conversation_id
)

agent = Agent(conversation_memory=ConversationMemory(driver=redis_conversation_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
```
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver
from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver
from .memory.conversation.amazon_dynamodb_conversation_memory_driver import AmazonDynamoDbConversationMemoryDriver
from .memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver

from .embedding.base_embedding_driver import BaseEmbeddingDriver
from .embedding.openai_embedding_driver import OpenAiEmbeddingDriver
Expand Down Expand Up @@ -123,6 +124,7 @@
"BaseConversationMemoryDriver",
"LocalConversationMemoryDriver",
"AmazonDynamoDbConversationMemoryDriver",
"RedisConversationMemoryDriver",
"BaseEmbeddingDriver",
"OpenAiEmbeddingDriver",
"AzureOpenAiEmbeddingDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations
import uuid
from attr import define, field, Factory
from typing import Optional, TYPE_CHECKING
from griptape.drivers import BaseConversationMemoryDriver
from griptape.memory.structure import BaseConversationMemory
from griptape.utils.import_utils import import_optional_dependency

if TYPE_CHECKING:
from redis import Redis


@define
class RedisConversationMemoryDriver(BaseConversationMemoryDriver):
"""A Conversation Memory Driver for Redis.
This driver interfaces with a Redis instance and utilizes the Redis hashes and RediSearch module to store,
retrieve, and query conversations in a structured manner.
Proper setup of the Redis instance and RediSearch is necessary for the driver to function correctly.
Attributes:
host: The host of the Redis instance.
port: The port of the Redis instance.
db: The database of the Redis instance.
password: The password of the Redis instance.
index: The name of the index to use.
conversation_id: The id of the conversation.
"""

host: str = field(kw_only=True, metadata={"serializable": True})
port: int = field(kw_only=True, metadata={"serializable": True})
db: int = field(kw_only=True, default=0, metadata={"serializable": True})
password: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
index: str = field(kw_only=True, metadata={"serializable": True})
conversation_id: str = field(kw_only=True, default=uuid.uuid4().hex)

client: Redis = field(
default=Factory(
lambda self: import_optional_dependency("redis").Redis(
host=self.host, port=self.port, db=self.db, password=self.password, decode_responses=False
),
takes_self=True,
)
)

def store(self, memory: BaseConversationMemory) -> None:
self.client.hset(self.index, self.conversation_id, memory.to_json())

def load(self) -> Optional[BaseConversationMemory]:
key = self.index
memory_json = self.client.hget(key, self.conversation_id)
if memory_json:
memory = BaseConversationMemory.from_json(memory_json)
memory.driver = self
return memory
return None
27 changes: 25 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ drivers-sql-snowflake = ["snowflake-sqlalchemy", "snowflake", "snowflake-connect
drivers-sql-postgres = ["pgvector", "psycopg2-binary"]

drivers-memory-conversation-amazon-dynamodb = ["boto3"]
drivers-memory-conversation-redis = ["redis"]

drivers-vector-marqo = ["marqo"]
drivers-vector-pinecone = ["pinecone-client"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
import redis
from griptape.memory.structure.base_conversation_memory import BaseConversationMemory
from griptape.drivers.memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver

TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": "Hi There, Hello", "output": "Hello! How can I assist you today?"}], "max_runs": 2}'
CONVERSATION_ID = "117151897f344ff684b553d0655d8f39"
INDEX = "griptape_converstaion"
HOST = "127.0.0.1"
PORT = 6379
PASSWORD = ""


class TestRedisConversationMemoryDriver:
@pytest.fixture(autouse=True)
def mock_redis(self, mocker):
mocker.patch.object(redis.StrictRedis, "hset", return_value=None)
mocker.patch.object(redis.StrictRedis, "keys", return_value=[b"test"])
mocker.patch.object(redis.StrictRedis, "hget", return_value=TEST_CONVERSATION)

fake_redisearch = mocker.MagicMock()
fake_redisearch.search = mocker.MagicMock(return_value=mocker.MagicMock(docs=[]))
fake_redisearch.info = mocker.MagicMock(side_effect=Exception("Index not found"))
fake_redisearch.create_index = mocker.MagicMock(return_value=None)

mocker.patch.object(redis.StrictRedis, "ft", return_value=fake_redisearch)

@pytest.fixture
def driver(self):
return RedisConversationMemoryDriver(host=HOST, port=PORT, db=0, index=INDEX, conversation_id=CONVERSATION_ID)

def test_store(self, driver):
memory = BaseConversationMemory.from_json(TEST_CONVERSATION)
assert driver.store(memory) == None

def test_load(self, driver):
memory = driver.load()
assert memory.type == "ConversationMemory"
assert memory.max_runs == 2
assert memory.runs == BaseConversationMemory.from_json(TEST_CONVERSATION).runs

0 comments on commit f1a2dba

Please sign in to comment.