Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functional api expose previous value #3073

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 62.8 ms +- 1.9 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 56.0 ms +- 1.8 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 80.2 ms +- 2.9 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 109 ms +- 10 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 669 ms +- 36 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 535 ms +- 16 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 806 ms +- 25 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 994 ms +- 26 ms ......................................... react_agent_10x: Mean +- std dev: 30.5 ms +- 0.7 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.8 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 38.3 ms +- 0.7 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.6 ms +- 0.7 ms ......................................... react_agent_100x: Mean +- std dev: 341 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 271 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 681 ms +- 28 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 657 ms +- 26 ms ......................................... wide_state_25x300: Mean +- std dev: 23.7 ms +- 0.6 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.8 ms +- 0.3 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 253 ms +- 17 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 253 ms +- 19 ms ......................................... wide_state_15x600: Mean +- std dev: 28.0 ms +- 0.8 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 18.1 ms +- 0.3 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 437 ms +- 19 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 433 ms +- 21 ms ......................................... wide_state_9x1200: Mean +- std dev: 28.0 ms +- 0.6 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 18.2 ms +- 0.3 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 286 ms +- 18 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 278 ms +- 14 ms

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+----------+-----------------------+ | Benchmark | main | changes | +=========================================+==========+=======================+ | fanout_to_subgraph_100x_checkpoint_sync | 1.01 sec | 994 ms: 1.02x faster | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_10x | 63.7 ms | 62.8 ms: 1.01x faster | +-----------------------------------------+----------+-----------------------+ | react_agent_100x_sync | 270 ms | 271 ms: 1.00x slower | +-----------------------------------------+----------+-----------------------+ | wide_state_15x600_sync | 17.9 ms | 18.1 ms: 1.01x slower | +-----------------------------------------+----------+-----------------------+ | wide_state_9x1200_sync | 17.9 ms | 18.2 ms: 1.01x slower | +-----------------------------------------+----------+-----------------------+ | wide_state_25x300 | 23.3 ms | 23.7 ms: 1.01x slower | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_100x_sync | 526 ms | 535 ms: 1.02x slower | +-----------------------------------------+----------+-----------------------+ | wide_state_25x300_checkpoint | 248 ms | 253 ms: 1.02x slower | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_10x_sync | 55.0 ms | 56.0 ms: 1.02x slower | +-----------------------------------------+----------+-----------------------+ | wide_state_15x600_checkpoint | 429 ms | 437 ms: 1.02x slower | +-----------------------------------------+----------+-----------------------+ | wide_state_15x600_checkpoint_sync | 424 ms | 433 ms: 1.02x slower | +-----------------------------------------+----------+-----------------------+ | wide_state_25x300_sync | 15.5 ms | 15.8 ms: 1.02x slower | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 78.1 ms | 80.2 ms: 1.03x slower | +-----------------------------------------+----------+-----------------------+ | wide_state_15x600 | 27.2 ms | 28.0 ms: 1.03x slower | +-----------------------------------------+----------+-----------------------+ | wide_state_9x1200 | 27.3 ms | 28.0 ms: 1.03x slower | +-----------------------------------------+----------+-----------------------+ | wide_state_25x300_checkpoint_sync | 245 ms | 253 ms: 1.04x slower | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_100x | 641 ms | 669 ms: 1.04x slower | +-----------------------------------------+----------+-----------------------+ | react_agent_100x_checkpoint_sync | 619 ms | 657 ms: 1.06x slower | +-----------------------------------------+----------+-----------------------+ | react_agent_100x_checkpoint | 639 ms | 681 ms: 1.06x slower | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 99.0 ms | 109 ms: 1.10x slower | +-----------------------------------------+----------+-----------------------+ | Geometric mean | (ref) | 1.02x slower | +-----------------------------------------+----------+-----------------------+ Benchmark hidden because not significant (8): react_agent_10x, react_agent_10x_checkpoint, react_agent_10x_sync, react_agent_10x_checkpoint_sync, fanout_to_subgraph_100x_checkpoint, react_agent_100x, wide_state_9x1200_checkpoint, wide_state_9x1200_checkpoint_sync
from types import MappingProxyType
from typing import Any, Literal, Mapping, cast

Expand Down Expand Up @@ -78,6 +78,8 @@
# holds a callback to be called when a node is finished
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
# holds a mutable dict for temporary storage scoped to the current task
CONFIG_KEY_END = sys.intern("__pregel_previous")
# holds the previous return value from a stateful Pregel graph.

# --- Other constants ---
PUSH = sys.intern("__pregel_push")
Expand Down
96 changes: 85 additions & 11 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,96 @@ def entrypoint(
config_schema: Optional[type[Any]] = None,
) -> Callable[[types.FunctionType], Pregel]:
def _imp(func: types.FunctionType) -> Pregel:
# wrap generators in a function that writes to StreamWriter
if inspect.isgeneratorfunction(func):
"""Convert a function into a Pregel graph.

def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
for chunk in func(*args, **kwargs):
writer(chunk)
Args:
func: The function to convert. Support both sync and async functions, as well
as generator and async generator functions.

Returns:
A Pregel graph.
"""
# wrap generators in a function that writes to StreamWriter
if inspect.isgeneratorfunction(func):
original_sig = inspect.signature(func)
# Check if original signature has a writer argument with a matching type.
# If not, we'll inject it into the decorator, but not pass it
# to the wrapped function.
if "writer" in original_sig.parameters:

@functools.wraps(func)
def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
chunks = []
for chunk in func(*args, writer=writer, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks
else:

@functools.wraps(func)
def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
chunks = []
# Do not pass the writer argument to the wrapped function
# as it does not have a matching parameter
for chunk in func(*args, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks

# Create a new parameter for the writer argument
extra_param = inspect.Parameter(
"writer",
inspect.Parameter.KEYWORD_ONLY,
# The extra argument is a keyword-only argument
default=lambda _: None,
)
# Update the function's signature to include the extra argument
new_params = list(original_sig.parameters.values()) + [extra_param]
new_sig = original_sig.replace(parameters=new_params)
# Update the signature of the wrapper function
gen_wrapper.__signature__ = new_sig # type: ignore
bound = get_runnable_for_func(gen_wrapper)
stream_mode: StreamMode = "custom"
elif inspect.isasyncgenfunction(func):

async def agen_wrapper(
*args: Any, writer: StreamWriter, **kwargs: Any
) -> Any:
async for chunk in func(*args, **kwargs):
writer(chunk)
original_sig = inspect.signature(func)
# Check if original signature has a writer argument with a matching type.
# If not, we'll inject it into the decorator, but not pass it
# to the wrapped function.
if "writer" in original_sig.parameters:

@functools.wraps(func)
async def agen_wrapper(
*args: Any, writer: StreamWriter, **kwargs: Any
) -> Any:
chunks = []
async for chunk in func(*args, writer=writer, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks
else:

@functools.wraps(func)
async def agen_wrapper(
*args: Any, writer: StreamWriter, **kwargs: Any
) -> Any:
chunks = []
async for chunk in func(*args, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks

# Create a new parameter for the writer argument
extra_param = inspect.Parameter(
"writer",
inspect.Parameter.KEYWORD_ONLY,
# The extra argument is a keyword-only argument
default=lambda _: None,
)
# Update the function's signature to include the extra argument
new_params = list(original_sig.parameters.values()) + [extra_param]
new_sig = original_sig.replace(parameters=new_params)
# Update the signature of the wrapper function
agen_wrapper.__signature__ = new_sig # type: ignore

bound = get_runnable_for_func(agen_wrapper)
stream_mode = "custom"
Expand Down
10 changes: 10 additions & 0 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_END,
CONFIG_KEY_READ,
CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_SEND,
Expand Down Expand Up @@ -507,6 +508,9 @@ def prepare_single_task(
pending_writes,
task_id,
),
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down Expand Up @@ -616,6 +620,9 @@ def prepare_single_task(
pending_writes,
task_id,
),
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down Expand Up @@ -737,6 +744,9 @@ def prepare_single_task(
pending_writes,
task_id,
),
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down
50 changes: 38 additions & 12 deletions libs/langgraph/langgraph/utils/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from typing_extensions import TypeGuard

from langgraph.constants import CONF, CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER
from langgraph.constants import (
CONF,
CONFIG_KEY_END,
CONFIG_KEY_STORE,
CONFIG_KEY_STREAM_WRITER,
)
from langgraph.store.base import BaseStore
from langgraph.types import StreamWriter
from langgraph.utils.config import (
Expand All @@ -58,6 +63,10 @@ class StrEnum(str, enum.Enum):
"""A string enum."""


# Special type to denote any type is accepted
ANY_TYPE = object()


ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)

KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
Expand All @@ -73,6 +82,12 @@ class StrEnum(str, enum.Enum):
CONFIG_KEY_STORE,
inspect.Parameter.empty,
),
(
sys.intern("previous"),
(ANY_TYPE,),
CONFIG_KEY_END,
inspect.Parameter.empty,
),
)
"""List of kwargs that can be passed to functions, and their corresponding
config keys, default values and type annotations.
Expand Down Expand Up @@ -135,9 +150,12 @@ def __init__(
self.func_accepts: dict[str, bool] = {}
for kw, typ, _, _ in KWARGS_CONFIG_KEYS:
p = params.get(kw)
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)
if typ == (ANY_TYPE,):
self.func_accepts[kw] = p is not None and p.kind in VALID_KINDS
else:
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)

def __repr__(self) -> str:
repr_args = {
Expand All @@ -162,16 +180,20 @@ def invoke(
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue

if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
if (
default_value is inspect.Parameter.empty
and kw not in kwargs
and config_key not in _conf
):
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
f"Missing required config key '{config_key}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
kwargs[kw] = _conf.get(config_key, default_value)

context = copy_context()
if self.trace:
Expand Down Expand Up @@ -210,16 +232,20 @@ async def ainvoke(
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue

if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
if (
default_value is inspect.Parameter.empty
and kw not in kwargs
and config_key not in _conf
):
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
f"Missing required config key '{config_key}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
kwargs[kw] = _conf.get(config_key, default_value)
context = copy_context()
if self.trace:
callback_manager = get_async_callback_manager_for_config(config, self.tags)
Expand Down
Loading
Loading