From 14d604272dc18992bc912c1f82901506c1a451f3 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Wed, 20 Mar 2024 23:56:15 -0700 Subject: [PATCH 1/2] Adds Redis persister This is a simple implementation. It assumes that data is JSON serializable. --- burr/integrations/persisters/b_redis.py | 129 ++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 burr/integrations/persisters/b_redis.py diff --git a/burr/integrations/persisters/b_redis.py b/burr/integrations/persisters/b_redis.py new file mode 100644 index 00000000..d643ff2b --- /dev/null +++ b/burr/integrations/persisters/b_redis.py @@ -0,0 +1,129 @@ +from burr.integrations import base + +try: + import redis # can't name module redis because this import wouldn't work. + +except ImportError as e: + base.require_plugin(e, ["redis"], "redis") + +import datetime +import json +import logging +from typing import Literal, Optional + +from burr.core import persistence, state + +logger = logging.getLogger(__name__) + + +class RedisPersister(persistence.BaseStatePersister): + """A class used to represent a Redis Persister. + + This class is responsible for persisting state data to a Redis database. + It inherits from the BaseStatePersister class. + """ + + def __init__(self, host: str, port: int, db: int, password: str = None): + """Initializes the RedisPersister class. + + :param host: + :param port: + :param db: + :param password: + """ + self.connection = redis.Redis(host=host, port=port, db=db, password=password) + + def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: + """List the app ids for a given partition key.""" + app_ids = self.connection.zrevrange(partition_key, 0, -1) + return [app_id.decode() for app_id in app_ids] + + def load( + self, partition_key: str, app_id: str, sequence_id: int = None, **kwargs + ) -> Optional[persistence.PersistedStateData]: + """Load the state data for a given partition key, app id, and sequence id. + + If the sequence id is not given, it will be looked up in the Redis database. If it is not found, None will be returned. + + :param partition_key: + :param app_id: + :param sequence_id: + :param kwargs: + :return: Value or None. + """ + if sequence_id is None: + sequence_id = self.connection.zscore(partition_key, app_id) + if sequence_id is None: + return None + sequence_id = int(sequence_id) + key = self.create_key(app_id, partition_key, sequence_id) + data = self.connection.hgetall(key) + if not data: + return None + _state = state.State(json.loads(data[b"state"].decode())["_state"]) + return { + "partition_key": partition_key, + "app_id": app_id, + "sequence_id": sequence_id, + "position": data[b"position"].decode(), + "state": _state, + "created_at": data[b"created_at"].decode(), + "status": data[b"status"].decode(), + } + + def create_key(self, app_id, partition_key, sequence_id): + """Create a key for the Redis database.""" + key = f"{partition_key}:{app_id}:{sequence_id}" + return key + + def save( + self, + partition_key: str, + app_id: str, + sequence_id: int, + position: str, + state: state.State, + status: Literal["completed", "failed"], + **kwargs, + ): + """Save the state data to the Redis database. + + :param partition_key: + :param app_id: + :param sequence_id: + :param position: + :param state: + :param status: + :param kwargs: + :return: + """ + key = self.create_key(app_id, partition_key, sequence_id) + if self.connection.exists(key): + raise ValueError(f"partition_key:app_id:sequence_id[{key}] already exists.") + json_state = json.dumps(state.__dict__) + self.connection.hset( + key, + mapping={ + "partition_key": partition_key, + "app_id": app_id, + "sequence_id": sequence_id, + "position": position, + "state": json_state, + "status": status, + "created_at": datetime.datetime.utcnow().isoformat(), + }, + ) + self.connection.zadd(partition_key, {app_id: sequence_id}) + + def __del__(self): + self.connection.close() + + +if __name__ == "__main__": + # test the RedisPersister class + persister = RedisPersister("localhost", 6379, 0) + + persister.initialize() + persister.save("pk", "app_id", 1, "pos", state.State({"a": 1, "b": 2}), "completed") + print(persister.list_app_ids("pk")) + print(persister.load("pk", "app_id")) From 02592e36758127a51586121ce28937ba78022f8d Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Thu, 21 Mar 2024 10:53:12 -0700 Subject: [PATCH 2/2] Changes state serialization to use get_all() This is technically backwards incompatible, but we're so early, we can migrate people manually. Otherwise adds Redis to docs. --- burr/core/persistence.py | 4 ++-- burr/integrations/persisters/b_redis.py | 6 +++--- burr/integrations/persisters/postgresql.py | 4 ++-- docs/reference/persister.rst | 6 ++++++ pyproject.toml | 4 ++++ 5 files changed, 17 insertions(+), 7 deletions(-) diff --git a/burr/core/persistence.py b/burr/core/persistence.py index ef3a2ae6..a0d21083 100644 --- a/burr/core/persistence.py +++ b/burr/core/persistence.py @@ -229,7 +229,7 @@ def load( row = cursor.fetchone() if row is None: return None - _state = State(json.loads(row[1])["_state"]) + _state = State(json.loads(row[1])) return { "partition_key": partition_key, "app_id": row[3], @@ -277,7 +277,7 @@ def save( status, ) cursor = self.connection.cursor() - json_state = json.dumps(state.__dict__) + json_state = json.dumps(state.get_all()) cursor.execute( f"INSERT INTO {self.table_name} (partition_key, app_id, sequence_id, position, state, status) " f"VALUES (?, ?, ?, ?, ?, ?)", diff --git a/burr/integrations/persisters/b_redis.py b/burr/integrations/persisters/b_redis.py index d643ff2b..0b03fbc0 100644 --- a/burr/integrations/persisters/b_redis.py +++ b/burr/integrations/persisters/b_redis.py @@ -60,7 +60,7 @@ def load( data = self.connection.hgetall(key) if not data: return None - _state = state.State(json.loads(data[b"state"].decode())["_state"]) + _state = state.State(json.loads(data[b"state"].decode())) return { "partition_key": partition_key, "app_id": app_id, @@ -100,7 +100,7 @@ def save( key = self.create_key(app_id, partition_key, sequence_id) if self.connection.exists(key): raise ValueError(f"partition_key:app_id:sequence_id[{key}] already exists.") - json_state = json.dumps(state.__dict__) + json_state = json.dumps(state.get_all()) self.connection.hset( key, mapping={ @@ -124,6 +124,6 @@ def __del__(self): persister = RedisPersister("localhost", 6379, 0) persister.initialize() - persister.save("pk", "app_id", 1, "pos", state.State({"a": 1, "b": 2}), "completed") + persister.save("pk", "app_id", 2, "pos", state.State({"a": 1, "b": 2}), "completed") print(persister.list_app_ids("pk")) print(persister.load("pk", "app_id")) diff --git a/burr/integrations/persisters/postgresql.py b/burr/integrations/persisters/postgresql.py index 40eac979..80ce0107 100644 --- a/burr/integrations/persisters/postgresql.py +++ b/burr/integrations/persisters/postgresql.py @@ -160,7 +160,7 @@ def load( row = cursor.fetchone() if row is None: return None - _state = state.State(row[1]["_state"]) + _state = state.State(row[1]) return { "partition_key": partition_key, "app_id": row[3], @@ -208,7 +208,7 @@ def save( status, ) cursor = self.connection.cursor() - json_state = json.dumps(state.__dict__) + json_state = json.dumps(state.get_all()) cursor.execute( f"INSERT INTO {self.table_name} (partition_key, app_id, sequence_id, position, state, status) " "VALUES (%s, %s, %s, %s, %s, %s)", diff --git a/docs/reference/persister.rst b/docs/reference/persister.rst index 0f0de180..3786847a 100644 --- a/docs/reference/persister.rst +++ b/docs/reference/persister.rst @@ -47,5 +47,11 @@ Currently we support the following, although we highly recommend you contribute .. automethod:: __init__ +.. autoclass:: burr.integrations.persisters.b_redis.RedisPersister + :members: + + .. automethod:: __init__ + + Note that the :py:class:`LocalTrackingClient ` leverages the :py:class:`BaseStateLoader ` to allow loading state, although it uses different mechanisms to save state (as it tracks more than just state). diff --git a/pyproject.toml b/pyproject.toml index e73f8bcf..f8ac7698 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,10 @@ postgresql = [ "psycopg2-binary" ] +redis = [ + "redis" +] + tests = [ "pytest", "pytest-asyncio",