Skip to content

Commit

Permalink
Performance improvements in checkpointer libs
Browse files Browse the repository at this point in the history
- Use sha1 instead of md5 for hashing (faster in python 3.x)
- Use orjson instead of json for json dumping (sadly can't use for json loading)
  • Loading branch information
nfcampos committed Sep 11, 2024
1 parent 620a82b commit 8426e10
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from hashlib import md5
from hashlib import sha1
from typing import Any, List, Optional, Tuple

from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -245,7 +245,7 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) ->
current_v = int(current.split(".")[0])
next_v = current_v + 1
try:
next_h = md5(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest()
next_h = sha1(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest()
except EmptyChannelError:
next_h = ""
return f"{next_v:032}.{next_h}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sqlite3
import threading
from contextlib import closing, contextmanager
from hashlib import md5
from hashlib import sha1
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple

from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -515,7 +515,7 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) ->
current_v = int(current.split(".")[0])
next_v = current_v + 1
try:
next_h = md5(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest()
next_h = sha1(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest()
except EmptyChannelError:
next_h = ""
return f"{next_v:032}.{next_h}"
26 changes: 26 additions & 0 deletions libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from contextlib import asynccontextmanager
from hashlib import sha1
from typing import (
Any,
AsyncIterator,
Expand All @@ -22,10 +23,12 @@
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
EmptyChannelError,
SerializerProtocol,
get_checkpoint_id,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import ChannelProtocol
from langgraph.checkpoint.sqlite.utils import search_where

T = TypeVar("T", bound=callable)
Expand Down Expand Up @@ -498,3 +501,26 @@ async def aput_writes(
for idx, (channel, value) in enumerate(writes)
],
)

def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
"""Generate the next version ID for a channel.
This method creates a new version identifier for a channel based on its current version.
Args:
current (Optional[str]): The current version identifier of the channel.
channel (BaseChannel): The channel being versioned.
Returns:
str: The next version identifier, which is guaranteed to be monotonically increasing.
"""
if current is None:
current_v = 0
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
try:
next_h = sha1(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest()
except EmptyChannelError:
next_h = ""
return f"{next_v:032}.{next_h}"
18 changes: 17 additions & 1 deletion libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from functools import partial
from hashlib import sha1
from types import TracebackType
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple

Expand All @@ -14,10 +15,11 @@
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
EmptyChannelError,
SerializerProtocol,
get_checkpoint_id,
)
from langgraph.checkpoint.serde.types import TASKS
from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol


class MemorySaver(
Expand Down Expand Up @@ -457,3 +459,17 @@ async def aput_writes(
return await asyncio.get_running_loop().run_in_executor(
None, self.put_writes, config, writes, task_id
)

def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
if current is None:
current_v = 0
elif isinstance(current, int):
current_v = current
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
try:
next_h = sha1(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest()
except EmptyChannelError:
next_h = ""
return f"{next_v:032}.{next_h}"
17 changes: 11 additions & 6 deletions libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
IPv6Network,
)
from typing import Any, Optional
from uuid import UUID

import orjson
from langchain_core.load.load import Reviver
from langchain_core.load.serializable import Serializable
from zoneinfo import ZoneInfo
Expand Down Expand Up @@ -58,14 +58,14 @@ def _default(self, obj):
return self._encode_constructor_args(
obj.__class__, method=[None, "construct"], kwargs=obj.dict()
)
elif hasattr(obj, "_asdict") and callable(obj._asdict):
return self._encode_constructor_args(obj.__class__, kwargs=obj._asdict())
elif isinstance(obj, pathlib.Path):
return self._encode_constructor_args(pathlib.Path, args=obj.parts)
elif isinstance(obj, re.Pattern):
return self._encode_constructor_args(
re.compile, args=[obj.pattern, obj.flags]
)
elif isinstance(obj, UUID):
return self._encode_constructor_args(UUID, args=[obj.hex])
elif isinstance(obj, decimal.Decimal):
return self._encode_constructor_args(decimal.Decimal, args=[str(obj)])
elif isinstance(obj, (set, frozenset, deque)):
Expand Down Expand Up @@ -165,9 +165,7 @@ def _reviver(self, value: dict[str, Any]) -> Any:
return LC_REVIVER(value)

def dumps(self, obj: Any) -> bytes:
return json.dumps(obj, default=self._default, ensure_ascii=False).encode(
"utf-8", "ignore"
)
return orjson.dumps(obj, default=self._default, option=_option)

def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
if isinstance(obj, bytes):
Expand All @@ -190,3 +188,10 @@ def loads_typed(self, data: tuple[str, bytes]) -> Any:
return self.loads(data_)
else:
raise NotImplementedError(f"Unknown serialization type: {type_}")


_option = (
orjson.OPT_PASSTHROUGH_DATACLASS
| orjson.OPT_PASSTHROUGH_DATETIME
| orjson.OPT_NON_STR_KEYS
)

0 comments on commit 8426e10

Please sign in to comment.