Skip to content

Commit

Permalink
added suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
arpitgupta-it committed Nov 16, 2024
1 parent f8e3145 commit 101b88b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
7 changes: 4 additions & 3 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,13 +2208,14 @@ def with_state_persister(
if on_every != "step":
raise ValueError(f"on_every {on_every} not supported")

# Check if 'is_initialized' exists and whether it returns False, indicating the persister is uninitialized
if hasattr(persister, 'is_initialized'):
# Check if 'is_initialized' exists and is False; raise RuntimeError, else continue if not implemented
try:
if not persister.is_initialized():
raise RuntimeError(
"RuntimeError: Uninitialized persister. Make sure to call .initialize() before passing it to the ApplicationBuilder"
)
# If the persister is valid and initialized, add it to lifecycle adapters
except NotImplementedError:
pass
if not isinstance(persister, persistence.BaseStateSaver):
self.lifecycle_adapters.append(persister)
else:
Expand Down
14 changes: 13 additions & 1 deletion tests/core/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
_validate_start,
)
from burr.core.graph import Graph, GraphBuilder, Transition
from burr.core.persistence import BaseStatePersister, DevNullPersister, PersistedStateData
from burr.core.persistence import BaseStatePersister, DevNullPersister, PersistedStateData, SQLLitePersister
from burr.core.typing import TypingSystem
from burr.lifecycle import (
PostRunStepHook,
Expand Down Expand Up @@ -3230,3 +3230,15 @@ def test_builder_captures_typing_system():
_, _, state = app.run(halt_after=["result"])
assert isinstance(state.data, CounterState)
assert state.data["count"] == 10

# Define a mock persister that does not implement is_initialized
class PersisterWithoutIsInitialized:
def initialize(self):
pass

def test_with_state_persister_no_is_initialized_method():
builder = ApplicationBuilder()
persister = PersisterWithoutIsInitialized()

# Add the persister to the builder, expecting no exceptions
builder.with_state_persister(persister) # No exception should be raised

0 comments on commit 101b88b

Please sign in to comment.