Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Redis persister #89

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions burr/core/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def load(
row = cursor.fetchone()
if row is None:
return None
_state = State(json.loads(row[1])["_state"])
_state = State(json.loads(row[1]))
return {
"partition_key": partition_key,
"app_id": row[3],
Expand Down Expand Up @@ -277,7 +277,7 @@ def save(
status,
)
cursor = self.connection.cursor()
json_state = json.dumps(state.__dict__)
json_state = json.dumps(state.get_all())
cursor.execute(
f"INSERT INTO {self.table_name} (partition_key, app_id, sequence_id, position, state, status) "
f"VALUES (?, ?, ?, ?, ?, ?)",
Expand Down
129 changes: 129 additions & 0 deletions burr/integrations/persisters/b_redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from burr.integrations import base

try:
import redis # can't name module redis because this import wouldn't work.

except ImportError as e:
base.require_plugin(e, ["redis"], "redis")

import datetime
import json
import logging
from typing import Literal, Optional

from burr.core import persistence, state

logger = logging.getLogger(__name__)


class RedisPersister(persistence.BaseStatePersister):
"""A class used to represent a Redis Persister.

This class is responsible for persisting state data to a Redis database.
It inherits from the BaseStatePersister class.
"""

def __init__(self, host: str, port: int, db: int, password: str = None):
"""Initializes the RedisPersister class.

:param host:
:param port:
:param db:
:param password:
"""
self.connection = redis.Redis(host=host, port=port, db=db, password=password)

def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
"""List the app ids for a given partition key."""
app_ids = self.connection.zrevrange(partition_key, 0, -1)
return [app_id.decode() for app_id in app_ids]

def load(
self, partition_key: str, app_id: str, sequence_id: int = None, **kwargs
) -> Optional[persistence.PersistedStateData]:
"""Load the state data for a given partition key, app id, and sequence id.

If the sequence id is not given, it will be looked up in the Redis database. If it is not found, None will be returned.

:param partition_key:
:param app_id:
:param sequence_id:
:param kwargs:
:return: Value or None.
"""
if sequence_id is None:
sequence_id = self.connection.zscore(partition_key, app_id)
if sequence_id is None:
return None
sequence_id = int(sequence_id)
key = self.create_key(app_id, partition_key, sequence_id)
data = self.connection.hgetall(key)
if not data:
return None
_state = state.State(json.loads(data[b"state"].decode()))
return {
"partition_key": partition_key,
"app_id": app_id,
"sequence_id": sequence_id,
"position": data[b"position"].decode(),
"state": _state,
"created_at": data[b"created_at"].decode(),
"status": data[b"status"].decode(),
}

def create_key(self, app_id, partition_key, sequence_id):
"""Create a key for the Redis database."""
key = f"{partition_key}:{app_id}:{sequence_id}"
return key

def save(
self,
partition_key: str,
app_id: str,
sequence_id: int,
position: str,
state: state.State,
status: Literal["completed", "failed"],
**kwargs,
):
"""Save the state data to the Redis database.

:param partition_key:
:param app_id:
:param sequence_id:
:param position:
:param state:
:param status:
:param kwargs:
:return:
"""
key = self.create_key(app_id, partition_key, sequence_id)
if self.connection.exists(key):
raise ValueError(f"partition_key:app_id:sequence_id[{key}] already exists.")
json_state = json.dumps(state.get_all())
self.connection.hset(
key,
mapping={
"partition_key": partition_key,
"app_id": app_id,
"sequence_id": sequence_id,
"position": position,
"state": json_state,
"status": status,
"created_at": datetime.datetime.utcnow().isoformat(),
},
)
self.connection.zadd(partition_key, {app_id: sequence_id})

def __del__(self):
self.connection.close()


if __name__ == "__main__":
# test the RedisPersister class
persister = RedisPersister("localhost", 6379, 0)

persister.initialize()
persister.save("pk", "app_id", 2, "pos", state.State({"a": 1, "b": 2}), "completed")
print(persister.list_app_ids("pk"))
print(persister.load("pk", "app_id"))
4 changes: 2 additions & 2 deletions burr/integrations/persisters/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def load(
row = cursor.fetchone()
if row is None:
return None
_state = state.State(row[1]["_state"])
_state = state.State(row[1])
return {
"partition_key": partition_key,
"app_id": row[3],
Expand Down Expand Up @@ -208,7 +208,7 @@ def save(
status,
)
cursor = self.connection.cursor()
json_state = json.dumps(state.__dict__)
json_state = json.dumps(state.get_all())
cursor.execute(
f"INSERT INTO {self.table_name} (partition_key, app_id, sequence_id, position, state, status) "
"VALUES (%s, %s, %s, %s, %s, %s)",
Expand Down
6 changes: 6 additions & 0 deletions docs/reference/persister.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,11 @@ Currently we support the following, although we highly recommend you contribute
.. automethod:: __init__


.. autoclass:: burr.integrations.persisters.b_redis.RedisPersister
:members:

.. automethod:: __init__


Note that the :py:class:`LocalTrackingClient <burr.tracking.client.LocalTrackingClient>` leverages the :py:class:`BaseStateLoader <burr.core.persistence.BaseStateLoader>` to allow loading state,
although it uses different mechanisms to save state (as it tracks more than just state).
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ postgresql = [
"psycopg2-binary"
]

redis = [
"redis"
]

tests = [
"pytest",
"pytest-asyncio",
Expand Down
Loading