Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kafka: Make consumer and producer classes configurable #1686

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions libs/scheduler-kafka/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ from your_lib import graph # graph expected to be a compiled LangGraph graph
logger = logging.getLogger(__name__)

topics = Topics(
orchestrator: os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
executor: os.environ['KAFKA_TOPIC_EXECUTOR'],
error: os.environ['KAFKA_TOPIC_ERROR'],
orchestrator=os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
executor=os.environ['KAFKA_TOPIC_EXECUTOR'],
error=os.environ['KAFKA_TOPIC_ERROR'],
)

async def main():
Expand Down Expand Up @@ -64,9 +64,9 @@ from your_lib import graph # graph expected to be a compiled LangGraph graph
logger = logging.getLogger(__name__)

topics = Topics(
orchestrator: os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
executor: os.environ['KAFKA_TOPIC_EXECUTOR'],
error: os.environ['KAFKA_TOPIC_ERROR'],
orchestrator=os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
executor=os.environ['KAFKA_TOPIC_EXECUTOR'],
error=os.environ['KAFKA_TOPIC_ERROR'],
)

async def main():
Expand All @@ -91,7 +91,6 @@ python executor.py &

You can pass any of the following values as `kwargs` to either `KafkaOrchestrator` or `KafkaExecutor` to configure the consumer:

- group_id (str): a name for the consumer group. Defaults to 'orchestrator' or 'executor', respectively.
- batch_max_n (int): Maximum number of messages to include in a single batch. Default: 10.
- batch_max_ms (int): Maximum time in milliseconds to wait for messages to include in a batch. Default: 1000.
- retry_policy (langgraph.pregel.types.RetryPolicy): Controls which graph-level errors will be retried when processing messages. A good use for this is to retry database errors thrown by the checkpointer. Defaults to None.
Expand Down
16 changes: 16 additions & 0 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import dataclasses
from typing import Any, Sequence

import aiokafka


class DefaultAsyncConsumer(aiokafka.AIOKafkaConsumer):
async def getmany(
self, timeout_ms: int, max_records: int
) -> dict[str, Sequence[dict[str, Any]]]:
batch = await super().getmany(timeout_ms=timeout_ms, max_records=max_records)
return {t: [dataclasses.asdict(m) for m in msgs] for t, msgs in batch.items()}


class DefaultAsyncProducer(aiokafka.AIOKafkaProducer):
pass
128 changes: 72 additions & 56 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import partial
from typing import Any, Optional, Sequence

import aiokafka
import orjson
from langchain_core.runnables import RunnableConfig
from typing_extensions import Self
Expand All @@ -19,6 +18,8 @@
from langgraph.pregel.types import RetryPolicy
from langgraph.scheduler.kafka.retry import aretry
from langgraph.scheduler.kafka.types import (
AsyncConsumer,
AsyncProducer,
ErrorMessage,
MessageToExecutor,
MessageToOrchestrator,
Expand All @@ -28,51 +29,56 @@


class KafkaExecutor(AbstractAsyncContextManager):
consumer: AsyncConsumer

producer: AsyncProducer

def __init__(
self,
graph: Pregel,
topics: Topics,
*,
group_id: str = "executor",
batch_max_n: int = 10,
batch_max_ms: int = 1000,
retry_policy: Optional[RetryPolicy] = None,
consumer_kwargs: Optional[dict[str, Any]] = None,
producer_kwargs: Optional[dict[str, Any]] = None,
consumer: Optional[AsyncConsumer] = None,
producer: Optional[AsyncProducer] = None,
**kwargs: Any,
) -> None:
self.graph = graph
self.topics = topics
self.stack = AsyncExitStack()
self.kwargs = kwargs
self.consumer_kwargs = consumer_kwargs or {}
self.producer_kwargs = producer_kwargs or {}
self.group_id = group_id
self.consumer = consumer
self.producer = producer
self.batch_max_n = batch_max_n
self.batch_max_ms = batch_max_ms
self.retry_policy = retry_policy

async def __aenter__(self) -> Self:
self.consumer = await self.stack.enter_async_context(
aiokafka.AIOKafkaConsumer(
self.topics.executor,
value_deserializer=serde.loads,
auto_offset_reset="earliest",
group_id=self.group_id,
enable_auto_commit=False,
**self.kwargs,
)
)
self.producer = await self.stack.enter_async_context(
aiokafka.AIOKafkaProducer(
key_serializer=serde.dumps,
value_serializer=serde.dumps,
**self.kwargs,
)
)
self.subgraphs = {
k: v async for k, v in self.graph.aget_subgraphs(recurse=True)
}
if self.consumer is None:
from langgraph.scheduler.kafka.default_async import DefaultAsyncConsumer

self.consumer = await self.stack.enter_async_context(
DefaultAsyncConsumer(
self.topics.executor,
auto_offset_reset="earliest",
group_id="executor",
enable_auto_commit=False,
**self.kwargs,
)
)
if self.producer is None:
from langgraph.scheduler.kafka.default_async import DefaultAsyncProducer

self.producer = await self.stack.enter_async_context(
DefaultAsyncProducer(
**self.kwargs,
)
)
return self

async def __aexit__(self, *args: Any) -> None:
Expand All @@ -83,15 +89,12 @@ def __aiter__(self) -> Self:

async def __anext__(self) -> Sequence[MessageToExecutor]:
# wait for next batch
try:
recs = await self.consumer.getmany(
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
msgs: list[MessageToExecutor] = [
msg.value for msgs in recs.values() for msg in msgs
]
except aiokafka.ConsumerStoppedError:
raise StopAsyncIteration from None
recs = await self.consumer.getmany(
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
msgs: list[MessageToExecutor] = [
serde.loads(msg["value"]) for msgs in recs.values() for msg in msgs
]
# process batch
await asyncio.gather(*(self.each(msg) for msg in msgs))
# commit offsets
Expand All @@ -106,30 +109,38 @@ async def each(self, msg: MessageToExecutor) -> None:
pass
except GraphDelegate as exc:
for arg in exc.args:
await self.producer.send_and_wait(
fut = await self.producer.send(
self.topics.orchestrator,
value=MessageToOrchestrator(
config=arg["config"],
input=orjson.Fragment(
self.graph.checkpointer.serde.dumps(arg["input"])
),
finally_executor=[msg],
value=serde.dumps(
MessageToOrchestrator(
config=arg["config"],
input=orjson.Fragment(
self.graph.checkpointer.serde.dumps(arg["input"])
),
finally_executor=[msg],
)
),
# use thread_id, checkpoint_ns as partition key
key=(
arg["config"]["configurable"]["thread_id"],
arg["config"]["configurable"].get("checkpoint_ns"),
key=serde.dumps(
(
arg["config"]["configurable"]["thread_id"],
arg["config"]["configurable"].get("checkpoint_ns"),
)
),
)
await fut
except Exception as exc:
await self.producer.send_and_wait(
fut = await self.producer.send(
self.topics.error,
value=ErrorMessage(
topic=self.topics.executor,
msg=msg,
error=repr(exc),
value=serde.dumps(
ErrorMessage(
topic=self.topics.executor,
msg=msg,
error=repr(exc),
)
),
)
await fut

async def attempt(self, msg: MessageToExecutor) -> None:
# find graph
Expand Down Expand Up @@ -182,19 +193,24 @@ async def attempt(self, msg: MessageToExecutor) -> None:
msg["config"], [(ERROR, TaskNotFound())]
)
# notify orchestrator
await self.producer.send_and_wait(
fut = await self.producer.send(
self.topics.orchestrator,
value=MessageToOrchestrator(
input=None,
config=msg["config"],
finally_executor=msg.get("finally_executor"),
value=serde.dumps(
MessageToOrchestrator(
input=None,
config=msg["config"],
finally_executor=msg.get("finally_executor"),
)
),
# use thread_id, checkpoint_ns as partition key
key=(
msg["config"]["configurable"]["thread_id"],
msg["config"]["configurable"].get("checkpoint_ns"),
key=serde.dumps(
(
msg["config"]["configurable"]["thread_id"],
msg["config"]["configurable"].get("checkpoint_ns"),
)
),
)
await fut

def _put_writes(
self,
Expand Down
Loading
Loading