Skip to content

Commit

Permalink
Refactors mongodb persister a little (#472)
Browse files Browse the repository at this point in the history
This is so that we can more easily test the persister,
but allowing one to inject a custom client. The old
way prevented that. This now then makes the behavior
inline with the other persisters.
  • Loading branch information
skrawcz authored Dec 13, 2024
1 parent c24cad4 commit 1c36e6c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 10 deletions.
68 changes: 61 additions & 7 deletions burr/integrations/persisters/b_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
logger = logging.getLogger(__name__)


class MongoDBPersister(persistence.BaseStatePersister):
class MongoDBBasePersister(persistence.BaseStatePersister):
"""A class used to represent a MongoDB Persister.
Example usage:
.. code-block:: python
persister = MongoDBPersister(uri='mongodb://user:pass@localhost:27017', db_name='mydatabase', collection_name='mystates')
persister = MongoDBBasePersister.from_values(uri='mongodb://user:pass@localhost:27017',
db_name='mydatabase',
collection_name='mystates')
persister.save(
partition_key='example_partition',
app_id='example_app',
Expand All @@ -28,20 +30,46 @@ class MongoDBPersister(persistence.BaseStatePersister):
)
loaded_state = persister.load(partition_key='example_partition', app_id='example_app', sequence_id=1)
print(loaded_state)
Note: this is called MongoDBBasePersister because we had to change the constructor and wanted to make
this change backwards compatible.
"""

def __init__(
self,
@classmethod
def from_values(
cls,
uri="mongodb://localhost:27017",
db_name="mydatabase",
collection_name="mystates",
serde_kwargs: dict = None,
mongo_client_kwargs: dict = None,
):
"""Initializes the MongoDBPersister class."""
) -> "MongoDBBasePersister":
"""Initializes the MongoDBBasePersister class."""
if mongo_client_kwargs is None:
mongo_client_kwargs = {}
self.client = MongoClient(uri, **mongo_client_kwargs)
client = MongoClient(uri, **mongo_client_kwargs)
return cls(
client=client,
db_name=db_name,
collection_name=collection_name,
serde_kwargs=serde_kwargs,
)

def __init__(
self,
client,
db_name="mydatabase",
collection_name="mystates",
serde_kwargs: dict = None,
):
"""Initializes the MongoDBBasePersister class.
:param client: the mongodb client to use
:param db_name: the name of the database to use
:param collection_name: the name of the collection to use
:param serde_kwargs: serializer/deserializer keyword arguments to pass to the state object
"""
self.client = client
self.db = self.client[db_name]
self.collection = self.db[collection_name]
self.serde_kwargs = serde_kwargs or {}
Expand Down Expand Up @@ -101,3 +129,29 @@ def save(

def __del__(self):
self.client.close()


class MongoDBPersister(MongoDBBasePersister):
"""A class used to represent a MongoDB Persister.
This class is deprecated. Please use MongoDBBasePersister instead.
"""

def __init__(
self,
uri="mongodb://localhost:27017",
db_name="mydatabase",
collection_name="mystates",
serde_kwargs: dict = None,
mongo_client_kwargs: dict = None,
):
"""Initializes the MongoDBPersister class."""
if mongo_client_kwargs is None:
mongo_client_kwargs = {}
client = MongoClient(uri, **mongo_client_kwargs)
super(MongoDBPersister, self).__init__(
client=client,
db_name=db_name,
collection_name=collection_name,
serde_kwargs=serde_kwargs,
)
2 changes: 1 addition & 1 deletion docs/reference/persister.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Currently we support the following, although we highly recommend you contribute

.. automethod:: __init__

.. autoclass:: burr.integrations.persisters.b_mongodb.MongoDBPersister
.. autoclass:: burr.integrations.persisters.b_mongodb.MongoDBBasePersister
:members:

.. automethod:: __init__
Expand Down
15 changes: 13 additions & 2 deletions tests/integrations/persisters/test_b_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import pytest

from burr.core import state
from burr.integrations.persisters.b_mongodb import MongoDBPersister
from burr.integrations.persisters.b_mongodb import MongoDBBasePersister, MongoDBPersister

if not os.environ.get("BURR_CI_INTEGRATION_TESTS") == "true":
pytest.skip("Skipping integration tests", allow_module_level=True)


@pytest.fixture
def mongodb_persister():
persister = MongoDBPersister(
persister = MongoDBBasePersister.from_values(
uri="mongodb://localhost:27017", db_name="testdb", collection_name="testcollection"
)
yield persister
Expand All @@ -35,3 +35,14 @@ def test_list_app_ids(mongodb_persister):
def test_load_nonexistent_key(mongodb_persister):
state_data = mongodb_persister.load("pk", "nonexistent_key")
assert state_data is None


def test_backwards_compatible_persister():
persister = MongoDBPersister(
uri="mongodb://localhost:27017", db_name="testdb", collection_name="backwardscompatible"
)
persister.save("pk", "app_id", 5, "pos", state.State({"a": 5, "b": 5}), "completed")
data = persister.load("pk", "app_id", 5)
assert data["state"].get_all() == {"a": 5, "b": 5}

persister.collection.drop()

0 comments on commit 1c36e6c

Please sign in to comment.