Skip to content

Commit

Permalink
Merge pull request #2378 from langchain-ai/nc/8nov/send-future
Browse files Browse the repository at this point in the history
Imperative API
  • Loading branch information
nfcampos authored Dec 7, 2024
2 parents 6784a5a + 2fa2469 commit 1af1911
Show file tree
Hide file tree
Showing 15 changed files with 1,103 additions and 245 deletions.
4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/channels/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, Optional, Sequence, Type, TypeVar
from typing import Any, Generic, Optional, Sequence, TypeVar

from typing_extensions import Self

Expand All @@ -13,7 +13,7 @@
class BaseChannel(Generic[Value, Update, C], ABC):
__slots__ = ("key", "typ")

def __init__(self, typ: Type[Any], key: str = "") -> None:
def __init__(self, typ: Any, key: str = "") -> None:
self.typ = typ
self.key = key

Expand Down
4 changes: 4 additions & 0 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@
# marker to signal node was scheduled (in distributed mode)
TASKS = sys.intern("__pregel_tasks")
# for Send objects returned by nodes/edges, corresponds to PUSH below
RETURN = sys.intern("__return__")
# for writes of a task where we simply record the return value

# --- Reserved config.configurable keys ---
CONFIG_KEY_SEND = sys.intern("__pregel_send")
# holds the `write` function that accepts writes to state/edges/reserved keys
CONFIG_KEY_READ = sys.intern("__pregel_read")
# holds the `read` function that returns a copy of the current state
CONFIG_KEY_CALL = sys.intern("__pregel_call")
# holds the `call` function that accepts a node/func, args and returns a future
CONFIG_KEY_CHECKPOINTER = sys.intern("__pregel_checkpointer")
# holds a `BaseCheckpointSaver` passed from parent graph to child graphs
CONFIG_KEY_STREAM = sys.intern("__pregel_stream")
Expand Down
97 changes: 97 additions & 0 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import asyncio
import concurrent
import concurrent.futures
import types
from functools import partial, update_wrapper
from typing import (
Any,
Awaitable,
Callable,
Optional,
TypeVar,
Union,
overload,
)

from typing_extensions import ParamSpec

from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.last_value import LastValue
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import END, START, TAG_HIDDEN
from langgraph.pregel import Pregel
from langgraph.pregel.call import get_runnable_for_func
from langgraph.pregel.read import PregelNode
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import RetryPolicy

P = ParamSpec("P")
P1 = TypeVar("P1")
T = TypeVar("T")


def call(
func: Callable[[P1], T],
input: P1,
*,
retry: Optional[RetryPolicy] = None,
) -> concurrent.futures.Future[T]:
from langgraph.constants import CONFIG_KEY_CALL
from langgraph.utils.config import get_configurable

conf = get_configurable()
impl = conf[CONFIG_KEY_CALL]
fut = impl(func, input, retry=retry)
return fut


@overload
def task(
*, retry: Optional[RetryPolicy] = None
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, asyncio.Future[T]]]: ...


@overload
def task( # type: ignore[overload-cannot-match]
*, retry: Optional[RetryPolicy] = None
) -> Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]]: ...


def task(
*, retry: Optional[RetryPolicy] = None
) -> Union[
Callable[[Callable[P, Awaitable[T]]], Callable[P, asyncio.Future[T]]],
Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]],
]:
def _task(func: Callable[P, T]) -> Callable[P, concurrent.futures.Future[T]]:
return update_wrapper(partial(call, func, retry=retry), func)

return _task


def entrypoint(
*,
checkpointer: Optional[BaseCheckpointSaver] = None,
store: Optional[BaseStore] = None,
) -> Callable[[types.FunctionType], Pregel]:
def _imp(func: types.FunctionType) -> Pregel:
return Pregel(
nodes={
func.__name__: PregelNode(
bound=get_runnable_for_func(func),
triggers=[START],
channels=[START],
writers=[ChannelWrite([ChannelWriteEntry(END)], tags=[TAG_HIDDEN])],
)
},
channels={START: EphemeralValue(Any), END: LastValue(Any, END)},
input_channels=START,
output_channels=END,
stream_channels=END,
stream_mode="updates",
checkpointer=checkpointer,
store=store,
)

return _imp
Loading

0 comments on commit 1af1911

Please sign in to comment.