Skip to content

Commit

Permalink
fix: make database saver classes inheritance-friendly (#2615)
Browse files Browse the repository at this point in the history
Replace hardcoded database saver class names with `cls` in
`from_conn_string` factory methods to improve subclassing support

## Changes
* Replaced direct class instantiations with `cls(conn)` in
`from_conn_string` classmethods across all database implementations
* Updated both synchronous and asynchronous variants for DuckDB,
PostgreSQL, and SQLite savers

## Why
This refactor makes the database saver classes more extensible by
following Python's convention of using `cls` in class methods. This
enables proper inheritance patterns where subclasses can reuse the
factory methods without needing to override them. Previously, the
hardcoded class names would always instantiate the parent class, even
when called from a subclass.

## Testing
The change is backward compatible and doesn't alter existing
functionality. All existing tests should continue to pass as this is
purely a structural refactoring that preserves the current behavior
while improving extensibility.

## Notes
This PR addresses follow up on comments from #2518 - AsyncPostgresSaver
didn't need to be fixed but many of the other DB saver classes did.
  • Loading branch information
phoenixAja authored Dec 4, 2024
1 parent 5fa196a commit aca6710
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def from_conn_string(cls, conn_string: str) -> Iterator["DuckDBSaver"]:
DuckDBSaver: A new DuckDBSaver instance.
"""
with duckdb.connect(conn_string) as conn:
yield DuckDBSaver(conn)
yield cls(conn)

def setup(self) -> None:
"""Set up the checkpoint database asynchronously.
Expand Down
2 changes: 1 addition & 1 deletion libs/checkpoint-duckdb/langgraph/checkpoint/duckdb/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def from_conn_string(
AsyncDuckDBSaver: A new AsyncDuckDBSaver instance.
"""
with duckdb.connect(conn_string) as conn:
yield AsyncDuckDBSaver(conn)
yield cls(conn)

async def setup(self) -> None:
"""Set up the checkpoint database asynchronously.
Expand Down
2 changes: 1 addition & 1 deletion libs/checkpoint-duckdb/langgraph/store/duckdb/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def from_conn_string(
AsyncDuckDBStore: A new AsyncDuckDBStore instance.
"""
with duckdb.connect(conn_string) as conn:
yield AsyncDuckDBStore(conn)
yield cls(conn)

async def setup(self) -> None:
"""Set up the store database asynchronously.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def from_conn_string(
pipeline: bool = False,
serde: Optional[SerializerProtocol] = None,
) -> AsyncIterator["AsyncPostgresSaver"]:
"""Create a new PostgresSaver instance from a connection string.
"""Create a new AsyncPostgresSaver instance from a connection string.
Args:
conn_string (str): The Postgres connection info string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def from_conn_string(cls, conn_string: str) -> Iterator["SqliteSaver"]:
check_same_thread=False,
)
) as conn:
yield SqliteSaver(conn)
yield cls(conn)

def setup(self) -> None:
"""Set up the checkpoint database.
Expand Down
2 changes: 1 addition & 1 deletion libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def from_conn_string(
AsyncSqliteSaver: A new AsyncSqliteSaver instance.
"""
async with aiosqlite.connect(conn_string) as conn:
yield AsyncSqliteSaver(conn)
yield cls(conn)

def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database.
Expand Down

0 comments on commit aca6710

Please sign in to comment.