diff --git a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py index aea9069b9..067144731 100644 --- a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py @@ -21,7 +21,7 @@ from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol -class MemorySaver( +class InMemorySaver( BaseCheckpointSaver[str], AbstractContextManager, AbstractAsyncContextManager ): """An in-memory checkpoint saver. @@ -29,7 +29,7 @@ class MemorySaver( This checkpoint saver stores checkpoints in memory using a defaultdict. Note: - Only use `MemorySaver` for debugging or testing purposes. + Only use `InMemorySaver` for debugging or testing purposes. For production use cases we recommend installing [langgraph-checkpoint-postgres](https://pypi.org/project/langgraph-checkpoint-postgres/) and using `PostgresSaver` / `AsyncPostgresSaver`. Args: @@ -39,7 +39,7 @@ class MemorySaver( import asyncio - from langgraph.checkpoint.memory import MemorySaver + from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import StateGraph builder = StateGraph(int) @@ -47,7 +47,7 @@ class MemorySaver( builder.set_entry_point("add_one") builder.set_finish_point("add_one") - memory = MemorySaver() + memory = InMemorySaver() graph = builder.compile(checkpointer=memory) coro = graph.ainvoke(1, {"configurable": {"thread_id": "thread-1"}}) asyncio.run(coro) # Output: 2 @@ -73,7 +73,7 @@ def __init__( self.storage = defaultdict(lambda: defaultdict(dict)) self.writes = defaultdict(dict) - def __enter__(self) -> "MemorySaver": + def __enter__(self) -> "InMemorySaver": return self def __exit__( @@ -84,7 +84,7 @@ def __exit__( ) -> Optional[bool]: return - async def __aenter__(self) -> "MemorySaver": + async def __aenter__(self) -> "InMemorySaver": return self async def __aexit__( @@ -135,15 +135,17 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: pending_writes=[ (id, c, self.serde.loads_typed(v)) for id, c, v in writes ], - parent_config={ - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": parent_checkpoint_id, + parent_config=( + { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": parent_checkpoint_id, + } } - } - if parent_checkpoint_id - else None, + if parent_checkpoint_id + else None + ), ) else: if checkpoints := self.storage[thread_id][checkpoint_ns]: @@ -176,15 +178,17 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: pending_writes=[ (id, c, self.serde.loads_typed(v)) for id, c, v in writes ], - parent_config={ - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": parent_checkpoint_id, + parent_config=( + { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": parent_checkpoint_id, + } } - } - if parent_checkpoint_id - else None, + if parent_checkpoint_id + else None + ), ) def list( @@ -285,15 +289,17 @@ def list( "pending_sends": [self.serde.loads_typed(s) for s in sends], }, metadata=metadata, - parent_config={ - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": parent_checkpoint_id, + parent_config=( + { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": parent_checkpoint_id, + } } - } - if parent_checkpoint_id - else None, + if parent_checkpoint_id + else None + ), pending_writes=[ (id, c, self.serde.loads_typed(v)) for id, c, v in writes ], @@ -474,3 +480,6 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> next_v = current_v + 1 next_h = random.random() return f"{next_v:032}.{next_h:016}" + + +MemorySaver = InMemorySaver # Kept for backwards compatibility diff --git a/libs/checkpoint/pyproject.toml b/libs/checkpoint/pyproject.toml index 376e63422..491624228 100644 --- a/libs/checkpoint/pyproject.toml +++ b/libs/checkpoint/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langgraph-checkpoint" -version = "2.0.1" +version = "2.0.2" description = "Library with base interfaces for LangGraph checkpoint savers." authors = [] license = "MIT" diff --git a/libs/checkpoint/tests/test_memory.py b/libs/checkpoint/tests/test_memory.py index 578437233..8e2fded54 100644 --- a/libs/checkpoint/tests/test_memory.py +++ b/libs/checkpoint/tests/test_memory.py @@ -9,13 +9,13 @@ create_checkpoint, empty_checkpoint, ) -from langgraph.checkpoint.memory import MemorySaver +from langgraph.checkpoint.memory import InMemorySaver class TestMemorySaver: @pytest.fixture(autouse=True) def setup(self) -> None: - self.memory_saver = MemorySaver() + self.memory_saver = InMemorySaver() # objects for test setup self.config_1: RunnableConfig = { diff --git a/libs/langgraph/tests/memory_assert.py b/libs/langgraph/tests/memory_assert.py index 6b44051f7..44d803a92 100644 --- a/libs/langgraph/tests/memory_assert.py +++ b/libs/langgraph/tests/memory_assert.py @@ -12,7 +12,7 @@ SerializerProtocol, copy_checkpoint, ) -from langgraph.checkpoint.memory import MemorySaver +from langgraph.checkpoint.memory import InMemorySaver class NoopSerializer(SerializerProtocol): @@ -23,7 +23,7 @@ def dumps_typed(self, obj: Any) -> tuple[str, bytes]: return "type", obj -class MemorySaverAssertImmutable(MemorySaver): +class MemorySaverAssertImmutable(InMemorySaver): storage_for_copies: defaultdict[str, dict[str, dict[str, Checkpoint]]] def __init__( @@ -64,7 +64,7 @@ def put( return super().put(config, checkpoint, metadata, new_versions) -class MemorySaverAssertCheckpointMetadata(MemorySaver): +class MemorySaverAssertCheckpointMetadata(InMemorySaver): """This custom checkpointer is for verifying that a run's configurable fields are merged with the previous checkpoint config for each step in the run. This is the desired behavior. Because the checkpointer's (a)put() @@ -119,7 +119,7 @@ async def aput( ) -class MemorySaverNoPending(MemorySaver): +class MemorySaverNoPending(InMemorySaver): def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: result = super().get_tuple(config) if result: diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 196fe4223..b930b0b9d 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -53,7 +53,7 @@ CheckpointMetadata, CheckpointTuple, ) -from langgraph.checkpoint.memory import MemorySaver +from langgraph.checkpoint.memory import InMemorySaver from langgraph.constants import ERROR, PULL, PUSH from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt from langgraph.graph import END, Graph @@ -268,11 +268,11 @@ def node_b(state: State) -> State: def test_checkpoint_errors() -> None: - class FaultyGetCheckpointer(MemorySaver): + class FaultyGetCheckpointer(InMemorySaver): def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: raise ValueError("Faulty get_tuple") - class FaultyPutCheckpointer(MemorySaver): + class FaultyPutCheckpointer(InMemorySaver): def put( self, config: RunnableConfig, @@ -282,13 +282,13 @@ def put( ) -> RunnableConfig: raise ValueError("Faulty put") - class FaultyPutWritesCheckpointer(MemorySaver): + class FaultyPutWritesCheckpointer(InMemorySaver): def put_writes( self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str ) -> RunnableConfig: raise ValueError("Faulty put_writes") - class FaultyVersionCheckpointer(MemorySaver): + class FaultyVersionCheckpointer(InMemorySaver): def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int: raise ValueError("Faulty get_next_version") @@ -11250,7 +11250,7 @@ def generate_answer(state): interview_builder.add_conditional_edges("answer_question", route_messages) # Set up memory - memory = MemorySaver() + memory = InMemorySaver() # Interview interview_graph = interview_builder.compile(checkpointer=memory).with_config( @@ -11478,7 +11478,7 @@ def child_node_b(state: ChildState): parent.add_edge("parent_node", "child_graph") parent.set_entry_point("parent_node") - checkpointer = MemorySaver() + checkpointer = InMemorySaver() app = parent.compile(checkpointer=checkpointer) with pytest.raises(RandomError): app.invoke({"count": 0}, {"configurable": {"thread_id": "foo"}}) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 0290bf069..08f15265c 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -49,7 +49,7 @@ CheckpointMetadata, CheckpointTuple, ) -from langgraph.checkpoint.memory import MemorySaver +from langgraph.checkpoint.memory import InMemorySaver from langgraph.constants import ERROR, PULL, PUSH from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt from langgraph.graph import END, Graph, StateGraph @@ -89,11 +89,11 @@ async def test_checkpoint_errors() -> None: - class FaultyGetCheckpointer(MemorySaver): + class FaultyGetCheckpointer(InMemorySaver): async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: raise ValueError("Faulty get_tuple") - class FaultyPutCheckpointer(MemorySaver): + class FaultyPutCheckpointer(InMemorySaver): async def aput( self, config: RunnableConfig, @@ -103,13 +103,13 @@ async def aput( ) -> RunnableConfig: raise ValueError("Faulty put") - class FaultyPutWritesCheckpointer(MemorySaver): + class FaultyPutWritesCheckpointer(InMemorySaver): async def aput_writes( self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str ) -> RunnableConfig: raise ValueError("Faulty put_writes") - class FaultyVersionCheckpointer(MemorySaver): + class FaultyVersionCheckpointer(InMemorySaver): def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int: raise ValueError("Faulty get_next_version")