Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors Redis Persister #471

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 65 additions & 17 deletions burr/integrations/persisters/b_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,50 @@
logger = logging.getLogger(__name__)


class RedisPersister(persistence.BaseStatePersister):
"""A class used to represent a Redis Persister.
class RedisBasePersister(persistence.BaseStatePersister):
"""Main class for Redis persister.

Use this class if you want to directly control injecting the Redis client.

This class is responsible for persisting state data to a Redis database.
It inherits from the BaseStatePersister class.

Note: We didn't create the right constructor for the initial implementation of the RedisPersister class,
so this is an attempt to fix that in a backwards compatible way.
"""

def __init__(
self,
@classmethod
def from_values(
cls,
host: str,
port: int,
db: int,
password: str = None,
serde_kwargs: dict = None,
redis_client_kwargs: dict = None,
namespace: str = None,
) -> "RedisBasePersister":
"""Creates a new instance of the RedisBasePersister from passed in values."""
if redis_client_kwargs is None:
redis_client_kwargs = {}
connection = redis.Redis(
host=host, port=port, db=db, password=password, **redis_client_kwargs
)
return cls(connection, serde_kwargs, namespace)

def __init__(
self,
connection,
serde_kwargs: dict = None,
namespace: str = None,
):
"""Initializes the RedisPersister class.

:param host:
:param port:
:param db:
:param password:
:param serde_kwargs:
:param redis_client_kwargs: Additional keyword arguments to pass to the redis.Redis client.
:param connection: the redis connection object.
:param serde_kwargs: serialization and deserialization keyword arguments to pass to state SERDE.
:param namespace: The name of the project to optionally use in the key prefix.
"""
if redis_client_kwargs is None:
redis_client_kwargs = {}
self.connection = redis.Redis(
host=host, port=port, db=db, password=password, **redis_client_kwargs
)
self.connection = connection
self.serde_kwargs = serde_kwargs or {}
self.namespace = namespace if namespace else ""

Expand Down Expand Up @@ -149,9 +161,45 @@ def __del__(self):
self.connection.close()


class RedisPersister(RedisBasePersister):
"""A class used to represent a Redis Persister.

This class is deprecated. Use RedisBasePersister.from_values() instead.
"""

def __init__(
self,
host: str,
port: int,
db: int,
password: str = None,
serde_kwargs: dict = None,
redis_client_kwargs: dict = None,
namespace: str = None,
):
"""Initializes the RedisPersister class.

This is deprecated. Use RedisBasePersister.from_values() instead.

:param host:
:param port:
:param db:
:param password:
:param serde_kwargs:
:param redis_client_kwargs: Additional keyword arguments to pass to the redis.Redis client.
:param namespace: The name of the project to optionally use in the key prefix.
"""
if redis_client_kwargs is None:
redis_client_kwargs = {}
connection = redis.Redis(
host=host, port=port, db=db, password=password, **redis_client_kwargs
)
super(RedisPersister, self).__init__(connection, serde_kwargs, namespace)


if __name__ == "__main__":
# test the RedisPersister class
persister = RedisPersister("localhost", 6379, 0)
# test the RedisBasePersister class
persister = RedisBasePersister.from_values("localhost", 6379, 0)

persister.initialize()
persister.save("pk", "app_id", 2, "pos", state.State({"a": 1, "b": 2}), "completed")
Expand Down
3 changes: 1 addition & 2 deletions docs/reference/persister.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ Currently we support the following, although we highly recommend you contribute

.. automethod:: __init__


.. autoclass:: burr.integrations.persisters.b_redis.RedisPersister
.. autoclass:: burr.integrations.persisters.b_redis.RedisBasePersister
:members:

.. automethod:: __init__
Expand Down
15 changes: 12 additions & 3 deletions tests/integrations/persisters/test_b_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
import pytest

from burr.core import state
from burr.integrations.persisters.b_redis import RedisPersister
from burr.integrations.persisters.b_redis import RedisBasePersister, RedisPersister

if not os.environ.get("BURR_CI_INTEGRATION_TESTS") == "true":
pytest.skip("Skipping integration tests", allow_module_level=True)


@pytest.fixture
def redis_persister():
persister = RedisPersister(host="localhost", port=6379, db=0)
persister = RedisBasePersister.from_values(host="localhost", port=6379, db=0)
yield persister
persister.connection.close()


@pytest.fixture
def redis_persister_with_ns():
persister = RedisPersister(host="localhost", port=6379, db=0, namespace="test")
persister = RedisBasePersister.from_values(host="localhost", port=6379, db=0, namespace="test")
yield persister
persister.connection.close()

Expand Down Expand Up @@ -61,3 +61,12 @@ def test_list_app_ids_with_ns(redis_persister_with_ns):
def test_load_nonexistent_key_with_ns(redis_persister_with_ns):
state_data = redis_persister_with_ns.load("pk", "nonexistent_key")
assert state_data is None


def test_redis_persister_class_backwards_compatible():
"""Tests that the RedisPersister class is still backwards compatible."""
persister = RedisPersister(host="localhost", port=6379, db=0, namespace="backwardscompatible")
persister.save("pk", "app_id", 2, "pos", state.State({"a": 4, "b": 5}), "completed")
data = persister.load("pk", "app_id", 2)
assert data["state"].get_all() == {"a": 4, "b": 5}
persister.connection.close()
Loading