From 217b07ae65cdfece3eda1c9e728aaf7015ae3cb7 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 11 Sep 2024 14:27:08 -0700 Subject: [PATCH] kafka: Make consumer and producer classes configurable - Define protocol for sync and async producer and consumer - Accept consumer/producer as init args in Orchestrator/Executor - If not passed in, create default consumer/producer as before --- libs/scheduler-kafka/README.md | 13 +- .../scheduler/kafka/default_async.py | 16 +++ .../langgraph/scheduler/kafka/executor.py | 128 ++++++++++-------- .../langgraph/scheduler/kafka/orchestrator.py | 112 ++++++++------- .../langgraph/scheduler/kafka/types.py | 40 +++++- 5 files changed, 194 insertions(+), 115 deletions(-) create mode 100644 libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py diff --git a/libs/scheduler-kafka/README.md b/libs/scheduler-kafka/README.md index c279f5e38..2d934d59f 100644 --- a/libs/scheduler-kafka/README.md +++ b/libs/scheduler-kafka/README.md @@ -34,9 +34,9 @@ from your_lib import graph # graph expected to be a compiled LangGraph graph logger = logging.getLogger(__name__) topics = Topics( - orchestrator: os.environ['KAFKA_TOPIC_ORCHESTRATOR'], - executor: os.environ['KAFKA_TOPIC_EXECUTOR'], - error: os.environ['KAFKA_TOPIC_ERROR'], + orchestrator=os.environ['KAFKA_TOPIC_ORCHESTRATOR'], + executor=os.environ['KAFKA_TOPIC_EXECUTOR'], + error=os.environ['KAFKA_TOPIC_ERROR'], ) async def main(): @@ -64,9 +64,9 @@ from your_lib import graph # graph expected to be a compiled LangGraph graph logger = logging.getLogger(__name__) topics = Topics( - orchestrator: os.environ['KAFKA_TOPIC_ORCHESTRATOR'], - executor: os.environ['KAFKA_TOPIC_EXECUTOR'], - error: os.environ['KAFKA_TOPIC_ERROR'], + orchestrator=os.environ['KAFKA_TOPIC_ORCHESTRATOR'], + executor=os.environ['KAFKA_TOPIC_EXECUTOR'], + error=os.environ['KAFKA_TOPIC_ERROR'], ) async def main(): @@ -91,7 +91,6 @@ python executor.py & You can pass any of the following values as `kwargs` to either `KafkaOrchestrator` or `KafkaExecutor` to configure the consumer: -- group_id (str): a name for the consumer group. Defaults to 'orchestrator' or 'executor', respectively. - batch_max_n (int): Maximum number of messages to include in a single batch. Default: 10. - batch_max_ms (int): Maximum time in milliseconds to wait for messages to include in a batch. Default: 1000. - retry_policy (langgraph.pregel.types.RetryPolicy): Controls which graph-level errors will be retried when processing messages. A good use for this is to retry database errors thrown by the checkpointer. Defaults to None. diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py new file mode 100644 index 000000000..ce6703a03 --- /dev/null +++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py @@ -0,0 +1,16 @@ +import dataclasses +from typing import Any, Sequence + +import aiokafka + + +class DefaultAsyncConsumer(aiokafka.AIOKafkaConsumer): + async def getmany( + self, timeout_ms: int, max_records: int + ) -> dict[str, Sequence[dict[str, Any]]]: + batch = await super().getmany(timeout_ms=timeout_ms, max_records=max_records) + return {t: [dataclasses.asdict(m) for m in msgs] for t, msgs in batch.items()} + + +class DefaultAsyncProducer(aiokafka.AIOKafkaProducer): + pass diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py index 86812c258..f7bceb86a 100644 --- a/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py +++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py @@ -3,7 +3,6 @@ from functools import partial from typing import Any, Optional, Sequence -import aiokafka import orjson from langchain_core.runnables import RunnableConfig from typing_extensions import Self @@ -19,6 +18,8 @@ from langgraph.pregel.types import RetryPolicy from langgraph.scheduler.kafka.retry import aretry from langgraph.scheduler.kafka.types import ( + AsyncConsumer, + AsyncProducer, ErrorMessage, MessageToExecutor, MessageToOrchestrator, @@ -28,51 +29,56 @@ class KafkaExecutor(AbstractAsyncContextManager): + consumer: AsyncConsumer + + producer: AsyncProducer + def __init__( self, graph: Pregel, topics: Topics, *, - group_id: str = "executor", batch_max_n: int = 10, batch_max_ms: int = 1000, retry_policy: Optional[RetryPolicy] = None, - consumer_kwargs: Optional[dict[str, Any]] = None, - producer_kwargs: Optional[dict[str, Any]] = None, + consumer: Optional[AsyncConsumer] = None, + producer: Optional[AsyncProducer] = None, **kwargs: Any, ) -> None: self.graph = graph self.topics = topics self.stack = AsyncExitStack() self.kwargs = kwargs - self.consumer_kwargs = consumer_kwargs or {} - self.producer_kwargs = producer_kwargs or {} - self.group_id = group_id + self.consumer = consumer + self.producer = producer self.batch_max_n = batch_max_n self.batch_max_ms = batch_max_ms self.retry_policy = retry_policy async def __aenter__(self) -> Self: - self.consumer = await self.stack.enter_async_context( - aiokafka.AIOKafkaConsumer( - self.topics.executor, - value_deserializer=serde.loads, - auto_offset_reset="earliest", - group_id=self.group_id, - enable_auto_commit=False, - **self.kwargs, - ) - ) - self.producer = await self.stack.enter_async_context( - aiokafka.AIOKafkaProducer( - key_serializer=serde.dumps, - value_serializer=serde.dumps, - **self.kwargs, - ) - ) self.subgraphs = { k: v async for k, v in self.graph.aget_subgraphs(recurse=True) } + if self.consumer is None: + from langgraph.scheduler.kafka.default_async import DefaultAsyncConsumer + + self.consumer = await self.stack.enter_async_context( + DefaultAsyncConsumer( + self.topics.executor, + auto_offset_reset="earliest", + group_id="executor", + enable_auto_commit=False, + **self.kwargs, + ) + ) + if self.producer is None: + from langgraph.scheduler.kafka.default_async import DefaultAsyncProducer + + self.producer = await self.stack.enter_async_context( + DefaultAsyncProducer( + **self.kwargs, + ) + ) return self async def __aexit__(self, *args: Any) -> None: @@ -83,15 +89,12 @@ def __aiter__(self) -> Self: async def __anext__(self) -> Sequence[MessageToExecutor]: # wait for next batch - try: - recs = await self.consumer.getmany( - timeout_ms=self.batch_max_ms, max_records=self.batch_max_n - ) - msgs: list[MessageToExecutor] = [ - msg.value for msgs in recs.values() for msg in msgs - ] - except aiokafka.ConsumerStoppedError: - raise StopAsyncIteration from None + recs = await self.consumer.getmany( + timeout_ms=self.batch_max_ms, max_records=self.batch_max_n + ) + msgs: list[MessageToExecutor] = [ + serde.loads(msg["value"]) for msgs in recs.values() for msg in msgs + ] # process batch await asyncio.gather(*(self.each(msg) for msg in msgs)) # commit offsets @@ -106,30 +109,38 @@ async def each(self, msg: MessageToExecutor) -> None: pass except GraphDelegate as exc: for arg in exc.args: - await self.producer.send_and_wait( + fut = await self.producer.send( self.topics.orchestrator, - value=MessageToOrchestrator( - config=arg["config"], - input=orjson.Fragment( - self.graph.checkpointer.serde.dumps(arg["input"]) - ), - finally_executor=[msg], + value=serde.dumps( + MessageToOrchestrator( + config=arg["config"], + input=orjson.Fragment( + self.graph.checkpointer.serde.dumps(arg["input"]) + ), + finally_executor=[msg], + ) ), # use thread_id, checkpoint_ns as partition key - key=( - arg["config"]["configurable"]["thread_id"], - arg["config"]["configurable"].get("checkpoint_ns"), + key=serde.dumps( + ( + arg["config"]["configurable"]["thread_id"], + arg["config"]["configurable"].get("checkpoint_ns"), + ) ), ) + await fut except Exception as exc: - await self.producer.send_and_wait( + fut = await self.producer.send( self.topics.error, - value=ErrorMessage( - topic=self.topics.executor, - msg=msg, - error=repr(exc), + value=serde.dumps( + ErrorMessage( + topic=self.topics.executor, + msg=msg, + error=repr(exc), + ) ), ) + await fut async def attempt(self, msg: MessageToExecutor) -> None: # find graph @@ -182,19 +193,24 @@ async def attempt(self, msg: MessageToExecutor) -> None: msg["config"], [(ERROR, TaskNotFound())] ) # notify orchestrator - await self.producer.send_and_wait( + fut = await self.producer.send( self.topics.orchestrator, - value=MessageToOrchestrator( - input=None, - config=msg["config"], - finally_executor=msg.get("finally_executor"), + value=serde.dumps( + MessageToOrchestrator( + input=None, + config=msg["config"], + finally_executor=msg.get("finally_executor"), + ) ), # use thread_id, checkpoint_ns as partition key - key=( - msg["config"]["configurable"]["thread_id"], - msg["config"]["configurable"].get("checkpoint_ns"), + key=serde.dumps( + ( + msg["config"]["configurable"]["thread_id"], + msg["config"]["configurable"].get("checkpoint_ns"), + ) ), ) + await fut def _put_writes( self, diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py index c13de9c08..e594d88ba 100644 --- a/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py +++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py @@ -2,7 +2,6 @@ from contextlib import AbstractAsyncContextManager, AsyncExitStack from typing import Any, Optional -import aiokafka from langchain_core.runnables import ensure_config from typing_extensions import Self @@ -21,6 +20,8 @@ from langgraph.pregel.types import RetryPolicy from langgraph.scheduler.kafka.retry import aretry from langgraph.scheduler.kafka.types import ( + AsyncConsumer, + AsyncProducer, ErrorMessage, ExecutorTask, MessageToExecutor, @@ -31,50 +32,55 @@ class KafkaOrchestrator(AbstractAsyncContextManager): + consumer: AsyncConsumer + + producer: AsyncProducer + def __init__( self, graph: Pregel, topics: Topics, - group_id: str = "orchestrator", batch_max_n: int = 10, batch_max_ms: int = 1000, retry_policy: Optional[RetryPolicy] = None, - consumer_kwargs: Optional[dict[str, Any]] = None, - producer_kwargs: Optional[dict[str, Any]] = None, + consumer: Optional[AsyncConsumer] = None, + producer: Optional[AsyncProducer] = None, **kwargs: Any, ) -> None: self.graph = graph self.topics = topics self.stack = AsyncExitStack() self.kwargs = kwargs - self.consumer_kwargs = consumer_kwargs or {} - self.producer_kwargs = producer_kwargs or {} - self.group_id = group_id + self.consumer = consumer + self.producer = producer self.batch_max_n = batch_max_n self.batch_max_ms = batch_max_ms self.retry_policy = retry_policy async def __aenter__(self) -> Self: - self.consumer = await self.stack.enter_async_context( - aiokafka.AIOKafkaConsumer( - self.topics.orchestrator, - auto_offset_reset="earliest", - group_id=self.group_id, - enable_auto_commit=False, - **self.kwargs, - **self.consumer_kwargs, - ) - ) - self.producer = await self.stack.enter_async_context( - aiokafka.AIOKafkaProducer( - value_serializer=serde.dumps, - **self.kwargs, - **self.producer_kwargs, - ) - ) self.subgraphs = { k: v async for k, v in self.graph.aget_subgraphs(recurse=True) } + if self.consumer is None: + from langgraph.scheduler.kafka.default_async import DefaultAsyncConsumer + + self.consumer = await self.stack.enter_async_context( + DefaultAsyncConsumer( + self.topics.orchestrator, + auto_offset_reset="earliest", + group_id="orchestrator", + enable_auto_commit=False, + **self.kwargs, + ) + ) + if self.producer is None: + from langgraph.scheduler.kafka.default_async import DefaultAsyncProducer + + self.producer = await self.stack.enter_async_context( + DefaultAsyncProducer( + **self.kwargs, + ) + ) return self async def __aexit__(self, *args: Any) -> None: @@ -85,15 +91,12 @@ def __aiter__(self) -> Self: async def __anext__(self) -> list[MessageToOrchestrator]: # wait for next batch - try: - recs = await self.consumer.getmany( - timeout_ms=self.batch_max_ms, max_records=self.batch_max_n - ) - # dedupe messages, eg. if multiple nodes finish around same time - uniq = set(msg.value for msgs in recs.values() for msg in msgs) - msgs: list[MessageToOrchestrator] = [serde.loads(msg) for msg in uniq] - except aiokafka.ConsumerStoppedError: - raise StopAsyncIteration from None + recs = await self.consumer.getmany( + timeout_ms=self.batch_max_ms, max_records=self.batch_max_n + ) + # dedupe messages, eg. if multiple nodes finish around same time + uniq = set(msg["value"] for msgs in recs.values() for msg in msgs) + msgs: list[MessageToOrchestrator] = [serde.loads(msg) for msg in uniq] # process batch await asyncio.gather(*(self.each(msg) for msg in msgs)) # commit offsets @@ -109,14 +112,17 @@ async def each(self, msg: MessageToOrchestrator) -> None: except GraphInterrupt: pass except Exception as exc: - await self.producer.send_and_wait( + fut = await self.producer.send( self.topics.error, - value=ErrorMessage( - topic=self.topics.orchestrator, - msg=msg, - error=repr(exc), + value=serde.dumps( + ErrorMessage( + topic=self.topics.orchestrator, + msg=msg, + error=repr(exc), + ) ), ) + await fut async def attempt(self, msg: MessageToOrchestrator) -> None: # find graph @@ -155,21 +161,25 @@ async def attempt(self, msg: MessageToOrchestrator) -> None: # schedule any new tasks if new_tasks := [t for t in loop.tasks.values() if not t.scheduled]: # send messages to executor - futures: list[asyncio.Future] = await asyncio.gather( + futures = await asyncio.gather( *( self.producer.send( self.topics.executor, - value=MessageToExecutor( - config=patch_configurable( - loop.config, - { - **loop.checkpoint_config["configurable"], - CONFIG_KEY_DEDUPE_TASKS: True, - CONFIG_KEY_ENSURE_LATEST: True, - }, - ), - task=ExecutorTask(id=task.id, path=task.path), - finally_executor=msg.get("finally_executor"), + value=serde.dumps( + MessageToExecutor( + config=patch_configurable( + loop.config, + { + **loop.checkpoint_config[ + "configurable" + ], + CONFIG_KEY_DEDUPE_TASKS: True, + CONFIG_KEY_ENSURE_LATEST: True, + }, + ), + task=ExecutorTask(id=task.id, path=task.path), + finally_executor=msg.get("finally_executor"), + ) ), ) for task in new_tasks @@ -197,7 +207,7 @@ async def attempt(self, msg: MessageToOrchestrator) -> None: # schedule any finally_executor tasks futs = await asyncio.gather( *( - self.producer.send(self.topics.executor, value=m) + self.producer.send(self.topics.executor, value=serde.dumps(m)) for m in msg["finally_executor"] ) ) diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/types.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/types.py index 3bc298b93..65385a3a0 100644 --- a/libs/scheduler-kafka/langgraph/scheduler/kafka/types.py +++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/types.py @@ -1,4 +1,6 @@ -from typing import Any, NamedTuple, Optional, Sequence, TypedDict, Union +import asyncio +import concurrent.futures +from typing import Any, NamedTuple, Optional, Protocol, Sequence, TypedDict, Union from langchain_core.runnables import RunnableConfig @@ -30,3 +32,39 @@ class ErrorMessage(TypedDict): topic: str error: str msg: Union[MessageToExecutor, MessageToOrchestrator] + + +class Consumer(Protocol): + def getmany( + self, timeout_ms: int, max_records: int + ) -> dict[str, Sequence[dict[str, Any]]]: ... + + def commit(self) -> None: ... + + +class AsyncConsumer(Protocol): + async def getmany( + self, timeout_ms: int, max_records: int + ) -> dict[str, Sequence[dict[str, Any]]]: ... + + async def commit(self) -> None: ... + + +class Producer(Protocol): + def send( + self, + topic: str, + *, + key: Optional[bytes] = None, + value: Optional[bytes] = None, + ) -> concurrent.futures.Future: ... + + +class AsyncProducer(Protocol): + async def send( + self, + topic: str, + *, + key: Optional[bytes] = None, + value: Optional[bytes] = None, + ) -> asyncio.Future: ...