Skip to content

Commit

Permalink
Merge pull request #1693 from langchain-ai/nc/11sep/kafka-finally-send
Browse files Browse the repository at this point in the history
kafka: Expose finally_send as public api
  • Loading branch information
nfcampos authored Sep 11, 2024
2 parents 33017b5 + 5716234 commit 73d1051
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 277 deletions.
13 changes: 9 additions & 4 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
MessageToExecutor,
MessageToOrchestrator,
Producer,
Sendable,
Topics,
)
from langgraph.utils.config import patch_configurable
Expand Down Expand Up @@ -129,7 +130,9 @@ async def each(self, msg: MessageToExecutor) -> None:
input=orjson.Fragment(
self.graph.checkpointer.serde.dumps(arg["input"])
),
finally_executor=[msg],
finally_send=[
Sendable(topic=self.topics.executor, value=msg)
],
)
),
# use thread_id, checkpoint_ns as partition key
Expand Down Expand Up @@ -211,7 +214,7 @@ async def attempt(self, msg: MessageToExecutor) -> None:
MessageToOrchestrator(
input=None,
config=msg["config"],
finally_executor=msg.get("finally_executor"),
finally_send=msg.get("finally_send"),
)
),
# use thread_id, checkpoint_ns as partition key
Expand Down Expand Up @@ -322,7 +325,9 @@ def each(self, msg: MessageToExecutor) -> None:
input=orjson.Fragment(
self.graph.checkpointer.serde.dumps(arg["input"])
),
finally_executor=[msg],
finally_send=[
Sendable(topic=self.topics.executor, value=msg)
],
)
),
# use thread_id, checkpoint_ns as partition key
Expand Down Expand Up @@ -403,7 +408,7 @@ def attempt(self, msg: MessageToExecutor) -> None:
MessageToOrchestrator(
input=None,
config=msg["config"],
finally_executor=msg.get("finally_executor"),
finally_send=msg.get("finally_send"),
)
),
# use thread_id, checkpoint_ns as partition key
Expand Down
29 changes: 18 additions & 11 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def attempt(self, msg: MessageToOrchestrator) -> None:
},
),
task=ExecutorTask(id=task.id, path=task.path),
finally_executor=msg.get("finally_executor"),
finally_send=msg.get("finally_send"),
)
),
)
Expand All @@ -212,12 +212,16 @@ async def attempt(self, msg: MessageToOrchestrator) -> None:
)
],
)
elif loop.status == "done" and msg.get("finally_executor"):
# schedule any finally_executor tasks
elif loop.status == "done" and msg.get("finally_send"):
# send any finally_send messages
futs = await asyncio.gather(
*(
self.producer.send(self.topics.executor, value=serde.dumps(m))
for m in msg["finally_executor"]
self.producer.send(
m["topic"],
value=serde.dumps(m["value"]) if m.get("value") else None,
key=serde.dumps(m["key"]) if m.get("key") else None,
)
for m in msg["finally_send"]
)
)
# wait for messages to be sent
Expand Down Expand Up @@ -288,7 +292,6 @@ def __next__(self) -> list[MessageToOrchestrator]:
recs = self.consumer.getmany(
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
print("orch.__next__", recs)
# dedupe messages, eg. if multiple nodes finish around same time
uniq = set(msg.value for msgs in recs.values() for msg in msgs)
msgs: list[MessageToOrchestrator] = [serde.loads(msg) for msg in uniq]
Expand Down Expand Up @@ -370,7 +373,7 @@ def attempt(self, msg: MessageToOrchestrator) -> None:
},
),
task=ExecutorTask(id=task.id, path=task.path),
finally_executor=msg.get("finally_executor"),
finally_send=msg.get("finally_send"),
)
),
)
Expand All @@ -394,11 +397,15 @@ def attempt(self, msg: MessageToOrchestrator) -> None:
)
],
)
elif loop.status == "done" and msg.get("finally_executor"):
# schedule any finally_executor tasks
elif loop.status == "done" and msg.get("finally_send"):
# schedule any finally_send msgs
futs = [
self.producer.send(self.topics.executor, value=serde.dumps(m))
for m in msg["finally_executor"]
self.producer.send(
m["topic"],
value=serde.dumps(m["value"]) if m.get("value") else None,
key=serde.dumps(m["key"]) if m.get("key") else None,
)
for m in msg["finally_send"]
]
# wait for messages to be sent
concurrent.futures.wait(futs)
10 changes: 8 additions & 2 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@ class Topics(NamedTuple):
error: str


class Sendable(TypedDict):
topic: str
value: Optional[Any]
key: Optional[Any]


class MessageToOrchestrator(TypedDict):
input: Optional[dict[str, Any]]
config: RunnableConfig
finally_executor: Optional[Sequence["MessageToExecutor"]]
finally_send: Optional[Sequence[Sendable]]


class ExecutorTask(TypedDict):
Expand All @@ -25,7 +31,7 @@ class ExecutorTask(TypedDict):
class MessageToExecutor(TypedDict):
config: RunnableConfig
task: ExecutorTask
finally_executor: Optional[Sequence["MessageToExecutor"]]
finally_send: Optional[Sequence[Sendable]]


class ErrorMessage(TypedDict):
Expand Down
6 changes: 4 additions & 2 deletions libs/scheduler-kafka/tests/drain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ async def drain_topics_async(
def done() -> bool:
return (
len(orch_msgs) > 0
and any(orch_msgs)
and len(exec_msgs) > 0
and any(exec_msgs)
and not orch_msgs[-1]
and not exec_msgs[-1]
)
Expand Down Expand Up @@ -97,7 +99,9 @@ def drain_topics(
def done() -> bool:
return (
len(orch_msgs) > 0
and any(orch_msgs)
and len(exec_msgs) > 0
and any(exec_msgs)
and not orch_msgs[-1]
and not exec_msgs[-1]
)
Expand All @@ -110,7 +114,6 @@ def orchestrator() -> None:
if debug:
print("\n---\norch", len(msgs), msgs)
if done():
print("am i done? orchestrator")
event.set()
if event.is_set():
break
Expand All @@ -126,7 +129,6 @@ def executor() -> None:
if debug:
print("\n---\nexec", len(msgs), msgs)
if done():
print("am i done? executor")
event.set()
if event.is_set():
break
Expand Down
8 changes: 4 additions & 4 deletions libs/scheduler-kafka/tests/test_fanout.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def test_fanout_graph(topics: Topics, acheckpointer: BaseCheckpointSaver)
"tags": [],
},
"input": None,
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history)
for _ in c.tasks
Expand All @@ -161,7 +161,7 @@ async def test_fanout_graph(topics: Topics, acheckpointer: BaseCheckpointSaver)
"id": t.id,
"path": list(t.path),
},
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history)
for t in c.tasks
Expand Down Expand Up @@ -218,7 +218,7 @@ async def test_fanout_graph_w_interrupt(
"tags": [],
},
"input": None,
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
# orchestrator messages appear only after tasks for that checkpoint
Expand All @@ -245,7 +245,7 @@ async def test_fanout_graph_w_interrupt(
"id": t.id,
"path": list(t.path),
},
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
for t in c.tasks
Expand Down
8 changes: 4 additions & 4 deletions libs/scheduler-kafka/tests/test_fanout_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_fanout_graph(topics: Topics, checkpointer: BaseCheckpointSaver) -> None
"tags": [],
},
"input": None,
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history)
for _ in c.tasks
Expand All @@ -158,7 +158,7 @@ def test_fanout_graph(topics: Topics, checkpointer: BaseCheckpointSaver) -> None
"id": t.id,
"path": list(t.path),
},
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history)
for t in c.tasks
Expand Down Expand Up @@ -216,7 +216,7 @@ def test_fanout_graph_w_interrupt(
"tags": [],
},
"input": None,
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
# orchestrator messages appear only after tasks for that checkpoint
Expand All @@ -243,7 +243,7 @@ def test_fanout_graph_w_interrupt(
"id": t.id,
"path": list(t.path),
},
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
for t in c.tasks
Expand Down
Loading

0 comments on commit 73d1051

Please sign in to comment.