Skip to content

Commit

Permalink
refactor(executor): use event cause to refactor shuffle processor (#1โ€ฆ
Browse files Browse the repository at this point in the history
โ€ฆ6445)

* refactor(executor): use event cause to refactor shuffle processor

* refactor(executor): use event cause to refactor shuffle processor

* refactor(executor): use event cause to refactor shuffle processor
  • Loading branch information
zhang2014 authored Sep 12, 2024
1 parent a620190 commit 8cff049
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 33 deletions.
104 changes: 82 additions & 22 deletions src/query/pipeline/core/src/processors/shuffle_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,49 @@ use std::sync::Arc;
use databend_common_exception::Result;

use crate::processors::Event;
use crate::processors::EventCause;
use crate::processors::InputPort;
use crate::processors::OutputPort;
use crate::processors::Processor;

pub struct ShuffleProcessor {
channel: Vec<(Arc<InputPort>, Arc<OutputPort>)>,
input2output: Vec<usize>,
output2input: Vec<usize>,

finished_port: usize,
inputs: Vec<(bool, Arc<InputPort>)>,
outputs: Vec<(bool, Arc<OutputPort>)>,
}

impl ShuffleProcessor {
pub fn create(
inputs: Vec<Arc<InputPort>>,
outputs: Vec<Arc<OutputPort>>,
rule: Vec<usize>,
edges: Vec<usize>,
) -> Self {
let len = rule.len();
let len = edges.len();
debug_assert!({
let mut sorted = rule.clone();
let mut sorted = edges.clone();
sorted.sort();
let expected = (0..len).collect::<Vec<_>>();
sorted == expected
});

let mut channel = Vec::with_capacity(len);
for (i, input) in inputs.into_iter().enumerate() {
let output = outputs[rule[i]].clone();
channel.push((input, output));
let mut input2output = vec![0_usize; edges.len()];
let mut output2input = vec![0_usize; edges.len()];

for (input, output) in edges.into_iter().enumerate() {
input2output[input] = output;
output2input[output] = input;
}

ShuffleProcessor {
input2output,
output2input,
finished_port: 0,
inputs: inputs.into_iter().map(|x| (false, x)).collect(),
outputs: outputs.into_iter().map(|x| (false, x)).collect(),
}
ShuffleProcessor { channel }
}
}

Expand All @@ -59,23 +74,68 @@ impl Processor for ShuffleProcessor {
self
}

fn event(&mut self) -> Result<Event> {
let mut finished = true;
for (input, output) in self.channel.iter() {
if output.is_finished() || input.is_finished() {
input.finish();
output.finish();
continue;
fn event_with_cause(&mut self, cause: EventCause) -> Result<Event> {
let ((input_finished, input), (output_finished, output)) = match cause {
EventCause::Other => unreachable!(),
EventCause::Input(index) => (
&mut self.inputs[index],
&mut self.outputs[self.input2output[index]],
),
EventCause::Output(index) => (
&mut self.inputs[self.output2input[index]],
&mut self.outputs[index],
),
};

if output.is_finished() {
input.finish();

if !*input_finished {
*input_finished = true;
self.finished_port += 1;
}
finished = false;
input.set_need_data();
if output.can_push() && input.has_data() {
output.push_data(input.pull_data().unwrap());

if !*output_finished {
*output_finished = true;
self.finished_port += 1;
}

return match self.finished_port == (self.inputs.len() + self.outputs.len()) {
true => Ok(Event::Finished),
false => Ok(Event::NeedConsume),
};
}
if finished {
return Ok(Event::Finished);

if !output.can_push() {
input.set_not_need_data();
return Ok(Event::NeedConsume);
}

if input.has_data() {
output.push_data(input.pull_data().unwrap());
return Ok(Event::NeedConsume);
}

if input.is_finished() {
output.finish();

if !*input_finished {
*input_finished = true;
self.finished_port += 1;
}

if !*output_finished {
*output_finished = true;
self.finished_port += 1;
}

return match self.finished_port == (self.inputs.len() + self.outputs.len()) {
true => Ok(Event::Finished),
false => Ok(Event::NeedConsume),
};
}

input.set_need_data();
Ok(Event::NeedData)
}
}
176 changes: 165 additions & 11 deletions src/query/pipeline/core/tests/it/pipelines/processors/shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use databend_common_expression::DataBlock;
use databend_common_expression::FromData;
use databend_common_pipeline_core::processors::connect;
use databend_common_pipeline_core::processors::Event;
use databend_common_pipeline_core::processors::EventCause;
use databend_common_pipeline_core::processors::InputPort;
use databend_common_pipeline_core::processors::OutputPort;
use databend_common_pipeline_core::processors::Processor;
Expand Down Expand Up @@ -51,8 +52,19 @@ async fn test_shuffle_output_finish() -> Result<()> {
downstream_input1.finish();
downstream_input2.finish();

assert!(matches!(processor.event()?, Event::Finished));
assert!(input1.is_finished() && input2.is_finished());
assert!(matches!(
processor.event_with_cause(EventCause::Output(0))?,
Event::NeedConsume
));
assert!(input1.is_finished());
assert!(!input2.is_finished());

assert!(matches!(
processor.event_with_cause(EventCause::Output(1))?,
Event::Finished
));
assert!(input1.is_finished());
assert!(input2.is_finished());

Ok(())
}
Expand Down Expand Up @@ -122,17 +134,159 @@ async fn test_shuffle_processor() -> Result<()> {
upstream_output3.push_data(Ok(block3));
upstream_output4.push_data(Ok(block4));

assert!(matches!(processor.event()?, Event::NeedData));
// 0 input and 0 output
assert!(matches!(
processor.event_with_cause(EventCause::Output(0))?,
Event::NeedConsume
));

assert!(downstream_input1.has_data());
assert!(
!downstream_input2.has_data()
&& !downstream_input3.has_data()
&& !downstream_input4.has_data()
);
assert!(
!upstream_output1.can_push()
&& !upstream_output2.can_push()
&& !upstream_output3.can_push()
&& !upstream_output4.can_push()
);

let block = downstream_input1.pull_data().unwrap()?;
downstream_input1.set_need_data();
assert!(block.columns()[0].value.as_column().unwrap().eq(&col1));
assert!(matches!(
processor.event_with_cause(EventCause::Output(0))?,
Event::NeedData
));

assert!(upstream_output1.can_push());
assert!(
!upstream_output2.can_push()
&& !upstream_output3.can_push()
&& !upstream_output4.can_push()
);
assert!(
!downstream_input1.has_data()
&& !downstream_input2.has_data()
&& !downstream_input3.has_data()
&& !downstream_input4.has_data()
);

// 2 input and 1 output
assert!(matches!(
processor.event_with_cause(EventCause::Output(1))?,
Event::NeedConsume
));

assert!(downstream_input2.has_data());
assert!(
!downstream_input1.has_data()
&& !downstream_input3.has_data()
&& !downstream_input4.has_data()
);
assert!(
upstream_output1.can_push()
&& !upstream_output2.can_push()
&& !upstream_output3.can_push()
&& !upstream_output4.can_push()
);

let block = downstream_input2.pull_data().unwrap()?;
downstream_input2.set_need_data();
assert!(block.columns()[0].value.as_column().unwrap().eq(&col3));
assert!(matches!(
processor.event_with_cause(EventCause::Output(1))?,
Event::NeedData
));

assert!(upstream_output3.can_push());
assert!(
upstream_output1.can_push() && !upstream_output2.can_push() && !upstream_output4.can_push()
);
assert!(
!downstream_input1.has_data()
&& !downstream_input2.has_data()
&& !downstream_input3.has_data()
&& !downstream_input4.has_data()
);

// 1 input and 2 output
assert!(matches!(
processor.event_with_cause(EventCause::Output(2))?,
Event::NeedConsume
));

assert!(downstream_input3.has_data());
assert!(
!downstream_input1.has_data()
&& !downstream_input2.has_data()
&& !downstream_input4.has_data()
);
assert!(
upstream_output1.can_push()
&& !upstream_output2.can_push()
&& upstream_output3.can_push()
&& !upstream_output4.can_push()
);

let out1 = downstream_input1.pull_data().unwrap()?;
let out2 = downstream_input2.pull_data().unwrap()?;
let out3 = downstream_input3.pull_data().unwrap()?;
let out4 = downstream_input4.pull_data().unwrap()?;
let block = downstream_input3.pull_data().unwrap()?;
downstream_input3.set_need_data();
assert!(block.columns()[0].value.as_column().unwrap().eq(&col2));
assert!(matches!(
processor.event_with_cause(EventCause::Output(2))?,
Event::NeedData
));

assert!(upstream_output2.can_push());
assert!(
upstream_output1.can_push() && upstream_output3.can_push() && !upstream_output4.can_push()
);
assert!(
!downstream_input1.has_data()
&& !downstream_input2.has_data()
&& !downstream_input3.has_data()
&& !downstream_input4.has_data()
);

assert!(out1.columns()[0].value.as_column().unwrap().eq(&col1));
assert!(out2.columns()[0].value.as_column().unwrap().eq(&col3));
assert!(out3.columns()[0].value.as_column().unwrap().eq(&col2));
assert!(out4.columns()[0].value.as_column().unwrap().eq(&col4));
// 3 input and 3 output
assert!(matches!(
processor.event_with_cause(EventCause::Output(3))?,
Event::NeedConsume
));

assert!(downstream_input4.has_data());
assert!(
!downstream_input1.has_data()
&& !downstream_input2.has_data()
&& !downstream_input3.has_data()
);
assert!(
upstream_output1.can_push()
&& upstream_output2.can_push()
&& upstream_output3.can_push()
&& !upstream_output4.can_push()
);

let block = downstream_input4.pull_data().unwrap()?;
downstream_input4.set_need_data();
assert!(block.columns()[0].value.as_column().unwrap().eq(&col4));
assert!(matches!(
processor.event_with_cause(EventCause::Output(3))?,
Event::NeedData
));

assert!(upstream_output4.can_push());
assert!(
upstream_output1.can_push() && upstream_output3.can_push() && upstream_output2.can_push()
);
assert!(
!downstream_input1.has_data()
&& !downstream_input2.has_data()
&& !downstream_input3.has_data()
&& !downstream_input4.has_data()
);

Ok(())
}

0 comments on commit 8cff049

Please sign in to comment.