diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py index 0dd7d7733b..d169362648 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py @@ -1,4 +1,4 @@ -from hashlib import md5 +from hashlib import sha1 from typing import Any, List, Optional, Tuple from langchain_core.runnables import RunnableConfig @@ -245,7 +245,7 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> current_v = int(current.split(".")[0]) next_v = current_v + 1 try: - next_h = md5(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest() + next_h = sha1(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest() except EmptyChannelError: next_h = "" return f"{next_v:032}.{next_h}" diff --git a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py index 8b6f728e99..9a5a3f0334 100644 --- a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py +++ b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py @@ -1,7 +1,7 @@ import sqlite3 import threading from contextlib import closing, contextmanager -from hashlib import md5 +from hashlib import sha1 from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple from langchain_core.runnables import RunnableConfig @@ -515,7 +515,7 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> current_v = int(current.split(".")[0]) next_v = current_v + 1 try: - next_h = md5(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest() + next_h = sha1(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest() except EmptyChannelError: next_h = "" return f"{next_v:032}.{next_h}" diff --git a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py index 3ffd715c17..c78b01110a 100644 --- a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py +++ b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py @@ -1,5 +1,6 @@ import asyncio from contextlib import asynccontextmanager +from hashlib import sha1 from typing import ( Any, AsyncIterator, @@ -22,10 +23,12 @@ Checkpoint, CheckpointMetadata, CheckpointTuple, + EmptyChannelError, SerializerProtocol, get_checkpoint_id, ) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from langgraph.checkpoint.serde.types import ChannelProtocol from langgraph.checkpoint.sqlite.utils import search_where T = TypeVar("T", bound=callable) @@ -498,3 +501,26 @@ async def aput_writes( for idx, (channel, value) in enumerate(writes) ], ) + + def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: + """Generate the next version ID for a channel. + + This method creates a new version identifier for a channel based on its current version. + + Args: + current (Optional[str]): The current version identifier of the channel. + channel (BaseChannel): The channel being versioned. + + Returns: + str: The next version identifier, which is guaranteed to be monotonically increasing. + """ + if current is None: + current_v = 0 + else: + current_v = int(current.split(".")[0]) + next_v = current_v + 1 + try: + next_h = sha1(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest() + except EmptyChannelError: + next_h = "" + return f"{next_v:032}.{next_h}" diff --git a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py index 372e5ba2f6..7ca0e31486 100644 --- a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py @@ -2,6 +2,7 @@ from collections import defaultdict from contextlib import AbstractAsyncContextManager, AbstractContextManager from functools import partial +from hashlib import sha1 from types import TracebackType from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple @@ -14,10 +15,11 @@ Checkpoint, CheckpointMetadata, CheckpointTuple, + EmptyChannelError, SerializerProtocol, get_checkpoint_id, ) -from langgraph.checkpoint.serde.types import TASKS +from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol class MemorySaver( @@ -457,3 +459,17 @@ async def aput_writes( return await asyncio.get_running_loop().run_in_executor( None, self.put_writes, config, writes, task_id ) + + def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: + if current is None: + current_v = 0 + elif isinstance(current, int): + current_v = current + else: + current_v = int(current.split(".")[0]) + next_v = current_v + 1 + try: + next_h = sha1(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest() + except EmptyChannelError: + next_h = "" + return f"{next_v:032}.{next_h}" diff --git a/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py b/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py index e777a81764..90540e2b73 100644 --- a/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py +++ b/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py @@ -17,8 +17,8 @@ IPv6Network, ) from typing import Any, Optional -from uuid import UUID +import orjson from langchain_core.load.load import Reviver from langchain_core.load.serializable import Serializable from zoneinfo import ZoneInfo @@ -58,14 +58,14 @@ def _default(self, obj): return self._encode_constructor_args( obj.__class__, method=[None, "construct"], kwargs=obj.dict() ) + elif hasattr(obj, "_asdict") and callable(obj._asdict): + return self._encode_constructor_args(obj.__class__, kwargs=obj._asdict()) elif isinstance(obj, pathlib.Path): return self._encode_constructor_args(pathlib.Path, args=obj.parts) elif isinstance(obj, re.Pattern): return self._encode_constructor_args( re.compile, args=[obj.pattern, obj.flags] ) - elif isinstance(obj, UUID): - return self._encode_constructor_args(UUID, args=[obj.hex]) elif isinstance(obj, decimal.Decimal): return self._encode_constructor_args(decimal.Decimal, args=[str(obj)]) elif isinstance(obj, (set, frozenset, deque)): @@ -165,9 +165,7 @@ def _reviver(self, value: dict[str, Any]) -> Any: return LC_REVIVER(value) def dumps(self, obj: Any) -> bytes: - return json.dumps(obj, default=self._default, ensure_ascii=False).encode( - "utf-8", "ignore" - ) + return orjson.dumps(obj, default=self._default, option=_option) def dumps_typed(self, obj: Any) -> tuple[str, bytes]: if isinstance(obj, bytes): @@ -190,3 +188,10 @@ def loads_typed(self, data: tuple[str, bytes]) -> Any: return self.loads(data_) else: raise NotImplementedError(f"Unknown serialization type: {type_}") + + +_option = ( + orjson.OPT_PASSTHROUGH_DATACLASS + | orjson.OPT_PASSTHROUGH_DATETIME + | orjson.OPT_NON_STR_KEYS +)