diff --git a/burr/integrations/persisters/b_mongodb.py b/burr/integrations/persisters/b_mongodb.py index 4913f087..1aaa8bea 100644 --- a/burr/integrations/persisters/b_mongodb.py +++ b/burr/integrations/persisters/b_mongodb.py @@ -10,14 +10,16 @@ logger = logging.getLogger(__name__) -class MongoDBPersister(persistence.BaseStatePersister): +class MongoDBBasePersister(persistence.BaseStatePersister): """A class used to represent a MongoDB Persister. Example usage: .. code-block:: python - persister = MongoDBPersister(uri='mongodb://user:pass@localhost:27017', db_name='mydatabase', collection_name='mystates') + persister = MongoDBBasePersister.from_values(uri='mongodb://user:pass@localhost:27017', + db_name='mydatabase', + collection_name='mystates') persister.save( partition_key='example_partition', app_id='example_app', @@ -28,20 +30,46 @@ class MongoDBPersister(persistence.BaseStatePersister): ) loaded_state = persister.load(partition_key='example_partition', app_id='example_app', sequence_id=1) print(loaded_state) + + Note: this is called MongoDBBasePersister because we had to change the constructor and wanted to make + this change backwards compatible. """ - def __init__( - self, + @classmethod + def from_values( + cls, uri="mongodb://localhost:27017", db_name="mydatabase", collection_name="mystates", serde_kwargs: dict = None, mongo_client_kwargs: dict = None, - ): - """Initializes the MongoDBPersister class.""" + ) -> "MongoDBBasePersister": + """Initializes the MongoDBBasePersister class.""" if mongo_client_kwargs is None: mongo_client_kwargs = {} - self.client = MongoClient(uri, **mongo_client_kwargs) + client = MongoClient(uri, **mongo_client_kwargs) + return cls( + client=client, + db_name=db_name, + collection_name=collection_name, + serde_kwargs=serde_kwargs, + ) + + def __init__( + self, + client, + db_name="mydatabase", + collection_name="mystates", + serde_kwargs: dict = None, + ): + """Initializes the MongoDBBasePersister class. + + :param client: the mongodb client to use + :param db_name: the name of the database to use + :param collection_name: the name of the collection to use + :param serde_kwargs: serializer/deserializer keyword arguments to pass to the state object + """ + self.client = client self.db = self.client[db_name] self.collection = self.db[collection_name] self.serde_kwargs = serde_kwargs or {} @@ -101,3 +129,29 @@ def save( def __del__(self): self.client.close() + + +class MongoDBPersister(MongoDBBasePersister): + """A class used to represent a MongoDB Persister. + + This class is deprecated. Please use MongoDBBasePersister instead. + """ + + def __init__( + self, + uri="mongodb://localhost:27017", + db_name="mydatabase", + collection_name="mystates", + serde_kwargs: dict = None, + mongo_client_kwargs: dict = None, + ): + """Initializes the MongoDBPersister class.""" + if mongo_client_kwargs is None: + mongo_client_kwargs = {} + client = MongoClient(uri, **mongo_client_kwargs) + super(MongoDBPersister, self).__init__( + client=client, + db_name=db_name, + collection_name=collection_name, + serde_kwargs=serde_kwargs, + ) diff --git a/docs/reference/persister.rst b/docs/reference/persister.rst index 3b37a23e..bd0e7f3d 100644 --- a/docs/reference/persister.rst +++ b/docs/reference/persister.rst @@ -53,7 +53,7 @@ Currently we support the following, although we highly recommend you contribute .. automethod:: __init__ -.. autoclass:: burr.integrations.persisters.b_mongodb.MongoDBPersister +.. autoclass:: burr.integrations.persisters.b_mongodb.MongoDBBasePersister :members: .. automethod:: __init__ diff --git a/tests/integrations/persisters/test_b_mongodb.py b/tests/integrations/persisters/test_b_mongodb.py index cf57211e..7a665aa9 100644 --- a/tests/integrations/persisters/test_b_mongodb.py +++ b/tests/integrations/persisters/test_b_mongodb.py @@ -3,7 +3,7 @@ import pytest from burr.core import state -from burr.integrations.persisters.b_mongodb import MongoDBPersister +from burr.integrations.persisters.b_mongodb import MongoDBBasePersister, MongoDBPersister if not os.environ.get("BURR_CI_INTEGRATION_TESTS") == "true": pytest.skip("Skipping integration tests", allow_module_level=True) @@ -11,7 +11,7 @@ @pytest.fixture def mongodb_persister(): - persister = MongoDBPersister( + persister = MongoDBBasePersister.from_values( uri="mongodb://localhost:27017", db_name="testdb", collection_name="testcollection" ) yield persister @@ -35,3 +35,14 @@ def test_list_app_ids(mongodb_persister): def test_load_nonexistent_key(mongodb_persister): state_data = mongodb_persister.load("pk", "nonexistent_key") assert state_data is None + + +def test_backwards_compatible_persister(): + persister = MongoDBPersister( + uri="mongodb://localhost:27017", db_name="testdb", collection_name="backwardscompatible" + ) + persister.save("pk", "app_id", 5, "pos", state.State({"a": 5, "b": 5}), "completed") + data = persister.load("pk", "app_id", 5) + assert data["state"].get_all() == {"a": 5, "b": 5} + + persister.collection.drop()