Skip to content

Commit

Permalink
Merge pull request #2661 from langchain-ai/nc/5dec/perf
Browse files Browse the repository at this point in the history
lib: Performance improvements
  • Loading branch information
nfcampos authored Dec 10, 2024
2 parents 02f1904 + 30f852e commit dc0398e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 47 deletions.
35 changes: 6 additions & 29 deletions libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import asyncio
import logging
import os
import pickle
import random
import shutil
from collections import defaultdict
from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack
from functools import partial
from types import TracebackType
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple, Type

Expand Down Expand Up @@ -395,9 +393,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.get_tuple, config
)
return self.get_tuple(config)

async def alist(
self,
Expand All @@ -418,24 +414,8 @@ async def alist(
Yields:
AsyncIterator[CheckpointTuple]: An asynchronous iterator of checkpoint tuples.
"""
loop = asyncio.get_running_loop()
iter = await loop.run_in_executor(
None,
partial(
self.list,
before=before,
limit=limit,
filter=filter,
),
config,
)
while True:
# handling StopIteration exception inside coroutine won't work
# as expected, so using next() with default value to break the loop
if item := await loop.run_in_executor(None, next, iter, None):
yield item
else:
break
for item in self.list(config, filter=filter, before=before, limit=limit):
yield item

async def aput(
self,
Expand All @@ -455,9 +435,7 @@ async def aput(
Returns:
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.put, config, checkpoint, metadata, new_versions
)
return self.put(config, checkpoint, metadata, new_versions)

async def aput_writes(
self,
Expand All @@ -474,10 +452,9 @@ async def aput_writes(
config (RunnableConfig): The config to associate with the writes.
writes (List[Tuple[str, Any]]): The writes to save, each as a (channel, value) pair.
task_id (str): Identifier for the task creating the writes.
return self.put_writes(config, writes, task_id)
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.put_writes, config, writes, task_id
)
return self.put_writes(config, writes, task_id)

def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
if current is None:
Expand Down
12 changes: 8 additions & 4 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,9 @@ def accept_push(
) -> Optional[PregelExecutableTask]:
"""Accept a PUSH from a task, potentially returning a new task to start."""
# don't start if we should interrupt *after* the original task
if should_interrupt(self.checkpoint, self.interrupt_after, [task]):
if self.interrupt_after and should_interrupt(
self.checkpoint, self.interrupt_after, [task]
):
self.to_interrupt.append(task)
return
if pushed := cast(
Expand All @@ -333,7 +335,9 @@ def accept_push(
),
):
# don't start if we should interrupt *before* the new task
if should_interrupt(self.checkpoint, self.interrupt_before, [pushed]):
if self.interrupt_before and should_interrupt(
self.checkpoint, self.interrupt_before, [pushed]
):
self.to_interrupt.append(pushed)
return
# produce debug output
Expand Down Expand Up @@ -409,7 +413,7 @@ def tick(
}
)
# after execution, check if we should interrupt
if should_interrupt(
if self.interrupt_after and should_interrupt(
self.checkpoint, self.interrupt_after, self.tasks.values()
):
self.status = "interrupt_after"
Expand Down Expand Up @@ -481,7 +485,7 @@ def tick(
return self.tick(input_keys=input_keys)

# before execution, check if we should interrupt
if should_interrupt(
if self.interrupt_before and should_interrupt(
self.checkpoint, self.interrupt_before, self.tasks.values()
):
self.status = "interrupt_before"
Expand Down
16 changes: 4 additions & 12 deletions libs/langgraph/langgraph/utils/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,10 @@ def invoke(
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
)
context = copy_context()
context.run(_set_config_context, config)
if i == 0:
input = context.run(step.invoke, input, config, **kwargs)
input = step.invoke(input, config, **kwargs)
else:
input = context.run(step.invoke, input, config)
input = step.invoke(input, config)
# finish the root run
except BaseException as e:
run_manager.on_chain_error(e)
Expand Down Expand Up @@ -443,16 +441,10 @@ async def ainvoke(
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
)
context = copy_context()
context.run(_set_config_context, config)
if i == 0:
coro = step.ainvoke(input, config, **kwargs)
else:
coro = step.ainvoke(input, config)
if ASYNCIO_ACCEPTS_CONTEXT:
input = await asyncio.create_task(coro, context=context)
input = await step.ainvoke(input, config, **kwargs)
else:
input = await asyncio.create_task(coro)
input = await step.ainvoke(input, config)
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12901,7 +12901,7 @@ def edit(state: JokeState):
metadata={
"step": 1,
"source": "loop",
"writes": {"edit": None},
"writes": None,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
Expand Down Expand Up @@ -12946,7 +12946,7 @@ def edit(state: JokeState):
metadata={
"step": 1,
"source": "loop",
"writes": {"edit": None},
"writes": None,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
Expand Down

0 comments on commit dc0398e

Please sign in to comment.