From 7be933f6b8ef7fdce19ace91f97920dea95f72eb Mon Sep 17 00:00:00 2001 From: Pratyush Verma Date: Thu, 12 Sep 2024 17:39:07 +0100 Subject: [PATCH] Use correct serde in postgres checkpoint --- .../langgraph/checkpoint/postgres/aio.py | 10 +++++++--- .../langgraph/checkpoint/postgres/base.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py index 9c246ef74..9aa08ad0f 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py @@ -57,7 +57,11 @@ def __init__( @classmethod @asynccontextmanager async def from_conn_string( - cls, conn_string: str, *, pipeline: bool = False + cls, + conn_string: str, + *, + pipeline: bool = False, + serde: Optional[SerializerProtocol] = None, ) -> AsyncIterator["AsyncPostgresSaver"]: """Create a new PostgresSaver instance from a connection string. @@ -73,9 +77,9 @@ async def from_conn_string( ) as conn: if pipeline: async with conn.pipeline() as pipe: - yield AsyncPostgresSaver(conn, pipe) + yield AsyncPostgresSaver(conn=conn, pipe=pipe, serde=serde) else: - yield AsyncPostgresSaver(conn) + yield AsyncPostgresSaver(conn=conn, serde=serde) async def setup(self) -> None: """Set up the checkpoint database asynchronously. diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py index 0dd7d7733..e3480643b 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py @@ -224,7 +224,7 @@ def _dump_writes( ] def _load_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - return self.jsonplus_serde.loads(self.jsonplus_serde.dumps(metadata)) + return self.serde.loads(self.serde.dumps(metadata)) def _dump_metadata(self, metadata) -> str: serialized_metadata_type, serialized_metadata = self.jsonplus_serde.dumps_typed(