Skip to content

Commit

Permalink
use capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Nov 18, 2024
1 parent f050515 commit f807b73
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 27 deletions.
16 changes: 3 additions & 13 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, Iterator, Optional, Sequence, Union

from langchain_core.runnables import RunnableConfig
from psycopg import Connection, Cursor, Pipeline
from psycopg.errors import NotSupportedError, UndefinedTable
from psycopg import Capabilities, Connection, Cursor, Pipeline
from psycopg.errors import UndefinedTable
from psycopg.rows import DictRow, dict_row
from psycopg.types.json import Jsonb
from psycopg_pool import ConnectionPool
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(
self.conn = conn
self.pipe = pipe
self.lock = threading.Lock()
self.supports_pipeline = Capabilities().has_pipeline()

@classmethod
@contextmanager
Expand Down Expand Up @@ -363,16 +364,6 @@ def put_writes(
),
)

def _check_pipeline_support(self, conn: Connection[DictRow]) -> None:
if self.supports_pipeline is not None:
return

try:
with conn.pipeline():
self.supports_pipeline = True
except NotSupportedError:
self.supports_pipeline = False

@contextmanager
def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
"""Create a database cursor as a context manager.
Expand All @@ -394,7 +385,6 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
if pipeline:
self.pipe.sync()
elif pipeline:
self._check_pipeline_support(conn)
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
Expand Down
16 changes: 3 additions & 13 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Union

from langchain_core.runnables import RunnableConfig
from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline
from psycopg.errors import NotSupportedError, UndefinedTable
from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline, Capabilities
from psycopg.errors import UndefinedTable
from psycopg.rows import DictRow, dict_row
from psycopg.types.json import Jsonb
from psycopg_pool import AsyncConnectionPool
Expand Down Expand Up @@ -55,6 +55,7 @@ def __init__(
self.pipe = pipe
self.lock = asyncio.Lock()
self.loop = asyncio.get_running_loop()
self.supports_pipeline = Capabilities().has_pipeline()

@classmethod
@asynccontextmanager
Expand Down Expand Up @@ -319,16 +320,6 @@ async def aput_writes(
async with self._cursor(pipeline=True) as cur:
await cur.executemany(query, params)

async def _check_pipeline_support(self, conn: AsyncConnection[DictRow]) -> None:
if self.supports_pipeline is not None:
return

try:
async with conn.pipeline():
self.supports_pipeline = True
except NotSupportedError:
self.supports_pipeline = False

@asynccontextmanager
async def _cursor(
self, *, pipeline: bool = False
Expand All @@ -352,7 +343,6 @@ async def _cursor(
if pipeline:
await self.pipe.sync()
elif pipeline:
await self._check_pipeline_support(conn)
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class BasePostgresSaver(BaseCheckpointSaver[str]):
INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL

jsonplus_serde = JsonPlusSerializer()
supports_pipeline: Optional[bool] = None
supports_pipeline: bool

def _load_checkpoint(
self,
Expand Down

0 comments on commit f807b73

Please sign in to comment.