diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index 2dda24754..26d583ed2 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -1,6 +1,7 @@ import asyncio import concurrent import concurrent.futures +import inspect import types from functools import partial, update_wrapper from typing import ( @@ -24,7 +25,7 @@ from langgraph.pregel.read import PregelNode from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry from langgraph.store.base import BaseStore -from langgraph.types import RetryPolicy +from langgraph.types import RetryPolicy, StreamMode, StreamWriter P = ParamSpec("P") P1 = TypeVar("P1") @@ -76,10 +77,32 @@ def entrypoint( store: Optional[BaseStore] = None, ) -> Callable[[types.FunctionType], Pregel]: def _imp(func: types.FunctionType) -> Pregel: + if inspect.isgeneratorfunction(func): + + def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any: + for chunk in func(*args, **kwargs): + writer(chunk) + + bound = get_runnable_for_func(gen_wrapper) + stream_mode: StreamMode = "custom" + elif inspect.isasyncgenfunction(func): + + async def agen_wrapper( + *args: Any, writer: StreamWriter, **kwargs: Any + ) -> Any: + async for chunk in func(*args, **kwargs): + writer(chunk) + + bound = get_runnable_for_func(agen_wrapper) + stream_mode = "custom" + else: + bound = get_runnable_for_func(func) + stream_mode = "updates" + return Pregel( nodes={ func.__name__: PregelNode( - bound=get_runnable_for_func(func), + bound=bound, triggers=[START], channels=[START], writers=[ChannelWrite([ChannelWriteEntry(END)], tags=[TAG_HIDDEN])], @@ -89,7 +112,7 @@ def _imp(func: types.FunctionType) -> Pregel: input_channels=START, output_channels=END, stream_channels=END, - stream_mode="updates", + stream_mode=stream_mode, checkpointer=checkpointer, store=store, )