Skip to content

Commit

Permalink
Update name to InMemorySaver
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Oct 8, 2024
1 parent 7c2a89d commit 4a76da6
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 49 deletions.
69 changes: 39 additions & 30 deletions libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol


class MemorySaver(
class InMemorySaver(
BaseCheckpointSaver[str], AbstractContextManager, AbstractAsyncContextManager
):
"""An in-memory checkpoint saver.
This checkpoint saver stores checkpoints in memory using a defaultdict.
Note:
Only use `MemorySaver` for debugging or testing purposes.
Only use `InMemorySaver` for debugging or testing purposes.
For production use cases we recommend installing [langgraph-checkpoint-postgres](https://pypi.org/project/langgraph-checkpoint-postgres/) and using `PostgresSaver` / `AsyncPostgresSaver`.
Args:
Expand All @@ -39,15 +39,15 @@ class MemorySaver(
import asyncio
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import StateGraph
builder = StateGraph(int)
builder.add_node("add_one", lambda x: x + 1)
builder.set_entry_point("add_one")
builder.set_finish_point("add_one")
memory = MemorySaver()
memory = InMemorySaver()
graph = builder.compile(checkpointer=memory)
coro = graph.ainvoke(1, {"configurable": {"thread_id": "thread-1"}})
asyncio.run(coro) # Output: 2
Expand All @@ -73,7 +73,7 @@ def __init__(
self.storage = defaultdict(lambda: defaultdict(dict))
self.writes = defaultdict(dict)

def __enter__(self) -> "MemorySaver":
def __enter__(self) -> "InMemorySaver":
return self

def __exit__(
Expand All @@ -84,7 +84,7 @@ def __exit__(
) -> Optional[bool]:
return

async def __aenter__(self) -> "MemorySaver":
async def __aenter__(self) -> "InMemorySaver":
return self

async def __aexit__(
Expand Down Expand Up @@ -135,15 +135,17 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v in writes
],
parent_config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
parent_config=(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
}
if parent_checkpoint_id
else None,
if parent_checkpoint_id
else None
),
)
else:
if checkpoints := self.storage[thread_id][checkpoint_ns]:
Expand Down Expand Up @@ -176,15 +178,17 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v in writes
],
parent_config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
parent_config=(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
}
if parent_checkpoint_id
else None,
if parent_checkpoint_id
else None
),
)

def list(
Expand Down Expand Up @@ -285,15 +289,17 @@ def list(
"pending_sends": [self.serde.loads_typed(s) for s in sends],
},
metadata=metadata,
parent_config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
parent_config=(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
}
if parent_checkpoint_id
else None,
if parent_checkpoint_id
else None
),
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v in writes
],
Expand Down Expand Up @@ -474,3 +480,6 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) ->
next_v = current_v + 1
next_h = random.random()
return f"{next_v:032}.{next_h:016}"


MemorySaver = InMemorySaver # Kept for backwards compatibility
2 changes: 1 addition & 1 deletion libs/checkpoint/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint"
version = "2.0.1"
version = "2.0.2"
description = "Library with base interfaces for LangGraph checkpoint savers."
authors = []
license = "MIT"
Expand Down
4 changes: 2 additions & 2 deletions libs/checkpoint/tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
create_checkpoint,
empty_checkpoint,
)
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint.memory import InMemorySaver


class TestMemorySaver:
@pytest.fixture(autouse=True)
def setup(self) -> None:
self.memory_saver = MemorySaver()
self.memory_saver = InMemorySaver()

# objects for test setup
self.config_1: RunnableConfig = {
Expand Down
8 changes: 4 additions & 4 deletions libs/langgraph/tests/memory_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SerializerProtocol,
copy_checkpoint,
)
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint.memory import InMemorySaver


class NoopSerializer(SerializerProtocol):
Expand All @@ -23,7 +23,7 @@ def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
return "type", obj


class MemorySaverAssertImmutable(MemorySaver):
class MemorySaverAssertImmutable(InMemorySaver):
storage_for_copies: defaultdict[str, dict[str, dict[str, Checkpoint]]]

def __init__(
Expand Down Expand Up @@ -64,7 +64,7 @@ def put(
return super().put(config, checkpoint, metadata, new_versions)


class MemorySaverAssertCheckpointMetadata(MemorySaver):
class MemorySaverAssertCheckpointMetadata(InMemorySaver):
"""This custom checkpointer is for verifying that a run's configurable
fields are merged with the previous checkpoint config for each step in
the run. This is the desired behavior. Because the checkpointer's (a)put()
Expand Down Expand Up @@ -119,7 +119,7 @@ async def aput(
)


class MemorySaverNoPending(MemorySaver):
class MemorySaverNoPending(InMemorySaver):
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
result = super().get_tuple(config)
if result:
Expand Down
14 changes: 7 additions & 7 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
CheckpointMetadata,
CheckpointTuple,
)
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import ERROR, PULL, PUSH
from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt
from langgraph.graph import END, Graph
Expand Down Expand Up @@ -268,11 +268,11 @@ def node_b(state: State) -> State:


def test_checkpoint_errors() -> None:
class FaultyGetCheckpointer(MemorySaver):
class FaultyGetCheckpointer(InMemorySaver):
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
raise ValueError("Faulty get_tuple")

class FaultyPutCheckpointer(MemorySaver):
class FaultyPutCheckpointer(InMemorySaver):
def put(
self,
config: RunnableConfig,
Expand All @@ -282,13 +282,13 @@ def put(
) -> RunnableConfig:
raise ValueError("Faulty put")

class FaultyPutWritesCheckpointer(MemorySaver):
class FaultyPutWritesCheckpointer(InMemorySaver):
def put_writes(
self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str
) -> RunnableConfig:
raise ValueError("Faulty put_writes")

class FaultyVersionCheckpointer(MemorySaver):
class FaultyVersionCheckpointer(InMemorySaver):
def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int:
raise ValueError("Faulty get_next_version")

Expand Down Expand Up @@ -11250,7 +11250,7 @@ def generate_answer(state):
interview_builder.add_conditional_edges("answer_question", route_messages)

# Set up memory
memory = MemorySaver()
memory = InMemorySaver()

# Interview
interview_graph = interview_builder.compile(checkpointer=memory).with_config(
Expand Down Expand Up @@ -11478,7 +11478,7 @@ def child_node_b(state: ChildState):
parent.add_edge("parent_node", "child_graph")
parent.set_entry_point("parent_node")

checkpointer = MemorySaver()
checkpointer = InMemorySaver()
app = parent.compile(checkpointer=checkpointer)
with pytest.raises(RandomError):
app.invoke({"count": 0}, {"configurable": {"thread_id": "foo"}})
Expand Down
10 changes: 5 additions & 5 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
CheckpointMetadata,
CheckpointTuple,
)
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import ERROR, PULL, PUSH
from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt
from langgraph.graph import END, Graph, StateGraph
Expand Down Expand Up @@ -89,11 +89,11 @@


async def test_checkpoint_errors() -> None:
class FaultyGetCheckpointer(MemorySaver):
class FaultyGetCheckpointer(InMemorySaver):
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
raise ValueError("Faulty get_tuple")

class FaultyPutCheckpointer(MemorySaver):
class FaultyPutCheckpointer(InMemorySaver):
async def aput(
self,
config: RunnableConfig,
Expand All @@ -103,13 +103,13 @@ async def aput(
) -> RunnableConfig:
raise ValueError("Faulty put")

class FaultyPutWritesCheckpointer(MemorySaver):
class FaultyPutWritesCheckpointer(InMemorySaver):
async def aput_writes(
self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str
) -> RunnableConfig:
raise ValueError("Faulty put_writes")

class FaultyVersionCheckpointer(MemorySaver):
class FaultyVersionCheckpointer(InMemorySaver):
def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int:
raise ValueError("Faulty get_next_version")

Expand Down

0 comments on commit 4a76da6

Please sign in to comment.