Skip to content

Commit

Permalink
Merge pull request #1686 from langchain-ai/nc/11sep/kafka-configurabl…
Browse files Browse the repository at this point in the history
…e-consumer-producer

kafka: Make consumer and producer classes configurable
  • Loading branch information
nfcampos authored Sep 11, 2024
2 parents cd5f5c1 + 217b07a commit 115c37a
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 115 deletions.
13 changes: 6 additions & 7 deletions libs/scheduler-kafka/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ from your_lib import graph # graph expected to be a compiled LangGraph graph
logger = logging.getLogger(__name__)

topics = Topics(
orchestrator: os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
executor: os.environ['KAFKA_TOPIC_EXECUTOR'],
error: os.environ['KAFKA_TOPIC_ERROR'],
orchestrator=os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
executor=os.environ['KAFKA_TOPIC_EXECUTOR'],
error=os.environ['KAFKA_TOPIC_ERROR'],
)

async def main():
Expand Down Expand Up @@ -64,9 +64,9 @@ from your_lib import graph # graph expected to be a compiled LangGraph graph
logger = logging.getLogger(__name__)

topics = Topics(
orchestrator: os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
executor: os.environ['KAFKA_TOPIC_EXECUTOR'],
error: os.environ['KAFKA_TOPIC_ERROR'],
orchestrator=os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
executor=os.environ['KAFKA_TOPIC_EXECUTOR'],
error=os.environ['KAFKA_TOPIC_ERROR'],
)

async def main():
Expand All @@ -91,7 +91,6 @@ python executor.py &

You can pass any of the following values as `kwargs` to either `KafkaOrchestrator` or `KafkaExecutor` to configure the consumer:

- group_id (str): a name for the consumer group. Defaults to 'orchestrator' or 'executor', respectively.
- batch_max_n (int): Maximum number of messages to include in a single batch. Default: 10.
- batch_max_ms (int): Maximum time in milliseconds to wait for messages to include in a batch. Default: 1000.
- retry_policy (langgraph.pregel.types.RetryPolicy): Controls which graph-level errors will be retried when processing messages. A good use for this is to retry database errors thrown by the checkpointer. Defaults to None.
Expand Down
16 changes: 16 additions & 0 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import dataclasses
from typing import Any, Sequence

import aiokafka


class DefaultAsyncConsumer(aiokafka.AIOKafkaConsumer):
async def getmany(
self, timeout_ms: int, max_records: int
) -> dict[str, Sequence[dict[str, Any]]]:
batch = await super().getmany(timeout_ms=timeout_ms, max_records=max_records)
return {t: [dataclasses.asdict(m) for m in msgs] for t, msgs in batch.items()}


class DefaultAsyncProducer(aiokafka.AIOKafkaProducer):
pass
128 changes: 72 additions & 56 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import partial
from typing import Any, Optional, Sequence

import aiokafka
import orjson
from langchain_core.runnables import RunnableConfig
from typing_extensions import Self
Expand All @@ -19,6 +18,8 @@
from langgraph.pregel.types import RetryPolicy
from langgraph.scheduler.kafka.retry import aretry
from langgraph.scheduler.kafka.types import (
AsyncConsumer,
AsyncProducer,
ErrorMessage,
MessageToExecutor,
MessageToOrchestrator,
Expand All @@ -28,51 +29,56 @@


class KafkaExecutor(AbstractAsyncContextManager):
consumer: AsyncConsumer

producer: AsyncProducer

def __init__(
self,
graph: Pregel,
topics: Topics,
*,
group_id: str = "executor",
batch_max_n: int = 10,
batch_max_ms: int = 1000,
retry_policy: Optional[RetryPolicy] = None,
consumer_kwargs: Optional[dict[str, Any]] = None,
producer_kwargs: Optional[dict[str, Any]] = None,
consumer: Optional[AsyncConsumer] = None,
producer: Optional[AsyncProducer] = None,
**kwargs: Any,
) -> None:
self.graph = graph
self.topics = topics
self.stack = AsyncExitStack()
self.kwargs = kwargs
self.consumer_kwargs = consumer_kwargs or {}
self.producer_kwargs = producer_kwargs or {}
self.group_id = group_id
self.consumer = consumer
self.producer = producer
self.batch_max_n = batch_max_n
self.batch_max_ms = batch_max_ms
self.retry_policy = retry_policy

async def __aenter__(self) -> Self:
self.consumer = await self.stack.enter_async_context(
aiokafka.AIOKafkaConsumer(
self.topics.executor,
value_deserializer=serde.loads,
auto_offset_reset="earliest",
group_id=self.group_id,
enable_auto_commit=False,
**self.kwargs,
)
)
self.producer = await self.stack.enter_async_context(
aiokafka.AIOKafkaProducer(
key_serializer=serde.dumps,
value_serializer=serde.dumps,
**self.kwargs,
)
)
self.subgraphs = {
k: v async for k, v in self.graph.aget_subgraphs(recurse=True)
}
if self.consumer is None:
from langgraph.scheduler.kafka.default_async import DefaultAsyncConsumer

self.consumer = await self.stack.enter_async_context(
DefaultAsyncConsumer(
self.topics.executor,
auto_offset_reset="earliest",
group_id="executor",
enable_auto_commit=False,
**self.kwargs,
)
)
if self.producer is None:
from langgraph.scheduler.kafka.default_async import DefaultAsyncProducer

self.producer = await self.stack.enter_async_context(
DefaultAsyncProducer(
**self.kwargs,
)
)
return self

async def __aexit__(self, *args: Any) -> None:
Expand All @@ -83,15 +89,12 @@ def __aiter__(self) -> Self:

async def __anext__(self) -> Sequence[MessageToExecutor]:
# wait for next batch
try:
recs = await self.consumer.getmany(
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
msgs: list[MessageToExecutor] = [
msg.value for msgs in recs.values() for msg in msgs
]
except aiokafka.ConsumerStoppedError:
raise StopAsyncIteration from None
recs = await self.consumer.getmany(
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
msgs: list[MessageToExecutor] = [
serde.loads(msg["value"]) for msgs in recs.values() for msg in msgs
]
# process batch
await asyncio.gather(*(self.each(msg) for msg in msgs))
# commit offsets
Expand All @@ -106,30 +109,38 @@ async def each(self, msg: MessageToExecutor) -> None:
pass
except GraphDelegate as exc:
for arg in exc.args:
await self.producer.send_and_wait(
fut = await self.producer.send(
self.topics.orchestrator,
value=MessageToOrchestrator(
config=arg["config"],
input=orjson.Fragment(
self.graph.checkpointer.serde.dumps(arg["input"])
),
finally_executor=[msg],
value=serde.dumps(
MessageToOrchestrator(
config=arg["config"],
input=orjson.Fragment(
self.graph.checkpointer.serde.dumps(arg["input"])
),
finally_executor=[msg],
)
),
# use thread_id, checkpoint_ns as partition key
key=(
arg["config"]["configurable"]["thread_id"],
arg["config"]["configurable"].get("checkpoint_ns"),
key=serde.dumps(
(
arg["config"]["configurable"]["thread_id"],
arg["config"]["configurable"].get("checkpoint_ns"),
)
),
)
await fut
except Exception as exc:
await self.producer.send_and_wait(
fut = await self.producer.send(
self.topics.error,
value=ErrorMessage(
topic=self.topics.executor,
msg=msg,
error=repr(exc),
value=serde.dumps(
ErrorMessage(
topic=self.topics.executor,
msg=msg,
error=repr(exc),
)
),
)
await fut

async def attempt(self, msg: MessageToExecutor) -> None:
# find graph
Expand Down Expand Up @@ -182,19 +193,24 @@ async def attempt(self, msg: MessageToExecutor) -> None:
msg["config"], [(ERROR, TaskNotFound())]
)
# notify orchestrator
await self.producer.send_and_wait(
fut = await self.producer.send(
self.topics.orchestrator,
value=MessageToOrchestrator(
input=None,
config=msg["config"],
finally_executor=msg.get("finally_executor"),
value=serde.dumps(
MessageToOrchestrator(
input=None,
config=msg["config"],
finally_executor=msg.get("finally_executor"),
)
),
# use thread_id, checkpoint_ns as partition key
key=(
msg["config"]["configurable"]["thread_id"],
msg["config"]["configurable"].get("checkpoint_ns"),
key=serde.dumps(
(
msg["config"]["configurable"]["thread_id"],
msg["config"]["configurable"].get("checkpoint_ns"),
)
),
)
await fut

def _put_writes(
self,
Expand Down
Loading

0 comments on commit 115c37a

Please sign in to comment.