Skip to content

Commit

Permalink
Merge pull request #2413 from langchain-ai/vb/fix-pipeline
Browse files Browse the repository at this point in the history
checkpoint-postgres: handle cases when conn.pipeline is not supported
  • Loading branch information
nfcampos authored Nov 18, 2024
2 parents f5bb2a3 + f807b73 commit a2d6837
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
25 changes: 20 additions & 5 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Iterator, Optional, Sequence, Union

from langchain_core.runnables import RunnableConfig
from psycopg import Connection, Cursor, Pipeline
from psycopg import Capabilities, Connection, Cursor, Pipeline
from psycopg.errors import UndefinedTable
from psycopg.rows import DictRow, dict_row
from psycopg.types.json import Jsonb
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(
self.conn = conn
self.pipe = pipe
self.lock = threading.Lock()
self.supports_pipeline = Capabilities().has_pipeline()

@classmethod
@contextmanager
Expand Down Expand Up @@ -365,6 +366,13 @@ def put_writes(

@contextmanager
def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
"""Create a database cursor as a context manager.
Args:
pipeline (bool): whether to use pipeline for the DB operations inside the context manager.
Will be applied regardless of whether the PostgresSaver instance was initialized with a pipeline.
If pipeline mode is not supported, will fall back to using transaction context manager.
"""
with _get_connection(self.conn) as conn:
if self.pipe:
# a connection in pipeline mode can be used concurrently
Expand All @@ -379,10 +387,17 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
elif pipeline:
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
with self.lock, conn.pipeline(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
yield cur
if self.supports_pipeline:
with self.lock, conn.pipeline(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
yield cur
else:
# Use connection's transaction context manager when pipeline mode not supported
with self.lock, conn.transaction(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
yield cur
else:
with self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur:
yield cur
Expand Down
25 changes: 20 additions & 5 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Union

from langchain_core.runnables import RunnableConfig
from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline
from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline, Capabilities
from psycopg.errors import UndefinedTable
from psycopg.rows import DictRow, dict_row
from psycopg.types.json import Jsonb
Expand Down Expand Up @@ -55,6 +55,7 @@ def __init__(
self.pipe = pipe
self.lock = asyncio.Lock()
self.loop = asyncio.get_running_loop()
self.supports_pipeline = Capabilities().has_pipeline()

@classmethod
@asynccontextmanager
Expand Down Expand Up @@ -323,6 +324,13 @@ async def aput_writes(
async def _cursor(
self, *, pipeline: bool = False
) -> AsyncIterator[AsyncCursor[DictRow]]:
"""Create a database cursor as a context manager.
Args:
pipeline (bool): whether to use pipeline for the DB operations inside the context manager.
Will be applied regardless of whether the AsyncPostgresSaver instance was initialized with a pipeline.
If pipeline mode is not supported, will fall back to using transaction context manager.
"""
async with _get_connection(self.conn) as conn:
if self.pipe:
# a connection in pipeline mode can be used concurrently
Expand All @@ -337,10 +345,17 @@ async def _cursor(
elif pipeline:
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
async with self.lock, conn.pipeline(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
yield cur
if self.supports_pipeline:
async with self.lock, conn.pipeline(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
yield cur
else:
# Use connection's transaction context manager when pipeline mode not supported
async with self.lock, conn.transaction(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
yield cur
else:
async with self.lock, conn.cursor(
binary=True, row_factory=dict_row
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class BasePostgresSaver(BaseCheckpointSaver[str]):
INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL

jsonplus_serde = JsonPlusSerializer()
supports_pipeline: bool

def _load_checkpoint(
self,
Expand Down

0 comments on commit a2d6837

Please sign in to comment.