From 8cff049dbce978606d495160b0ef5058837ca6d9 Mon Sep 17 00:00:00 2001 From: Winter Zhang Date: Thu, 12 Sep 2024 23:33:12 +0800 Subject: [PATCH] refactor(executor): use event cause to refactor shuffle processor (#16445) * refactor(executor): use event cause to refactor shuffle processor * refactor(executor): use event cause to refactor shuffle processor * refactor(executor): use event cause to refactor shuffle processor --- .../core/src/processors/shuffle_processor.rs | 104 ++++++++--- .../tests/it/pipelines/processors/shuffle.rs | 176 ++++++++++++++++-- 2 files changed, 247 insertions(+), 33 deletions(-) diff --git a/src/query/pipeline/core/src/processors/shuffle_processor.rs b/src/query/pipeline/core/src/processors/shuffle_processor.rs index ed545680ff65..ebfeefb6479e 100644 --- a/src/query/pipeline/core/src/processors/shuffle_processor.rs +++ b/src/query/pipeline/core/src/processors/shuffle_processor.rs @@ -18,34 +18,49 @@ use std::sync::Arc; use databend_common_exception::Result; use crate::processors::Event; +use crate::processors::EventCause; use crate::processors::InputPort; use crate::processors::OutputPort; use crate::processors::Processor; pub struct ShuffleProcessor { - channel: Vec<(Arc, Arc)>, + input2output: Vec, + output2input: Vec, + + finished_port: usize, + inputs: Vec<(bool, Arc)>, + outputs: Vec<(bool, Arc)>, } impl ShuffleProcessor { pub fn create( inputs: Vec>, outputs: Vec>, - rule: Vec, + edges: Vec, ) -> Self { - let len = rule.len(); + let len = edges.len(); debug_assert!({ - let mut sorted = rule.clone(); + let mut sorted = edges.clone(); sorted.sort(); let expected = (0..len).collect::>(); sorted == expected }); - let mut channel = Vec::with_capacity(len); - for (i, input) in inputs.into_iter().enumerate() { - let output = outputs[rule[i]].clone(); - channel.push((input, output)); + let mut input2output = vec![0_usize; edges.len()]; + let mut output2input = vec![0_usize; edges.len()]; + + for (input, output) in edges.into_iter().enumerate() { + input2output[input] = output; + output2input[output] = input; + } + + ShuffleProcessor { + input2output, + output2input, + finished_port: 0, + inputs: inputs.into_iter().map(|x| (false, x)).collect(), + outputs: outputs.into_iter().map(|x| (false, x)).collect(), } - ShuffleProcessor { channel } } } @@ -59,23 +74,68 @@ impl Processor for ShuffleProcessor { self } - fn event(&mut self) -> Result { - let mut finished = true; - for (input, output) in self.channel.iter() { - if output.is_finished() || input.is_finished() { - input.finish(); - output.finish(); - continue; + fn event_with_cause(&mut self, cause: EventCause) -> Result { + let ((input_finished, input), (output_finished, output)) = match cause { + EventCause::Other => unreachable!(), + EventCause::Input(index) => ( + &mut self.inputs[index], + &mut self.outputs[self.input2output[index]], + ), + EventCause::Output(index) => ( + &mut self.inputs[self.output2input[index]], + &mut self.outputs[index], + ), + }; + + if output.is_finished() { + input.finish(); + + if !*input_finished { + *input_finished = true; + self.finished_port += 1; } - finished = false; - input.set_need_data(); - if output.can_push() && input.has_data() { - output.push_data(input.pull_data().unwrap()); + + if !*output_finished { + *output_finished = true; + self.finished_port += 1; } + + return match self.finished_port == (self.inputs.len() + self.outputs.len()) { + true => Ok(Event::Finished), + false => Ok(Event::NeedConsume), + }; } - if finished { - return Ok(Event::Finished); + + if !output.can_push() { + input.set_not_need_data(); + return Ok(Event::NeedConsume); } + + if input.has_data() { + output.push_data(input.pull_data().unwrap()); + return Ok(Event::NeedConsume); + } + + if input.is_finished() { + output.finish(); + + if !*input_finished { + *input_finished = true; + self.finished_port += 1; + } + + if !*output_finished { + *output_finished = true; + self.finished_port += 1; + } + + return match self.finished_port == (self.inputs.len() + self.outputs.len()) { + true => Ok(Event::Finished), + false => Ok(Event::NeedConsume), + }; + } + + input.set_need_data(); Ok(Event::NeedData) } } diff --git a/src/query/pipeline/core/tests/it/pipelines/processors/shuffle.rs b/src/query/pipeline/core/tests/it/pipelines/processors/shuffle.rs index 980d775231e9..a785a9bbe3c5 100644 --- a/src/query/pipeline/core/tests/it/pipelines/processors/shuffle.rs +++ b/src/query/pipeline/core/tests/it/pipelines/processors/shuffle.rs @@ -18,6 +18,7 @@ use databend_common_expression::DataBlock; use databend_common_expression::FromData; use databend_common_pipeline_core::processors::connect; use databend_common_pipeline_core::processors::Event; +use databend_common_pipeline_core::processors::EventCause; use databend_common_pipeline_core::processors::InputPort; use databend_common_pipeline_core::processors::OutputPort; use databend_common_pipeline_core::processors::Processor; @@ -51,8 +52,19 @@ async fn test_shuffle_output_finish() -> Result<()> { downstream_input1.finish(); downstream_input2.finish(); - assert!(matches!(processor.event()?, Event::Finished)); - assert!(input1.is_finished() && input2.is_finished()); + assert!(matches!( + processor.event_with_cause(EventCause::Output(0))?, + Event::NeedConsume + )); + assert!(input1.is_finished()); + assert!(!input2.is_finished()); + + assert!(matches!( + processor.event_with_cause(EventCause::Output(1))?, + Event::Finished + )); + assert!(input1.is_finished()); + assert!(input2.is_finished()); Ok(()) } @@ -122,17 +134,159 @@ async fn test_shuffle_processor() -> Result<()> { upstream_output3.push_data(Ok(block3)); upstream_output4.push_data(Ok(block4)); - assert!(matches!(processor.event()?, Event::NeedData)); + // 0 input and 0 output + assert!(matches!( + processor.event_with_cause(EventCause::Output(0))?, + Event::NeedConsume + )); + + assert!(downstream_input1.has_data()); + assert!( + !downstream_input2.has_data() + && !downstream_input3.has_data() + && !downstream_input4.has_data() + ); + assert!( + !upstream_output1.can_push() + && !upstream_output2.can_push() + && !upstream_output3.can_push() + && !upstream_output4.can_push() + ); + + let block = downstream_input1.pull_data().unwrap()?; + downstream_input1.set_need_data(); + assert!(block.columns()[0].value.as_column().unwrap().eq(&col1)); + assert!(matches!( + processor.event_with_cause(EventCause::Output(0))?, + Event::NeedData + )); + + assert!(upstream_output1.can_push()); + assert!( + !upstream_output2.can_push() + && !upstream_output3.can_push() + && !upstream_output4.can_push() + ); + assert!( + !downstream_input1.has_data() + && !downstream_input2.has_data() + && !downstream_input3.has_data() + && !downstream_input4.has_data() + ); + + // 2 input and 1 output + assert!(matches!( + processor.event_with_cause(EventCause::Output(1))?, + Event::NeedConsume + )); + + assert!(downstream_input2.has_data()); + assert!( + !downstream_input1.has_data() + && !downstream_input3.has_data() + && !downstream_input4.has_data() + ); + assert!( + upstream_output1.can_push() + && !upstream_output2.can_push() + && !upstream_output3.can_push() + && !upstream_output4.can_push() + ); + + let block = downstream_input2.pull_data().unwrap()?; + downstream_input2.set_need_data(); + assert!(block.columns()[0].value.as_column().unwrap().eq(&col3)); + assert!(matches!( + processor.event_with_cause(EventCause::Output(1))?, + Event::NeedData + )); + + assert!(upstream_output3.can_push()); + assert!( + upstream_output1.can_push() && !upstream_output2.can_push() && !upstream_output4.can_push() + ); + assert!( + !downstream_input1.has_data() + && !downstream_input2.has_data() + && !downstream_input3.has_data() + && !downstream_input4.has_data() + ); + + // 1 input and 2 output + assert!(matches!( + processor.event_with_cause(EventCause::Output(2))?, + Event::NeedConsume + )); + + assert!(downstream_input3.has_data()); + assert!( + !downstream_input1.has_data() + && !downstream_input2.has_data() + && !downstream_input4.has_data() + ); + assert!( + upstream_output1.can_push() + && !upstream_output2.can_push() + && upstream_output3.can_push() + && !upstream_output4.can_push() + ); - let out1 = downstream_input1.pull_data().unwrap()?; - let out2 = downstream_input2.pull_data().unwrap()?; - let out3 = downstream_input3.pull_data().unwrap()?; - let out4 = downstream_input4.pull_data().unwrap()?; + let block = downstream_input3.pull_data().unwrap()?; + downstream_input3.set_need_data(); + assert!(block.columns()[0].value.as_column().unwrap().eq(&col2)); + assert!(matches!( + processor.event_with_cause(EventCause::Output(2))?, + Event::NeedData + )); + + assert!(upstream_output2.can_push()); + assert!( + upstream_output1.can_push() && upstream_output3.can_push() && !upstream_output4.can_push() + ); + assert!( + !downstream_input1.has_data() + && !downstream_input2.has_data() + && !downstream_input3.has_data() + && !downstream_input4.has_data() + ); - assert!(out1.columns()[0].value.as_column().unwrap().eq(&col1)); - assert!(out2.columns()[0].value.as_column().unwrap().eq(&col3)); - assert!(out3.columns()[0].value.as_column().unwrap().eq(&col2)); - assert!(out4.columns()[0].value.as_column().unwrap().eq(&col4)); + // 3 input and 3 output + assert!(matches!( + processor.event_with_cause(EventCause::Output(3))?, + Event::NeedConsume + )); + + assert!(downstream_input4.has_data()); + assert!( + !downstream_input1.has_data() + && !downstream_input2.has_data() + && !downstream_input3.has_data() + ); + assert!( + upstream_output1.can_push() + && upstream_output2.can_push() + && upstream_output3.can_push() + && !upstream_output4.can_push() + ); + + let block = downstream_input4.pull_data().unwrap()?; + downstream_input4.set_need_data(); + assert!(block.columns()[0].value.as_column().unwrap().eq(&col4)); + assert!(matches!( + processor.event_with_cause(EventCause::Output(3))?, + Event::NeedData + )); + + assert!(upstream_output4.can_push()); + assert!( + upstream_output1.can_push() && upstream_output3.can_push() && upstream_output2.can_push() + ); + assert!( + !downstream_input1.has_data() + && !downstream_input2.has_data() + && !downstream_input3.has_data() + && !downstream_input4.has_data() + ); Ok(()) }