From d973a9b5dbdf4452d2924bb71ce7ce79f6f0ff9f Mon Sep 17 00:00:00 2001 From: Mingwei Samuel Date: Thu, 5 Dec 2024 16:19:39 -0800 Subject: [PATCH] feat(hydroflow): add `repeat_n()` windowing operator, modify scheduler --- hydroflow/src/scheduled/context.rs | 17 ++-- hydroflow/src/scheduled/graph.rs | 8 ++ ...surface_loop__flo_nested@graphvis_dot.snap | 19 +++-- ...ace_loop__flo_nested@graphvis_mermaid.snap | 19 +++-- ...rface_loop__flo_repeat_n@graphvis_dot.snap | 75 +++++++++++++++++ ...e_loop__flo_repeat_n@graphvis_mermaid.snap | 62 ++++++++++++++ ...surface_loop__flo_syntax@graphvis_dot.snap | 16 ++-- ...ace_loop__flo_syntax@graphvis_mermaid.snap | 16 ++-- hydroflow/tests/surface_loop.rs | 74 ++++++++++++++++- hydroflow_lang/src/graph/ops/batch.rs | 4 +- hydroflow_lang/src/graph/ops/repeat_n.rs | 81 +++++++------------ 11 files changed, 302 insertions(+), 89 deletions(-) create mode 100644 hydroflow/tests/snapshots/surface_loop__flo_repeat_n@graphvis_dot.snap create mode 100644 hydroflow/tests/snapshots/surface_loop__flo_repeat_n@graphvis_mermaid.snap diff --git a/hydroflow/src/scheduled/context.rs b/hydroflow/src/scheduled/context.rs index 0d83a715b39d..bc469a50ff95 100644 --- a/hydroflow/src/scheduled/context.rs +++ b/hydroflow/src/scheduled/context.rs @@ -3,12 +3,14 @@ //! Provides APIs for state and scheduling. use std::any::Any; +use std::cell::RefCell; use std::collections::VecDeque; use std::future::Future; use std::marker::PhantomData; use std::ops::DerefMut; use std::pin::Pin; +use smallvec::SmallVec; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use tokio::task::JoinHandle; use web_time::SystemTime; @@ -36,10 +38,13 @@ pub struct Context { /// If the events have been received for this tick. pub(super) events_received_tick: bool, - // TODO(mingwei): as long as this is here, it's impossible to know when all work is done. - // Second field (bool) is for if the event is an external "important" event (true). + // TODO(mingwei): as long as this is unclosed, it's impossible to know when all work is done. + /// Second field (bool) is for if the event is an external "important" event (true). pub(super) event_queue_send: UnboundedSender<(SubgraphId, bool)>, + /// Subgraphs rescheduled in the current stratum. + pub(super) rescheduled_subgraphs: RefCell>, + pub(super) current_tick: TickInstant, pub(super) current_stratum: usize, @@ -51,7 +56,6 @@ pub struct Context { pub(super) subgraph_id: SubgraphId, tasks_to_spawn: Vec + 'static>>>, - /// Join handles for spawned tasks. task_join_handles: Vec>, } @@ -95,8 +99,10 @@ impl Context { } /// Schedules the current subgraph to run again _this tick_. - pub fn reschedule_current_subgraph(&mut self) { - self.stratum_queues[self.current_stratum].push_back(self.subgraph_id); + pub fn reschedule_current_subgraph(&self) { + self.rescheduled_subgraphs + .borrow_mut() + .push(self.subgraph_id); } /// Returns a `Waker` for interacting with async Rust. @@ -240,6 +246,7 @@ impl Default for Context { events_received_tick: false, event_queue_send, + rescheduled_subgraphs: Default::default(), current_stratum: 0, current_tick: TickInstant::default(), diff --git a/hydroflow/src/scheduled/graph.rs b/hydroflow/src/scheduled/graph.rs index 8c36f3bf7fb2..4913fe28c42c 100644 --- a/hydroflow/src/scheduled/graph.rs +++ b/hydroflow/src/scheduled/graph.rs @@ -285,6 +285,14 @@ impl<'a> Hydroflow<'a> { } } } + + for sg_id in self.context.rescheduled_subgraphs.borrow_mut().drain(..) { + let sg_data = &self.subgraphs[sg_id.0]; + assert_eq!(sg_data.stratum, self.context.current_stratum); + if !sg_data.is_scheduled.replace(true) { + self.context.stratum_queues[sg_data.stratum].push_back(sg_id); + } + } } work_done } diff --git a/hydroflow/tests/snapshots/surface_loop__flo_nested@graphvis_dot.snap b/hydroflow/tests/snapshots/surface_loop__flo_nested@graphvis_dot.snap index e566a3c1a17b..b12d8afaa9f7 100644 --- a/hydroflow/tests/snapshots/surface_loop__flo_nested@graphvis_dot.snap +++ b/hydroflow/tests/snapshots/surface_loop__flo_nested@graphvis_dot.snap @@ -13,21 +13,23 @@ digraph { n6v1 [label="(n6v1) flatten()", shape=invhouse, fillcolor="#88aaff"] n7v1 [label="(n7v1) cross_join::<'static, 'tick>()", shape=invhouse, fillcolor="#88aaff"] n8v1 [label="(n8v1) all_once()", shape=invhouse, fillcolor="#88aaff"] - n9v1 [label="(n9v1) for_each(|all| println!(\"{}: {:?}\", context.current_tick(), all))", shape=house, fillcolor="#ffff88"] - n10v1 [label="(n10v1) handoff", shape=parallelogram, fillcolor="#ddddff"] + n9v1 [label="(n9v1) map(|vec| (context.current_tick().0, vec))", shape=invhouse, fillcolor="#88aaff"] + n10v1 [label="(n10v1) assert_eq([\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l])\l", shape=house, fillcolor="#ffff88"] n11v1 [label="(n11v1) handoff", shape=parallelogram, fillcolor="#ddddff"] n12v1 [label="(n12v1) handoff", shape=parallelogram, fillcolor="#ddddff"] + n13v1 [label="(n13v1) handoff", shape=parallelogram, fillcolor="#ddddff"] n4v1 -> n7v1 [label="0"] n3v1 -> n4v1 - n1v1 -> n10v1 + n1v1 -> n11v1 n6v1 -> n7v1 [label="1"] n5v1 -> n6v1 - n2v1 -> n11v1 + n2v1 -> n12v1 + n9v1 -> n10v1 n8v1 -> n9v1 - n7v1 -> n12v1 - n10v1 -> n3v1 - n11v1 -> n5v1 - n12v1 -> n8v1 [color=red] + n7v1 -> n13v1 + n11v1 -> n3v1 + n12v1 -> n5v1 + n13v1 -> n8v1 [color=red] subgraph "cluster n1v1" { fillcolor="#dddddd" style=filled @@ -68,5 +70,6 @@ digraph { label = "sg_4v1\nstratum 1" n8v1 n9v1 + n10v1 } } diff --git a/hydroflow/tests/snapshots/surface_loop__flo_nested@graphvis_mermaid.snap b/hydroflow/tests/snapshots/surface_loop__flo_nested@graphvis_mermaid.snap index 68f67faccaf8..24d1482c96b6 100644 --- a/hydroflow/tests/snapshots/surface_loop__flo_nested@graphvis_mermaid.snap +++ b/hydroflow/tests/snapshots/surface_loop__flo_nested@graphvis_mermaid.snap @@ -16,21 +16,23 @@ linkStyle default stroke:#aaa 6v1[\"(6v1) flatten()"/]:::pullClass 7v1[\"(7v1) cross_join::<'static, 'tick>()"/]:::pullClass 8v1[\"(8v1) all_once()"/]:::pullClass -9v1[/"(9v1) for_each(|all| println!("{}: {:?}", context.current_tick(), all))"\]:::pushClass -10v1["(10v1) handoff"]:::otherClass +9v1[\"(9v1) map(|vec| (context.current_tick().0, vec))"/]:::pullClass +10v1[/"
(10v1)
assert_eq([
(
0,
vec![
("alice", 0),
("alice", 1),
("alice", 2),
("bob", 0),
("bob", 1),
("bob", 2),
],
),
(
1,
vec![
("alice", 3),
("alice", 4),
("alice", 5),
("bob", 3),
("bob", 4),
("bob", 5),
],
),
(
2,
vec![
("alice", 6),
("alice", 7),
("alice", 8),
("bob", 6),
("bob", 7),
("bob", 8),
],
),
(
3,
vec![
("alice", 9),
("alice", 10),
("alice", 11),
("bob", 9),
("bob", 10),
("bob", 11),
],
),
])
"\]:::pushClass 11v1["(11v1) handoff"]:::otherClass 12v1["(12v1) handoff"]:::otherClass +13v1["(13v1) handoff"]:::otherClass 4v1-->|0|7v1 3v1-->4v1 -1v1-->10v1 +1v1-->11v1 6v1-->|1|7v1 5v1-->6v1 -2v1-->11v1 +2v1-->12v1 +9v1-->10v1 8v1-->9v1 -7v1-->12v1 -10v1-->3v1 -11v1-->5v1 -12v1--x8v1; linkStyle 10 stroke:red +7v1-->13v1 +11v1-->3v1 +12v1-->5v1 +13v1--x8v1; linkStyle 11 stroke:red subgraph sg_1v1 ["sg_1v1 stratum 0"] 1v1 subgraph sg_1v1_var_users ["var users"] @@ -56,4 +58,5 @@ end subgraph sg_4v1 ["sg_4v1 stratum 1"] 8v1 9v1 + 10v1 end diff --git a/hydroflow/tests/snapshots/surface_loop__flo_repeat_n@graphvis_dot.snap b/hydroflow/tests/snapshots/surface_loop__flo_repeat_n@graphvis_dot.snap new file mode 100644 index 000000000000..363d51e1e1ef --- /dev/null +++ b/hydroflow/tests/snapshots/surface_loop__flo_repeat_n@graphvis_dot.snap @@ -0,0 +1,75 @@ +--- +source: hydroflow/tests/surface_loop.rs +expression: "df.meta_graph().unwrap().to_dot(& Default :: default())" +--- +digraph { + node [fontname="Monaco,Menlo,Consolas,"Droid Sans Mono",Inconsolata,"Courier New",monospace", style=filled]; + edge [fontname="Monaco,Menlo,Consolas,"Droid Sans Mono",Inconsolata,"Courier New",monospace"]; + n1v1 [label="(n1v1) source_iter([\"alice\", \"bob\"])", shape=invhouse, fillcolor="#88aaff"] + n2v1 [label="(n2v1) source_stream(iter_batches_stream(0..12, 3))", shape=invhouse, fillcolor="#88aaff"] + n3v1 [label="(n3v1) batch()", shape=invhouse, fillcolor="#88aaff"] + n4v1 [label="(n4v1) flatten()", shape=invhouse, fillcolor="#88aaff"] + n5v1 [label="(n5v1) batch()", shape=invhouse, fillcolor="#88aaff"] + n6v1 [label="(n6v1) flatten()", shape=invhouse, fillcolor="#88aaff"] + n7v1 [label="(n7v1) cross_join::<'static, 'tick>()", shape=invhouse, fillcolor="#88aaff"] + n8v1 [label="(n8v1) repeat_n(3)", shape=invhouse, fillcolor="#88aaff"] + n9v1 [label="(n9v1) map(|vec| (context.current_tick().0, vec))", shape=invhouse, fillcolor="#88aaff"] + n10v1 [label="(n10v1) assert_eq([\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l])\l", shape=house, fillcolor="#ffff88"] + n11v1 [label="(n11v1) handoff", shape=parallelogram, fillcolor="#ddddff"] + n12v1 [label="(n12v1) handoff", shape=parallelogram, fillcolor="#ddddff"] + n13v1 [label="(n13v1) handoff", shape=parallelogram, fillcolor="#ddddff"] + n4v1 -> n7v1 [label="0"] + n3v1 -> n4v1 + n1v1 -> n11v1 + n6v1 -> n7v1 [label="1"] + n5v1 -> n6v1 + n2v1 -> n12v1 + n9v1 -> n10v1 + n8v1 -> n9v1 + n7v1 -> n13v1 + n11v1 -> n3v1 + n12v1 -> n5v1 + n13v1 -> n8v1 [color=red] + subgraph "cluster n1v1" { + fillcolor="#dddddd" + style=filled + label = "sg_1v1\nstratum 0" + n1v1 + subgraph "cluster_sg_1v1_var_users" { + label="var users" + n1v1 + } + } + subgraph "cluster n2v1" { + fillcolor="#dddddd" + style=filled + label = "sg_2v1\nstratum 0" + n2v1 + subgraph "cluster_sg_2v1_var_messages" { + label="var messages" + n2v1 + } + } + subgraph "cluster n3v1" { + fillcolor="#dddddd" + style=filled + label = "sg_3v1\nstratum 0" + n3v1 + n4v1 + n5v1 + n6v1 + n7v1 + subgraph "cluster_sg_3v1_var_cp" { + label="var cp" + n7v1 + } + } + subgraph "cluster n4v1" { + fillcolor="#dddddd" + style=filled + label = "sg_4v1\nstratum 1" + n8v1 + n9v1 + n10v1 + } +} diff --git a/hydroflow/tests/snapshots/surface_loop__flo_repeat_n@graphvis_mermaid.snap b/hydroflow/tests/snapshots/surface_loop__flo_repeat_n@graphvis_mermaid.snap new file mode 100644 index 000000000000..cba8226d9544 --- /dev/null +++ b/hydroflow/tests/snapshots/surface_loop__flo_repeat_n@graphvis_mermaid.snap @@ -0,0 +1,62 @@ +--- +source: hydroflow/tests/surface_loop.rs +expression: "df.meta_graph().unwrap().to_mermaid(& Default :: default())" +--- +%%{init:{'theme':'base','themeVariables':{'clusterBkg':'#ddd','clusterBorder':'#888'}}}%% +flowchart TD +classDef pullClass fill:#8af,stroke:#000,text-align:left,white-space:pre +classDef pushClass fill:#ff8,stroke:#000,text-align:left,white-space:pre +classDef otherClass fill:#fdc,stroke:#000,text-align:left,white-space:pre +linkStyle default stroke:#aaa +1v1[\"(1v1) source_iter(["alice", "bob"])"/]:::pullClass +2v1[\"(2v1) source_stream(iter_batches_stream(0..12, 3))"/]:::pullClass +3v1[\"(3v1) batch()"/]:::pullClass +4v1[\"(4v1) flatten()"/]:::pullClass +5v1[\"(5v1) batch()"/]:::pullClass +6v1[\"(6v1) flatten()"/]:::pullClass +7v1[\"(7v1) cross_join::<'static, 'tick>()"/]:::pullClass +8v1[\"(8v1) repeat_n(3)"/]:::pullClass +9v1[\"(9v1) map(|vec| (context.current_tick().0, vec))"/]:::pullClass +10v1[/"
(10v1)
assert_eq([
(
0,
vec![
("alice", 0),
("alice", 1),
("alice", 2),
("bob", 0),
("bob", 1),
("bob", 2),
],
),
(
0,
vec![
("alice", 0),
("alice", 1),
("alice", 2),
("bob", 0),
("bob", 1),
("bob", 2),
],
),
(
0,
vec![
("alice", 0),
("alice", 1),
("alice", 2),
("bob", 0),
("bob", 1),
("bob", 2),
],
),
(
1,
vec![
("alice", 3),
("alice", 4),
("alice", 5),
("bob", 3),
("bob", 4),
("bob", 5),
],
),
(
1,
vec![
("alice", 3),
("alice", 4),
("alice", 5),
("bob", 3),
("bob", 4),
("bob", 5),
],
),
(
1,
vec![
("alice", 3),
("alice", 4),
("alice", 5),
("bob", 3),
("bob", 4),
("bob", 5),
],
),
(
2,
vec![
("alice", 6),
("alice", 7),
("alice", 8),
("bob", 6),
("bob", 7),
("bob", 8),
],
),
(
2,
vec![
("alice", 6),
("alice", 7),
("alice", 8),
("bob", 6),
("bob", 7),
("bob", 8),
],
),
(
2,
vec![
("alice", 6),
("alice", 7),
("alice", 8),
("bob", 6),
("bob", 7),
("bob", 8),
],
),
(
3,
vec![
("alice", 9),
("alice", 10),
("alice", 11),
("bob", 9),
("bob", 10),
("bob", 11),
],
),
(
3,
vec![
("alice", 9),
("alice", 10),
("alice", 11),
("bob", 9),
("bob", 10),
("bob", 11),
],
),
(
3,
vec![
("alice", 9),
("alice", 10),
("alice", 11),
("bob", 9),
("bob", 10),
("bob", 11),
],
),
])
"\]:::pushClass +11v1["(11v1) handoff"]:::otherClass +12v1["(12v1) handoff"]:::otherClass +13v1["(13v1) handoff"]:::otherClass +4v1-->|0|7v1 +3v1-->4v1 +1v1-->11v1 +6v1-->|1|7v1 +5v1-->6v1 +2v1-->12v1 +9v1-->10v1 +8v1-->9v1 +7v1-->13v1 +11v1-->3v1 +12v1-->5v1 +13v1--x8v1; linkStyle 11 stroke:red +subgraph sg_1v1 ["sg_1v1 stratum 0"] + 1v1 + subgraph sg_1v1_var_users ["var users"] + 1v1 + end +end +subgraph sg_2v1 ["sg_2v1 stratum 0"] + 2v1 + subgraph sg_2v1_var_messages ["var messages"] + 2v1 + end +end +subgraph sg_3v1 ["sg_3v1 stratum 0"] + 3v1 + 4v1 + 5v1 + 6v1 + 7v1 + subgraph sg_3v1_var_cp ["var cp"] + 7v1 + end +end +subgraph sg_4v1 ["sg_4v1 stratum 1"] + 8v1 + 9v1 + 10v1 +end diff --git a/hydroflow/tests/snapshots/surface_loop__flo_syntax@graphvis_dot.snap b/hydroflow/tests/snapshots/surface_loop__flo_syntax@graphvis_dot.snap index 57f118a3fee3..5af92a321679 100644 --- a/hydroflow/tests/snapshots/surface_loop__flo_syntax@graphvis_dot.snap +++ b/hydroflow/tests/snapshots/surface_loop__flo_syntax@graphvis_dot.snap @@ -12,18 +12,20 @@ digraph { n5v1 [label="(n5v1) batch()", shape=invhouse, fillcolor="#88aaff"] n6v1 [label="(n6v1) flatten()", shape=invhouse, fillcolor="#88aaff"] n7v1 [label="(n7v1) cross_join::<'static, 'tick>()", shape=invhouse, fillcolor="#88aaff"] - n8v1 [label="(n8v1) for_each(|(user, message)| {\l println!(\"{}: notify {} of {}\", context.current_tick(), user, message)\l})\l", shape=house, fillcolor="#ffff88"] - n9v1 [label="(n9v1) handoff", shape=parallelogram, fillcolor="#ddddff"] + n8v1 [label="(n8v1) map(|item| (context.current_tick().0, item))", shape=invhouse, fillcolor="#88aaff"] + n9v1 [label="(n9v1) assert_eq([\l (0, (\"alice\", 0)),\l (0, (\"alice\", 1)),\l (0, (\"alice\", 2)),\l (0, (\"bob\", 0)),\l (0, (\"bob\", 1)),\l (0, (\"bob\", 2)),\l (1, (\"alice\", 3)),\l (1, (\"alice\", 4)),\l (1, (\"alice\", 5)),\l (1, (\"bob\", 3)),\l (1, (\"bob\", 4)),\l (1, (\"bob\", 5)),\l (2, (\"alice\", 6)),\l (2, (\"alice\", 7)),\l (2, (\"alice\", 8)),\l (2, (\"bob\", 6)),\l (2, (\"bob\", 7)),\l (2, (\"bob\", 8)),\l (3, (\"alice\", 9)),\l (3, (\"alice\", 10)),\l (3, (\"alice\", 11)),\l (3, (\"bob\", 9)),\l (3, (\"bob\", 10)),\l (3, (\"bob\", 11)),\l])\l", shape=house, fillcolor="#ffff88"] n10v1 [label="(n10v1) handoff", shape=parallelogram, fillcolor="#ddddff"] + n11v1 [label="(n11v1) handoff", shape=parallelogram, fillcolor="#ddddff"] n4v1 -> n7v1 [label="0"] n3v1 -> n4v1 - n1v1 -> n9v1 + n1v1 -> n10v1 n6v1 -> n7v1 [label="1"] n5v1 -> n6v1 - n2v1 -> n10v1 + n2v1 -> n11v1 + n8v1 -> n9v1 n7v1 -> n8v1 - n9v1 -> n3v1 - n10v1 -> n5v1 + n10v1 -> n3v1 + n11v1 -> n5v1 subgraph "cluster n1v1" { fillcolor="#dddddd" style=filled @@ -54,10 +56,12 @@ digraph { n6v1 n7v1 n8v1 + n9v1 subgraph "cluster_sg_3v1_var_cp" { label="var cp" n7v1 n8v1 + n9v1 } } } diff --git a/hydroflow/tests/snapshots/surface_loop__flo_syntax@graphvis_mermaid.snap b/hydroflow/tests/snapshots/surface_loop__flo_syntax@graphvis_mermaid.snap index 3fdd6bd94591..943fc8a6c278 100644 --- a/hydroflow/tests/snapshots/surface_loop__flo_syntax@graphvis_mermaid.snap +++ b/hydroflow/tests/snapshots/surface_loop__flo_syntax@graphvis_mermaid.snap @@ -15,18 +15,20 @@ linkStyle default stroke:#aaa 5v1[\"(5v1) batch()"/]:::pullClass 6v1[\"(6v1) flatten()"/]:::pullClass 7v1[\"(7v1) cross_join::<'static, 'tick>()"/]:::pullClass -8v1[/"
(8v1)
for_each(|(user, message)| {
println!("{}: notify {} of {}", context.current_tick(), user, message)
})
"\]:::pushClass -9v1["(9v1) handoff"]:::otherClass +8v1[\"(8v1) map(|item| (context.current_tick().0, item))"/]:::pullClass +9v1[/"
(9v1)
assert_eq([
(0, ("alice", 0)),
(0, ("alice", 1)),
(0, ("alice", 2)),
(0, ("bob", 0)),
(0, ("bob", 1)),
(0, ("bob", 2)),
(1, ("alice", 3)),
(1, ("alice", 4)),
(1, ("alice", 5)),
(1, ("bob", 3)),
(1, ("bob", 4)),
(1, ("bob", 5)),
(2, ("alice", 6)),
(2, ("alice", 7)),
(2, ("alice", 8)),
(2, ("bob", 6)),
(2, ("bob", 7)),
(2, ("bob", 8)),
(3, ("alice", 9)),
(3, ("alice", 10)),
(3, ("alice", 11)),
(3, ("bob", 9)),
(3, ("bob", 10)),
(3, ("bob", 11)),
])
"\]:::pushClass 10v1["(10v1) handoff"]:::otherClass +11v1["(11v1) handoff"]:::otherClass 4v1-->|0|7v1 3v1-->4v1 -1v1-->9v1 +1v1-->10v1 6v1-->|1|7v1 5v1-->6v1 -2v1-->10v1 +2v1-->11v1 +8v1-->9v1 7v1-->8v1 -9v1-->3v1 -10v1-->5v1 +10v1-->3v1 +11v1-->5v1 subgraph sg_1v1 ["sg_1v1 stratum 0"] 1v1 subgraph sg_1v1_var_users ["var users"] @@ -46,8 +48,10 @@ subgraph sg_3v1 ["sg_3v1 stratum 0"] 6v1 7v1 8v1 + 9v1 subgraph sg_3v1_var_cp ["var cp"] 7v1 8v1 + 9v1 end end diff --git a/hydroflow/tests/surface_loop.rs b/hydroflow/tests/surface_loop.rs index ec026fd9d8a3..5374d3602271 100644 --- a/hydroflow/tests/surface_loop.rs +++ b/hydroflow/tests/surface_loop.rs @@ -11,7 +11,34 @@ pub fn test_flo_syntax() { // TODO(mingwei): cross_join type negotion should allow us to eliminate `flatten()`. users -> batch() -> flatten() -> [0]cp; messages -> batch() -> flatten() -> [1]cp; - cp = cross_join::<'static, 'tick>() -> for_each(|(user, message)| println!("{}: notify {} of {}", context.current_tick(), user, message)); + cp = cross_join::<'static, 'tick>() + -> map(|item| (context.current_tick().0, item)) + -> assert_eq([ + (0, ("alice", 0)), + (0, ("alice", 1)), + (0, ("alice", 2)), + (0, ("bob", 0)), + (0, ("bob", 1)), + (0, ("bob", 2)), + (1, ("alice", 3)), + (1, ("alice", 4)), + (1, ("alice", 5)), + (1, ("bob", 3)), + (1, ("bob", 4)), + (1, ("bob", 5)), + (2, ("alice", 6)), + (2, ("alice", 7)), + (2, ("alice", 8)), + (2, ("bob", 6)), + (2, ("bob", 7)), + (2, ("bob", 8)), + (3, ("alice", 9)), + (3, ("alice", 10)), + (3, ("alice", 11)), + (3, ("bob", 9)), + (3, ("bob", 10)), + (3, ("bob", 11)), + ]); } }; assert_graphvis_snapshots!(df); @@ -29,7 +56,50 @@ pub fn test_flo_nested() { messages -> batch() -> flatten() -> [1]cp; cp = cross_join::<'static, 'tick>(); loop { - cp -> all_once() -> for_each(|all| println!("{}: {:?}", context.current_tick(), all)); + cp + -> all_once() + -> map(|vec| (context.current_tick().0, vec)) + -> assert_eq([ + (0, vec![("alice", 0), ("alice", 1), ("alice", 2), ("bob", 0), ("bob", 1), ("bob", 2)]), + (1, vec![("alice", 3), ("alice", 4), ("alice", 5), ("bob", 3), ("bob", 4), ("bob", 5)]), + (2, vec![("alice", 6), ("alice", 7), ("alice", 8), ("bob", 6), ("bob", 7), ("bob", 8)]), + (3, vec![("alice", 9), ("alice", 10), ("alice", 11), ("bob", 9), ("bob", 10), ("bob", 11)]), + ]); + } + } + }; + assert_graphvis_snapshots!(df); + df.run_available(); +} + +#[multiplatform_test] +pub fn test_flo_repeat_n() { + let mut df = hydroflow_syntax! { + users = source_iter(["alice", "bob"]); + messages = source_stream(iter_batches_stream(0..12, 3)); + loop { + // TODO(mingwei): cross_join type negotion should allow us to eliminate `flatten()`. + users -> batch() -> flatten() -> [0]cp; + messages -> batch() -> flatten() -> [1]cp; + cp = cross_join::<'static, 'tick>(); + loop { + cp + -> repeat_n(3) + -> map(|vec| (context.current_tick().0, vec)) + -> assert_eq([ + (0, vec![("alice", 0), ("alice", 1), ("alice", 2), ("bob", 0), ("bob", 1), ("bob", 2)]), + (0, vec![("alice", 0), ("alice", 1), ("alice", 2), ("bob", 0), ("bob", 1), ("bob", 2)]), + (0, vec![("alice", 0), ("alice", 1), ("alice", 2), ("bob", 0), ("bob", 1), ("bob", 2)]), + (1, vec![("alice", 3), ("alice", 4), ("alice", 5), ("bob", 3), ("bob", 4), ("bob", 5)]), + (1, vec![("alice", 3), ("alice", 4), ("alice", 5), ("bob", 3), ("bob", 4), ("bob", 5)]), + (1, vec![("alice", 3), ("alice", 4), ("alice", 5), ("bob", 3), ("bob", 4), ("bob", 5)]), + (2, vec![("alice", 6), ("alice", 7), ("alice", 8), ("bob", 6), ("bob", 7), ("bob", 8)]), + (2, vec![("alice", 6), ("alice", 7), ("alice", 8), ("bob", 6), ("bob", 7), ("bob", 8)]), + (2, vec![("alice", 6), ("alice", 7), ("alice", 8), ("bob", 6), ("bob", 7), ("bob", 8)]), + (3, vec![("alice", 9), ("alice", 10), ("alice", 11), ("bob", 9), ("bob", 10), ("bob", 11)]), + (3, vec![("alice", 9), ("alice", 10), ("alice", 11), ("bob", 9), ("bob", 10), ("bob", 11)]), + (3, vec![("alice", 9), ("alice", 10), ("alice", 11), ("bob", 9), ("bob", 10), ("bob", 11)]), + ]); } } }; diff --git a/hydroflow_lang/src/graph/ops/batch.rs b/hydroflow_lang/src/graph/ops/batch.rs index faccafdbd0c7..9e3666068b9e 100644 --- a/hydroflow_lang/src/graph/ops/batch.rs +++ b/hydroflow_lang/src/graph/ops/batch.rs @@ -53,7 +53,9 @@ pub const BATCH: OperatorConstraints = OperatorConstraints { let input = &inputs[0]; quote_spanned! {op_span=> let mut #vec_ident = #context.state_ref(#singleton_output_ident).borrow_mut(); - *#vec_ident = #input.collect::<::std::vec::Vec<_>>(); + if #context.is_first_run_this_tick() { + *#vec_ident = #input.collect::<::std::vec::Vec<_>>(); + } let #ident = ::std::iter::once(::std::clone::Clone::clone(&*#vec_ident)); } } else if let Some(_output) = outputs.first() { diff --git a/hydroflow_lang/src/graph/ops/repeat_n.rs b/hydroflow_lang/src/graph/ops/repeat_n.rs index 978eeec29351..ff13682cada1 100644 --- a/hydroflow_lang/src/graph/ops/repeat_n.rs +++ b/hydroflow_lang/src/graph/ops/repeat_n.rs @@ -1,79 +1,54 @@ use quote::quote_spanned; -use super::{ - FloType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, WriteContextArgs, RANGE_0, - RANGE_1, -}; +use super::{OperatorConstraints, OperatorWriteOutput, WriteContextArgs}; /// TODO(mingwei): docs pub const REPEAT_N: OperatorConstraints = OperatorConstraints { name: "repeat_n", - categories: &[OperatorCategory::Fold, OperatorCategory::Windowing], - hard_range_inn: RANGE_1, - soft_range_inn: RANGE_1, - hard_range_out: &(0..=1), - soft_range_out: &(0..=1), - num_args: 0, - persistence_args: RANGE_0, - type_args: RANGE_0, - is_external_input: false, - has_singleton_output: true, - flo_type: Some(FloType::Windowing), - ports_inn: None, - ports_out: None, - input_delaytype_fn: |_| None, + num_args: 1, write_fn: |wc @ &WriteContextArgs { - root, context, hydroflow, op_span, - ident, - is_pull, - inputs, - outputs, - singleton_output_ident, + arguments, .. }, - _diagnostics| { + diagnostics| { + let OperatorWriteOutput { + write_prologue, + write_iterator, + write_iterator_after, + } = (super::all_once::ALL_ONCE.write_fn)(wc, diagnostics)?; + + let count_ident = wc.make_ident("count"); + let write_prologue = quote_spanned! {op_span=> - #[allow(clippy::redundant_closure_call)] - let #singleton_output_ident = #hydroflow.add_state( - ::std::cell::RefCell::new(::std::vec::Vec::new()) - ); + #write_prologue - // TODO(mingwei): Is this needed? - // Reset the value to the initializer fn if it is a new tick. - #hydroflow.set_state_tick_hook(#singleton_output_ident, move |rcell| { rcell.take(); }); + let #count_ident = #hydroflow.add_state(::std::cell::Cell::new(0_usize)); + #hydroflow.set_state_tick_hook(#count_ident, move |cell| { cell.take(); }); }; - let vec_ident = wc.make_ident("vec"); + // Reschedule, to repeat. + let count_arg = &arguments[0]; + let write_iterator_after = quote_spanned! {op_span=> + #write_iterator_after - let write_iterator = if is_pull { - // Pull. - let input = &inputs[0]; - quote_spanned! {op_span=> - let mut #vec_ident = #context.state_ref(#singleton_output_ident).borrow_mut(); - *#vec_ident = #input.collect::<::std::vec::Vec<_>>(); - let #ident = ::std::iter::once(::std::clone::Clone::clone(&*#vec_ident)); - } - } else if let Some(_output) = outputs.first() { - // Push with output. - // TODO(mingwei): Not supported - cannot tell EOS for pusherators. - panic!("Should not happen - batch must be at ingress to a loop, therefore ingress to a subgraph, so would be pull-based."); - } else { - // Push with no output. - quote_spanned! {op_span=> - let mut #vec_ident = #context.state_ref(#singleton_output_ident).borrow_mut(); - let #ident = #root::pusherator::for_each::ForEach::new(|item| { - ::std::vec::Vec::push(#vec_ident, item); - }); + { + let count_ref = #context.state_ref(#count_ident); + let count = count_ref.get() + 1; + if count < #count_arg { + count_ref.set(count); + #context.reschedule_current_subgraph(); + } } }; Ok(OperatorWriteOutput { write_prologue, write_iterator, - ..Default::default() + write_iterator_after, }) }, + ..super::all_once::ALL_ONCE };