Skip to content

Commit

Permalink
Fixup initial provisioning of aio postgres db (#2571) (#2600)
Browse files Browse the repository at this point in the history
fixes #2570

---------

Co-authored-by: Tai Groot <[email protected]>
  • Loading branch information
hinthornw and taigrr authored Dec 3, 2024
1 parent 64b99c1 commit 4332a95
Show file tree
Hide file tree
Showing 5 changed files with 422 additions and 208 deletions.
18 changes: 8 additions & 10 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from langchain_core.runnables import RunnableConfig
from psycopg import Capabilities, Connection, Cursor, Pipeline
from psycopg.errors import UndefinedTable
from psycopg.rows import DictRow, dict_row
from psycopg.types.json import Jsonb
from psycopg_pool import ConnectionPool
Expand Down Expand Up @@ -76,16 +75,15 @@ def setup(self) -> None:
the first time checkpointer is used.
"""
with self._cursor() as cur:
try:
row = cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
).fetchone()
if row is None:
version = -1
else:
version = row["v"]
except UndefinedTable:
cur.execute(self.MIGRATIONS[0])
results = cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
)
row = results.fetchone()
if row is None:
version = -1
else:
version = row["v"]
for v, migration in zip(
range(version + 1, len(self.MIGRATIONS)),
self.MIGRATIONS[version + 1 :],
Expand Down
19 changes: 8 additions & 11 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from langchain_core.runnables import RunnableConfig
from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline, Capabilities
from psycopg.errors import UndefinedTable
from psycopg.rows import DictRow, dict_row
from psycopg.types.json import Jsonb
from psycopg_pool import AsyncConnectionPool
Expand Down Expand Up @@ -81,17 +80,15 @@ async def setup(self) -> None:
the first time checkpointer is used.
"""
async with self._cursor() as cur:
try:
results = await cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
)
row = await results.fetchone()
if row is None:
version = -1
else:
version = row["v"]
except UndefinedTable:
await cur.execute(self.MIGRATIONS[0])
results = await cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
)
row = await results.fetchone()
if row is None:
version = -1
else:
version = row["v"]
for v, migration in zip(
range(version + 1, len(self.MIGRATIONS)),
self.MIGRATIONS[version + 1 :],
Expand Down
1 change: 1 addition & 0 deletions libs/checkpoint-postgres/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tests.embed_test_utils import CharacterEmbeddings

DEFAULT_POSTGRES_URI = "postgres://postgres:postgres@localhost:5441/"
DEFAULT_URI = "postgres://postgres:postgres@localhost:5441/postgres?sslmode=disable"


Expand Down
303 changes: 209 additions & 94 deletions libs/checkpoint-postgres/tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# type: ignore

from contextlib import asynccontextmanager
from typing import Any
from uuid import uuid4

import pytest
from langchain_core.runnables import RunnableConfig
from psycopg import AsyncConnection
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool

from langgraph.checkpoint.base import (
Checkpoint,
Expand All @@ -10,104 +17,212 @@
empty_checkpoint,
)
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from tests.conftest import DEFAULT_URI


class TestAsyncPostgresSaver:
@pytest.fixture(autouse=True)
async def setup(self) -> None:
# objects for test setup
self.config_1: RunnableConfig = {
"configurable": {
"thread_id": "thread-1",
# for backwards compatibility testing
"thread_ts": "1",
"checkpoint_ns": "",
}
from tests.conftest import DEFAULT_POSTGRES_URI


@asynccontextmanager
async def _pool_saver():
"""Fixture for pool mode testing."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
async with AsyncConnectionPool(
DEFAULT_POSTGRES_URI + database,
max_size=10,
kwargs={"autocommit": True, "row_factory": dict_row},
) as pool:
checkpointer = AsyncPostgresSaver(pool)
await checkpointer.setup()
yield checkpointer
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")


@asynccontextmanager
async def _pipe_saver():
"""Fixture for pipeline mode testing."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI + database,
autocommit=True,
prepare_threshold=0,
row_factory=dict_row,
) as conn:
async with conn.pipeline() as pipe:
checkpointer = AsyncPostgresSaver(conn, pipe=pipe)
await checkpointer.setup()
async with conn.pipeline() as pipe:
checkpointer = AsyncPostgresSaver(conn, pipe=pipe)
yield checkpointer
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")


@asynccontextmanager
async def _base_saver():
"""Fixture for regular connection mode testing."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI + database,
autocommit=True,
prepare_threshold=0,
row_factory=dict_row,
) as conn:
checkpointer = AsyncPostgresSaver(conn)
await checkpointer.setup()
yield checkpointer
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")


@asynccontextmanager
async def _saver(name: str):
if name == "base":
async with _base_saver() as saver:
yield saver
elif name == "pool":
async with _pool_saver() as saver:
yield saver
elif name == "pipe":
async with _pipe_saver() as saver:
yield saver


@pytest.fixture
def test_data():
"""Fixture providing test data for checkpoint tests."""
config_1: RunnableConfig = {
"configurable": {
"thread_id": "thread-1",
# for backwards compatibility testing
"thread_ts": "1",
"checkpoint_ns": "",
}
self.config_2: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2",
"checkpoint_ns": "",
}
}
config_2: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2",
"checkpoint_ns": "",
}
self.config_3: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2-inner",
"checkpoint_ns": "inner",
}
}
config_3: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2-inner",
"checkpoint_ns": "inner",
}
}

self.chkpnt_1: Checkpoint = empty_checkpoint()
self.chkpnt_2: Checkpoint = create_checkpoint(self.chkpnt_1, {}, 1)
self.chkpnt_3: Checkpoint = empty_checkpoint()
chkpnt_1: Checkpoint = empty_checkpoint()
chkpnt_2: Checkpoint = create_checkpoint(chkpnt_1, {}, 1)
chkpnt_3: Checkpoint = empty_checkpoint()

self.metadata_1: CheckpointMetadata = {
"source": "input",
"step": 2,
"writes": {},
"score": 1,
}
self.metadata_2: CheckpointMetadata = {
"source": "loop",
metadata_1: CheckpointMetadata = {
"source": "input",
"step": 2,
"writes": {},
"score": 1,
}
metadata_2: CheckpointMetadata = {
"source": "loop",
"step": 1,
"writes": {"foo": "bar"},
"score": None,
}
metadata_3: CheckpointMetadata = {}

return {
"configs": [config_1, config_2, config_3],
"checkpoints": [chkpnt_1, chkpnt_2, chkpnt_3],
"metadata": [metadata_1, metadata_2, metadata_3],
}


@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"])
async def test_asearch(request, saver_name: str, test_data) -> None:
async with _saver(saver_name) as saver:
configs = test_data["configs"]
checkpoints = test_data["checkpoints"]
metadata = test_data["metadata"]

await saver.aput(configs[0], checkpoints[0], metadata[0], {})
await saver.aput(configs[1], checkpoints[1], metadata[1], {})
await saver.aput(configs[2], checkpoints[2], metadata[2], {})

# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
"score": None,
}
self.metadata_3: CheckpointMetadata = {}
async with AsyncPostgresSaver.from_conn_string(DEFAULT_URI) as saver:
await saver.setup()

async def test_asearch(self) -> None:
async with AsyncPostgresSaver.from_conn_string(DEFAULT_URI) as saver:
await saver.aput(self.config_1, self.chkpnt_1, self.metadata_1, {})
await saver.aput(self.config_2, self.chkpnt_2, self.metadata_2, {})
await saver.aput(self.config_3, self.chkpnt_3, self.metadata_3, {})

# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match

search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
assert len(search_results_1) == 1
assert search_results_1[0].metadata == self.metadata_1

search_results_2 = [c async for c in saver.alist(None, filter=query_2)]
assert len(search_results_2) == 1
assert search_results_2[0].metadata == self.metadata_2

search_results_3 = [c async for c in saver.alist(None, filter=query_3)]
assert len(search_results_3) == 3

search_results_4 = [c async for c in saver.alist(None, filter=query_4)]
assert len(search_results_4) == 0

# search by config (defaults to checkpoints across all namespaces)
search_results_5 = [
c
async for c in saver.alist({"configurable": {"thread_id": "thread-2"}})
]
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}

# TODO: test before and limit params

async def test_null_chars(self) -> None:
async with AsyncPostgresSaver.from_conn_string(DEFAULT_URI) as saver:
config = await saver.aput(
self.config_1, self.chkpnt_1, {"my_key": "\x00abc"}, {}
)
assert (await saver.aget_tuple(config)).metadata["my_key"] == "abc" # type: ignore
assert [c async for c in saver.alist(None, filter={"my_key": "abc"})][
0
].metadata["my_key"] == "abc"
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match

search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
assert len(search_results_1) == 1
assert search_results_1[0].metadata == metadata[0]

search_results_2 = [c async for c in saver.alist(None, filter=query_2)]
assert len(search_results_2) == 1
assert search_results_2[0].metadata == metadata[1]

search_results_3 = [c async for c in saver.alist(None, filter=query_3)]
assert len(search_results_3) == 3

search_results_4 = [c async for c in saver.alist(None, filter=query_4)]
assert len(search_results_4) == 0

# search by config (defaults to checkpoints across all namespaces)
search_results_5 = [
c async for c in saver.alist({"configurable": {"thread_id": "thread-2"}})
]
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}


@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"])
async def test_null_chars(request, saver_name: str, test_data) -> None:
async with _saver(saver_name) as saver:
config = await saver.aput(
test_data["configs"][0],
test_data["checkpoints"][0],
{"my_key": "\x00abc"},
{},
)
assert (await saver.aget_tuple(config)).metadata["my_key"] == "abc" # type: ignore
assert [c async for c in saver.alist(None, filter={"my_key": "abc"})][
0
].metadata["my_key"] == "abc"
Loading

0 comments on commit 4332a95

Please sign in to comment.