diff --git a/burr/core/application.py b/burr/core/application.py index f9aa54a2..d92878c9 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -2248,9 +2248,19 @@ def with_state_persister( """ if on_every != "step": raise ValueError(f"on_every {on_every} not supported") + if not isinstance(persister, persistence.BaseStateSaver): self.lifecycle_adapters.append(persister) else: + # Check if 'is_initialized' exists and is False; raise RuntimeError, else continue if not implemented + try: + if not persister.is_initialized(): + raise RuntimeError( + "RuntimeError: Uninitialized persister. Make sure to call .initialize() before passing it to " + "the ApplicationBuilder." + ) + except NotImplementedError: + pass self.lifecycle_adapters.append(persistence.PersisterHook(persister)) return self diff --git a/burr/core/persistence.py b/burr/core/persistence.py index 22c26174..36c8a23c 100644 --- a/burr/core/persistence.py +++ b/burr/core/persistence.py @@ -54,6 +54,10 @@ def initialize(self): """Initializes the app for saving, set up any databases, etc.. you want to here.""" pass + def is_initialized(self) -> bool: + """Check if the persister has been initialized appropriately.""" + raise NotImplementedError("Implement this method in your subclass if you need to.") + @abc.abstractmethod def save( self, @@ -88,8 +92,6 @@ class BaseStatePersister(BaseStateLoader, BaseStateSaver, metaclass=ABCMeta): Extend this class if you want an easy way to implement custom state storage. """ - pass - class PersisterHook(PostRunStepHook): """Wrapper class for bridging the persistence interface with lifecycle hooks. This is used internally.""" @@ -164,6 +166,7 @@ def __init__( db_path, **connect_kwargs if connect_kwargs is not None else {} ) self.serde_kwargs = serde_kwargs or {} + self._initialized = False def create_table_if_not_exists(self, table_name: str): """Helper function to create the table where things are stored if it doesn't exist.""" @@ -192,6 +195,21 @@ def initialize(self): """Creates the table if it doesn't exist""" # Usage self.create_table_if_not_exists(self.table_name) + self._initialized = True + + def is_initialized(self) -> bool: + """This checks to see if the table has been created in the database or not. + It defaults to using the initialized field, else queries the database to see if the table exists. + It then sets the initialized field to True if the table exists. + """ + if self._initialized: + return True + cursor = self.connection.cursor() + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (self.table_name,) + ) + self._initialized = cursor.fetchone() is not None + return self._initialized def list_app_ids(self, partition_key: Optional[str], **kwargs) -> list[str]: partition_key = ( diff --git a/burr/integrations/persisters/postgresql.py b/burr/integrations/persisters/postgresql.py index b3263b6b..c5d152c2 100644 --- a/burr/integrations/persisters/postgresql.py +++ b/burr/integrations/persisters/postgresql.py @@ -84,6 +84,7 @@ def __init__(self, connection, table_name: str = "burr_state", serde_kwargs: dic self.table_name = table_name self.connection = connection self.serde_kwargs = serde_kwargs or {} + self._initialized = False def set_serde_kwargs(self, serde_kwargs: dict): """Sets the serde_kwargs for the persister.""" @@ -115,14 +116,30 @@ def create_table(self, table_name: str): def initialize(self): """Creates the table""" self.create_table(self.table_name) + self._initialized = True + + def is_initialized(self) -> bool: + """This checks to see if the table has been created in the database or not. + It defaults to using the initialized field, else queries the database to see if the table exists. + It then sets the initialized field to True if the table exists. + """ + if self._initialized: + return True + cursor = self.connection.cursor() + cursor.execute( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = %s)", + (self.table_name,), + ) + self._initialized = cursor.fetchone()[0] + return self._initialized def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: """Lists the app_ids for a given partition_key.""" cursor = self.connection.cursor() cursor.execute( f"SELECT DISTINCT app_id, created_at FROM {self.table_name} " - f"WHERE partition_key = %s " - f"ORDER BY created_at DESC", + "WHERE partition_key = %s " + "ORDER BY created_at DESC", (partition_key,), ) app_ids = [row[0] for row in cursor.fetchall()] diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 11256f16..8861f912 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -41,7 +41,12 @@ _validate_start, ) from burr.core.graph import Graph, GraphBuilder, Transition -from burr.core.persistence import BaseStatePersister, DevNullPersister, PersistedStateData +from burr.core.persistence import ( + BaseStatePersister, + DevNullPersister, + PersistedStateData, + SQLLitePersister, +) from burr.core.typing import TypingSystem from burr.lifecycle import ( PostRunStepHook, @@ -3230,3 +3235,46 @@ def test_builder_captures_typing_system(): _, _, state = app.run(halt_after=["result"]) assert isinstance(state.data, CounterState) assert state.data["count"] == 10 + + +def test_with_state_persister_is_not_initialized_error(tmp_path): + builder = ApplicationBuilder() + persister = SQLLitePersister(db_path=":memory:", table_name="test_table") + + with pytest.raises(RuntimeError): + # we have not initialized + builder.with_state_persister(persister) + + +def test_with_state_persister_is_initialized_not_implemented(): + builder = ApplicationBuilder() + + class FakePersister(BaseStatePersister): + # does not implement is_initialized + def list_app_ids(self): + return [] + + def save( + self, + partition_key: Optional[str], + app_id: str, + sequence_id: int, + position: str, + state: State, + status: Literal["completed", "failed"], + **kwargs, + ): + pass + + def load( + self, + partition_key: str, + app_id: Optional[str], + sequence_id: Optional[int] = None, + **kwargs, + ): + return None + + persister = FakePersister() + # Add the persister to the builder, expecting no exceptions + builder.with_state_persister(persister) diff --git a/tests/core/test_persistence.py b/tests/core/test_persistence.py index 7f7acc0d..bd43c7fb 100644 --- a/tests/core/test_persistence.py +++ b/tests/core/test_persistence.py @@ -35,6 +35,24 @@ def test_persistence_lists_app_ids(persistence): assert set(app_ids) == set(["app_id1", "app_id2"]) +def test_persistence_is_initialized_false(persistence): + assert not persistence.is_initialized() + + +def test_persistence_is_initialized_true(persistence): + persistence.initialize() + assert persistence.is_initialized() + + +def test_persistence_is_initialized_true_new_connection(tmp_path): + db_path = tmp_path / "test.db" + p = SQLLitePersister(db_path=db_path, table_name="test_table") + p.initialize() + assert p.is_initialized() + p2 = SQLLitePersister(db_path=db_path, table_name="test_table") + assert p2.is_initialized() + + @pytest.mark.parametrize( "method_name,kwargs", [ diff --git a/tests/integrations/persisters/test_postgresql.py b/tests/integrations/persisters/test_postgresql.py index 345d7bac..c45c6fa0 100644 --- a/tests/integrations/persisters/test_postgresql.py +++ b/tests/integrations/persisters/test_postgresql.py @@ -41,3 +41,29 @@ def test_list_app_ids(postgresql_persister): def test_load_nonexistent_key(postgresql_persister): state_data = postgresql_persister.load("pk", "nonexistent_key") assert state_data is None + + +def test_is_initialized(postgresql_persister): + """Tests that a new connection also returns True for is_initialized.""" + assert postgresql_persister.is_initialized() + persister2 = PostgreSQLPersister.from_values( + db_name="postgres", + user="postgres", + password="postgres", + host="localhost", + port=5432, + table_name="testtable", + ) + assert persister2.is_initialized() + + +def test_is_initialized_false(): + persister = PostgreSQLPersister.from_values( + db_name="postgres", + user="postgres", + password="postgres", + host="localhost", + port=5432, + table_name="testtable2", + ) + assert not persister.is_initialized()