Skip to content

Commit

Permalink
Create indexes on checkpointer tables.
Browse files Browse the repository at this point in the history
  • Loading branch information
tjni committed Dec 10, 2024
1 parent 6ddac30 commit a1d0712
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 57 deletions.
9 changes: 9 additions & 0 deletions langgraph/checkpoint/mysql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);""",
"ALTER TABLE checkpoint_blobs MODIFY COLUMN `blob` LONGBLOB;",
"""
CREATE INDEX checkpoints_thread_id_idx ON checkpoints (thread_id);
""",
"""
CREATE INDEX checkpoint_blobs_thread_id_idx ON checkpoint_blobs (thread_id);
""",
"""
CREATE INDEX checkpoint_writes_thread_id_idx ON checkpoint_writes (thread_id);
""",
]

SELECT_SQL = f"""
Expand Down
30 changes: 12 additions & 18 deletions langgraph/store/mysql/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import aiomysql # type: ignore
import orjson
import pymysql
import pymysql.constants.ER

from langgraph.checkpoint.mysql import _ainternal
from langgraph.store.base import (
Expand Down Expand Up @@ -112,24 +111,19 @@ async def setup(self) -> None:
"""

async def _get_version(cur: aiomysql.DictCursor, table: str) -> int:
try:
await cur.execute(f"SELECT v FROM {table} ORDER BY v DESC LIMIT 1")
row = await cur.fetchone()
if row is None:
version = -1
else:
version = row["v"]
except pymysql.ProgrammingError as e:
if e.args[0] != pymysql.constants.ER.NO_SUCH_TABLE:
raise
version = -1
await cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {table} (
v INTEGER PRIMARY KEY
)
"""
await cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {table} (
v INTEGER PRIMARY KEY
)
"""
)
await cur.execute(f"SELECT v FROM {table} ORDER BY v DESC LIMIT 1")
row = await cur.fetchone()
if row is None:
version = -1
else:
version = row["v"]
return version

async with _ainternal.get_connection(self.conn) as conn:
Expand Down
49 changes: 20 additions & 29 deletions langgraph/store/mysql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,6 @@ def __init__(
self.conn = conn
self.lock = threading.Lock()

@staticmethod
def _is_no_such_table_error(e: Exception) -> bool:
raise NotImplementedError

@staticmethod
def _get_cursor_from_connection(conn: _internal.C) -> R:
raise NotImplementedError
Expand Down Expand Up @@ -440,32 +436,27 @@ def setup(self) -> None:
already exist and runs database migrations. It MUST be called directly by the user
the first time the store is used.
"""
with self._cursor() as cur:
try:
cur.execute("SELECT v FROM store_migrations ORDER BY v DESC LIMIT 1")
row = cur.fetchone()
if row is None:
version = -1
else:
version = row["v"]
print(f"Version: {version}")
except Exception as e:
if not self._is_no_such_table_error(e):
raise
cast(Any, self.conn).rollback()
version = -1
# Create store_migrations table if it doesn't exist
cur.execute(
"""
CREATE TABLE IF NOT EXISTS store_migrations (
v INTEGER PRIMARY KEY
)
"""

def _get_version(cur: R, table: str) -> int:
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {table} (
v INTEGER PRIMARY KEY
)
for v, migration in enumerate(
self.MIGRATIONS[version + 1 :], start=version + 1
):
cur.execute(migration)
"""
)
cur.execute(f"SELECT v FROM {table} ORDER BY v DESC LIMIT 1")
row = cur.fetchone()
if row is None:
version = -1
else:
version = row["v"]
return version

with self._cursor() as cur:
version = _get_version(cur, table="store_migrations")
for v, sql in enumerate(self.MIGRATIONS[version + 1 :], start=version + 1):
cur.execute(sql)
cur.execute("INSERT INTO store_migrations (v) VALUES (%s)", (v,))


Expand Down
9 changes: 0 additions & 9 deletions langgraph/store/mysql/pymysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Iterator

import pymysql
import pymysql.constants.ER
from pymysql.cursors import DictCursor
from typing_extensions import Self, override

Expand Down Expand Up @@ -49,14 +48,6 @@ def from_conn_string(
) as conn:
yield cls(conn)

@override
@staticmethod
def _is_no_such_table_error(e: Exception) -> bool:
return (
isinstance(e, pymysql.ProgrammingError)
and e.args[0] == pymysql.constants.ER.NO_SUCH_TABLE
)

@override
@staticmethod
def _get_cursor_from_connection(conn: pymysql.Connection) -> DictCursor:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint-mysql"
version = "2.0.5"
version = "2.0.6"
description = "Library with a MySQL implementation of LangGraph checkpoint saver."
authors = ["Theodore Ni <[email protected]>"]
license = "MIT"
Expand Down

0 comments on commit a1d0712

Please sign in to comment.