diff --git a/CHANGELOG.md b/CHANGELOG.md index c24f744c0..9f237ce6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added +- `RedisConversationMemoryDriver` to save conversation memory in redis. + ### Changed - Default behavior of OpenAiStructureConfig to utilize `gpt-4o` for prompt_driver. diff --git a/docs/griptape-framework/drivers/conversation-memory-drivers.md b/docs/griptape-framework/drivers/conversation-memory-drivers.md index d429be527..4c3de1e65 100644 --- a/docs/griptape-framework/drivers/conversation-memory-drivers.md +++ b/docs/griptape-framework/drivers/conversation-memory-drivers.md @@ -45,3 +45,33 @@ agent = Agent(conversation_memory=ConversationMemory(driver=dynamodb_driver)) agent.run("My name is Jeff.") agent.run("What is my name?") ``` + + +### Redis Conversation Memory Driver + +!!! info + This driver requires the `drivers-memory-conversation-redis` [extra](../index.md#extras). + +The [RedisConversationMemoryDriver](../../reference/griptape/drivers/memory/conversation/redis_conversation_memory_driver.md) allows you to persist Conversation Memory in [Redis](https://redis.io/). + +```python +import os +import uuid +from griptape.drivers import RedisConversationMemoryDriver +from griptape.memory.structure import ConversationMemory +from griptape.structures import Agent + +conversation_id = uuid.uuid4().hex +redis_conversation_driver = RedisConversationMemoryDriver( + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + index=os.environ["REDIS_INDEX"], + conversation_id = conversation_id +) + +agent = Agent(conversation_memory=ConversationMemory(driver=redis_conversation_driver)) + +agent.run("My name is Jeff.") +agent.run("What is my name?") +``` diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index fdc5cdb59..7941ed61e 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -16,6 +16,7 @@ from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver from .memory.conversation.amazon_dynamodb_conversation_memory_driver import AmazonDynamoDbConversationMemoryDriver +from .memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver from .embedding.base_embedding_driver import BaseEmbeddingDriver from .embedding.openai_embedding_driver import OpenAiEmbeddingDriver @@ -123,6 +124,7 @@ "BaseConversationMemoryDriver", "LocalConversationMemoryDriver", "AmazonDynamoDbConversationMemoryDriver", + "RedisConversationMemoryDriver", "BaseEmbeddingDriver", "OpenAiEmbeddingDriver", "AzureOpenAiEmbeddingDriver", diff --git a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py new file mode 100644 index 000000000..0531d5f9d --- /dev/null +++ b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py @@ -0,0 +1,56 @@ +from __future__ import annotations +import uuid +from attr import define, field, Factory +from typing import Optional, TYPE_CHECKING +from griptape.drivers import BaseConversationMemoryDriver +from griptape.memory.structure import BaseConversationMemory +from griptape.utils.import_utils import import_optional_dependency + +if TYPE_CHECKING: + from redis import Redis + + +@define +class RedisConversationMemoryDriver(BaseConversationMemoryDriver): + """A Conversation Memory Driver for Redis. + + This driver interfaces with a Redis instance and utilizes the Redis hashes and RediSearch module to store, + retrieve, and query conversations in a structured manner. + Proper setup of the Redis instance and RediSearch is necessary for the driver to function correctly. + + Attributes: + host: The host of the Redis instance. + port: The port of the Redis instance. + db: The database of the Redis instance. + password: The password of the Redis instance. + index: The name of the index to use. + conversation_id: The id of the conversation. + """ + + host: str = field(kw_only=True, metadata={"serializable": True}) + port: int = field(kw_only=True, metadata={"serializable": True}) + db: int = field(kw_only=True, default=0, metadata={"serializable": True}) + password: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) + index: str = field(kw_only=True, metadata={"serializable": True}) + conversation_id: str = field(kw_only=True, default=uuid.uuid4().hex) + + client: Redis = field( + default=Factory( + lambda self: import_optional_dependency("redis").Redis( + host=self.host, port=self.port, db=self.db, password=self.password, decode_responses=False + ), + takes_self=True, + ) + ) + + def store(self, memory: BaseConversationMemory) -> None: + self.client.hset(self.index, self.conversation_id, memory.to_json()) + + def load(self) -> Optional[BaseConversationMemory]: + key = self.index + memory_json = self.client.hget(key, self.conversation_id) + if memory_json: + memory = BaseConversationMemory.from_json(memory_json) + memory.driver = self + return memory + return None diff --git a/poetry.lock b/poetry.lock index c66d806c6..ccb990264 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -3746,6 +3746,7 @@ files = [ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"}, {file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"}, {file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"}, + {file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"}, @@ -4802,30 +4803,51 @@ description = "Database Abstraction Library" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ + {file = "SQLAlchemy-1.4.51-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:1a09d5bd1a40d76ad90e5570530e082ddc000e1d92de495746f6257dc08f166b"}, {file = "SQLAlchemy-1.4.51-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2be4e6294c53f2ec8ea36486b56390e3bcaa052bf3a9a47005687ccf376745d1"}, {file = "SQLAlchemy-1.4.51-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ca484ca11c65e05639ffe80f20d45e6be81fbec7683d6c9a15cd421e6e8b340"}, {file = "SQLAlchemy-1.4.51-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:0535d5b57d014d06ceeaeffd816bb3a6e2dddeb670222570b8c4953e2d2ea678"}, {file = "SQLAlchemy-1.4.51-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af55cc207865d641a57f7044e98b08b09220da3d1b13a46f26487cc2f898a072"}, + {file = "SQLAlchemy-1.4.51-cp310-cp310-win32.whl", hash = "sha256:7af40425ac535cbda129d9915edcaa002afe35d84609fd3b9d6a8c46732e02ee"}, + {file = "SQLAlchemy-1.4.51-cp310-cp310-win_amd64.whl", hash = "sha256:8d1d7d63e5d2f4e92a39ae1e897a5d551720179bb8d1254883e7113d3826d43c"}, + {file = "SQLAlchemy-1.4.51-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:eaeeb2464019765bc4340214fca1143081d49972864773f3f1e95dba5c7edc7d"}, {file = "SQLAlchemy-1.4.51-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7deeae5071930abb3669b5185abb6c33ddfd2398f87660fafdb9e6a5fb0f3f2f"}, {file = "SQLAlchemy-1.4.51-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0892e7ac8bc76da499ad3ee8de8da4d7905a3110b952e2a35a940dab1ffa550e"}, + {file = "SQLAlchemy-1.4.51-cp311-cp311-win32.whl", hash = "sha256:50e074aea505f4427151c286955ea025f51752fa42f9939749336672e0674c81"}, + {file = "SQLAlchemy-1.4.51-cp311-cp311-win_amd64.whl", hash = "sha256:3b0cd89a7bd03f57ae58263d0f828a072d1b440c8c2949f38f3b446148321171"}, + {file = "SQLAlchemy-1.4.51-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a33cb3f095e7d776ec76e79d92d83117438b6153510770fcd57b9c96f9ef623d"}, {file = "SQLAlchemy-1.4.51-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6cacc0b2dd7d22a918a9642fc89840a5d3cee18a0e1fe41080b1141b23b10916"}, {file = "SQLAlchemy-1.4.51-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:245c67c88e63f1523e9216cad6ba3107dea2d3ee19adc359597a628afcabfbcb"}, + {file = "SQLAlchemy-1.4.51-cp312-cp312-win32.whl", hash = "sha256:8e702e7489f39375601c7ea5a0bef207256828a2bc5986c65cb15cd0cf097a87"}, + {file = "SQLAlchemy-1.4.51-cp312-cp312-win_amd64.whl", hash = "sha256:0525c4905b4b52d8ccc3c203c9d7ab2a80329ffa077d4bacf31aefda7604dc65"}, + {file = "SQLAlchemy-1.4.51-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:1980e6eb6c9be49ea8f89889989127daafc43f0b1b6843d71efab1514973cca0"}, {file = "SQLAlchemy-1.4.51-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ec7a0ed9b32afdf337172678a4a0e6419775ba4e649b66f49415615fa47efbd"}, {file = "SQLAlchemy-1.4.51-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:352df882088a55293f621328ec33b6ffca936ad7f23013b22520542e1ab6ad1b"}, {file = "SQLAlchemy-1.4.51-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:86a22143a4001f53bf58027b044da1fb10d67b62a785fc1390b5c7f089d9838c"}, {file = "SQLAlchemy-1.4.51-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c37bc677690fd33932182b85d37433845de612962ed080c3e4d92f758d1bd894"}, + {file = "SQLAlchemy-1.4.51-cp36-cp36m-win32.whl", hash = "sha256:d0a83afab5e062abffcdcbcc74f9d3ba37b2385294dd0927ad65fc6ebe04e054"}, + {file = "SQLAlchemy-1.4.51-cp36-cp36m-win_amd64.whl", hash = "sha256:a61184c7289146c8cff06b6b41807c6994c6d437278e72cf00ff7fe1c7a263d1"}, + {file = "SQLAlchemy-1.4.51-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:3f0ef620ecbab46e81035cf3dedfb412a7da35340500ba470f9ce43a1e6c423b"}, {file = "SQLAlchemy-1.4.51-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c55040d8ea65414de7c47f1a23823cd9f3fad0dc93e6b6b728fee81230f817b"}, {file = "SQLAlchemy-1.4.51-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38ef80328e3fee2be0a1abe3fe9445d3a2e52a1282ba342d0dab6edf1fef4707"}, {file = "SQLAlchemy-1.4.51-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f8cafa6f885a0ff5e39efa9325195217bb47d5929ab0051636610d24aef45ade"}, {file = "SQLAlchemy-1.4.51-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8f2df79a46e130235bc5e1bbef4de0583fb19d481eaa0bffa76e8347ea45ec6"}, + {file = "SQLAlchemy-1.4.51-cp37-cp37m-win32.whl", hash = "sha256:f2e5b6f5cf7c18df66d082604a1d9c7a2d18f7d1dbe9514a2afaccbb51cc4fc3"}, + {file = "SQLAlchemy-1.4.51-cp37-cp37m-win_amd64.whl", hash = "sha256:5e180fff133d21a800c4f050733d59340f40d42364fcb9d14f6a67764bdc48d2"}, + {file = "SQLAlchemy-1.4.51-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:7d8139ca0b9f93890ab899da678816518af74312bb8cd71fb721436a93a93298"}, {file = "SQLAlchemy-1.4.51-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb18549b770351b54e1ab5da37d22bc530b8bfe2ee31e22b9ebe650640d2ef12"}, {file = "SQLAlchemy-1.4.51-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55e699466106d09f028ab78d3c2e1f621b5ef2c8694598242259e4515715da7c"}, {file = "SQLAlchemy-1.4.51-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2ad16880ccd971ac8e570550fbdef1385e094b022d6fc85ef3ce7df400dddad3"}, {file = "SQLAlchemy-1.4.51-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b97fd5bb6b7c1a64b7ac0632f7ce389b8ab362e7bd5f60654c2a418496be5d7f"}, + {file = "SQLAlchemy-1.4.51-cp38-cp38-win32.whl", hash = "sha256:cecb66492440ae8592797dd705a0cbaa6abe0555f4fa6c5f40b078bd2740fc6b"}, + {file = "SQLAlchemy-1.4.51-cp38-cp38-win_amd64.whl", hash = "sha256:39b02b645632c5fe46b8dd30755682f629ffbb62ff317ecc14c998c21b2896ff"}, + {file = "SQLAlchemy-1.4.51-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:b03850c290c765b87102959ea53299dc9addf76ca08a06ea98383348ae205c99"}, {file = "SQLAlchemy-1.4.51-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e646b19f47d655261b22df9976e572f588185279970efba3d45c377127d35349"}, {file = "SQLAlchemy-1.4.51-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3cf56cc36d42908495760b223ca9c2c0f9f0002b4eddc994b24db5fcb86a9e4"}, {file = "SQLAlchemy-1.4.51-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:0d661cff58c91726c601cc0ee626bf167b20cc4d7941c93c5f3ac28dc34ddbea"}, {file = "SQLAlchemy-1.4.51-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3823dda635988e6744d4417e13f2e2b5fe76c4bf29dd67e95f98717e1b094cad"}, + {file = "SQLAlchemy-1.4.51-cp39-cp39-win32.whl", hash = "sha256:b00cf0471888823b7a9f722c6c41eb6985cf34f077edcf62695ac4bed6ec01ee"}, + {file = "SQLAlchemy-1.4.51-cp39-cp39-win_amd64.whl", hash = "sha256:a055ba17f4675aadcda3005df2e28a86feb731fdcc865e1f6b4f209ed1225cba"}, {file = "SQLAlchemy-1.4.51.tar.gz", hash = "sha256:e7908c2025eb18394e32d65dd02d2e37e17d733cdbe7d78231c2b6d7eb20cdb9"}, ] @@ -5692,6 +5714,7 @@ drivers-embedding-voyageai = ["voyageai"] drivers-event-listener-amazon-iot = ["boto3"] drivers-event-listener-amazon-sqs = ["boto3"] drivers-memory-conversation-amazon-dynamodb = ["boto3"] +drivers-memory-conversation-redis = ["redis"] drivers-prompt-amazon-bedrock = ["anthropic", "boto3"] drivers-prompt-amazon-sagemaker = ["boto3", "transformers"] drivers-prompt-anthropic = ["anthropic"] @@ -5718,4 +5741,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "895031a580cca4467c6bb500ef92e9c84e911f3add4046ed94385893dcb643b1" +content-hash = "edfa749ceeaae8216026c5c18245dbb284383a3b6772d648319253cc9657a2d1" diff --git a/pyproject.toml b/pyproject.toml index 77ae12331..06ea541f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ drivers-sql-snowflake = ["snowflake-sqlalchemy", "snowflake", "snowflake-connect drivers-sql-postgres = ["pgvector", "psycopg2-binary"] drivers-memory-conversation-amazon-dynamodb = ["boto3"] +drivers-memory-conversation-redis = ["redis"] drivers-vector-marqo = ["marqo"] drivers-vector-pinecone = ["pinecone-client"] diff --git a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py new file mode 100644 index 000000000..dee840508 --- /dev/null +++ b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py @@ -0,0 +1,40 @@ +import pytest +import redis +from griptape.memory.structure.base_conversation_memory import BaseConversationMemory +from griptape.drivers.memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver + +TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": "Hi There, Hello", "output": "Hello! How can I assist you today?"}], "max_runs": 2}' +CONVERSATION_ID = "117151897f344ff684b553d0655d8f39" +INDEX = "griptape_converstaion" +HOST = "127.0.0.1" +PORT = 6379 +PASSWORD = "" + + +class TestRedisConversationMemoryDriver: + @pytest.fixture(autouse=True) + def mock_redis(self, mocker): + mocker.patch.object(redis.StrictRedis, "hset", return_value=None) + mocker.patch.object(redis.StrictRedis, "keys", return_value=[b"test"]) + mocker.patch.object(redis.StrictRedis, "hget", return_value=TEST_CONVERSATION) + + fake_redisearch = mocker.MagicMock() + fake_redisearch.search = mocker.MagicMock(return_value=mocker.MagicMock(docs=[])) + fake_redisearch.info = mocker.MagicMock(side_effect=Exception("Index not found")) + fake_redisearch.create_index = mocker.MagicMock(return_value=None) + + mocker.patch.object(redis.StrictRedis, "ft", return_value=fake_redisearch) + + @pytest.fixture + def driver(self): + return RedisConversationMemoryDriver(host=HOST, port=PORT, db=0, index=INDEX, conversation_id=CONVERSATION_ID) + + def test_store(self, driver): + memory = BaseConversationMemory.from_json(TEST_CONVERSATION) + assert driver.store(memory) == None + + def test_load(self, driver): + memory = driver.load() + assert memory.type == "ConversationMemory" + assert memory.max_runs == 2 + assert memory.runs == BaseConversationMemory.from_json(TEST_CONVERSATION).runs