Skip to content

Commit

Permalink
CrossJoin Refactor (#9830)
Browse files Browse the repository at this point in the history
* First iteration

* Wrap the logic inside function

* Send batches in the size of left batches

* Update cross_join.rs

* fuzz tests

* Update cross_join_fuzz.rs

* Update cross_join_fuzz.rs

* Test version 2

* Minor changes

* Minor changes

* Stateful implementation of CJ

* Adding comments

* Update cross_join_fuzz.rs

* Update cross_join.rs

* collect until batch size

* tmp

* revert changes

* Preserve the join strategy, clean the algorithm and states

* Update cross_join.rs

* Review

* Update cross_join.rs

---------

Co-authored-by: Mustafa Akur <[email protected]>
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
  • Loading branch information
3 people authored Apr 4, 2024
1 parent 2f55003 commit 4bd7c13
Showing 1 changed file with 95 additions and 47 deletions.
142 changes: 95 additions & 47 deletions datafusion/physical-plan/src/joins/cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,23 @@ use std::{any::Any, sync::Arc, task::Poll};

use super::utils::{
adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut,
StatefulStreamResult,
};
use crate::coalesce_batches::concat_batches;
use crate::coalesce_partitions::CoalescePartitionsExec;
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::ExecutionPlanProperties;
use crate::{
execution_mode_from_children, ColumnStatistics, DisplayAs, DisplayFormatType,
Distribution, ExecutionMode, ExecutionPlan, PlanProperties, RecordBatchStream,
execution_mode_from_children, handle_state, ColumnStatistics, DisplayAs,
DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan,
ExecutionPlanProperties, PlanProperties, RecordBatchStream,
SendableRecordBatchStream, Statistics,
};

use arrow::datatypes::{Fields, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_array::RecordBatchOptions;
use datafusion_common::stats::Precision;
use datafusion_common::{JoinType, Result, ScalarValue};
use datafusion_common::{internal_err, JoinType, Result, ScalarValue};
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
Expand Down Expand Up @@ -257,9 +258,10 @@ impl ExecutionPlan for CrossJoinExec {
schema: self.schema.clone(),
left_fut,
right: stream,
right_batch: Arc::new(parking_lot::Mutex::new(None)),
left_index: 0,
join_metrics,
state: CrossJoinStreamState::WaitBuildSide,
left_data: RecordBatch::new_empty(self.left().schema()),
}))
}

Expand Down Expand Up @@ -319,16 +321,18 @@ fn stats_cartesian_product(
struct CrossJoinStream {
/// Input schema
schema: Arc<Schema>,
/// future for data from left side
/// Future for data from left side
left_fut: OnceFut<JoinLeftData>,
/// right
/// Right side stream
right: SendableRecordBatchStream,
/// Current value on the left
left_index: usize,
/// Current batch being processed from the right side
right_batch: Arc<parking_lot::Mutex<Option<RecordBatch>>>,
/// join execution metrics
/// Join execution metrics
join_metrics: BuildProbeJoinMetrics,
/// State of the stream
state: CrossJoinStreamState,
/// Left data
left_data: RecordBatch,
}

impl RecordBatchStream for CrossJoinStream {
Expand All @@ -337,6 +341,25 @@ impl RecordBatchStream for CrossJoinStream {
}
}

/// Represents states of CrossJoinStream
enum CrossJoinStreamState {
WaitBuildSide,
FetchProbeBatch,
/// Holds the currently processed right side batch
BuildBatches(RecordBatch),
}

impl CrossJoinStreamState {
/// Tries to extract RecordBatch from CrossJoinStreamState enum.
/// Returns an error if state is not BuildBatches state.
fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
match self {
CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
_ => internal_err!("Expected RecordBatch in BuildBatches state"),
}
}
}

fn build_batch(
left_index: usize,
batch: &RecordBatch,
Expand Down Expand Up @@ -384,58 +407,83 @@ impl CrossJoinStream {
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<RecordBatch>>> {
loop {
return match self.state {
CrossJoinStreamState::WaitBuildSide => {
handle_state!(ready!(self.collect_build_side(cx)))
}
CrossJoinStreamState::FetchProbeBatch => {
handle_state!(ready!(self.fetch_probe_batch(cx)))
}
CrossJoinStreamState::BuildBatches(_) => {
handle_state!(self.build_batches())
}
};
}
}

/// Collects build (left) side of the join into the state. In case of an empty build batch,
/// the execution terminates. Otherwise, the state is updated to fetch probe (right) batch.
fn collect_build_side(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let build_timer = self.join_metrics.build_time.timer();
let (left_data, _) = match ready!(self.left_fut.get(cx)) {
Ok(left_data) => left_data,
Err(e) => return Poll::Ready(Some(Err(e))),
Err(e) => return Poll::Ready(Err(e)),
};
build_timer.done();

if left_data.num_rows() == 0 {
return Poll::Ready(None);
}
let result = if left_data.num_rows() == 0 {
StatefulStreamResult::Ready(None)
} else {
self.left_data = left_data.clone();
self.state = CrossJoinStreamState::FetchProbeBatch;
StatefulStreamResult::Continue
};
Poll::Ready(Ok(result))
}

/// Fetches the probe (right) batch, updates the metrics, and save the batch in the state.
/// Then, the state is updated to build result batches.
fn fetch_probe_batch(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
self.left_index = 0;
let right_data = match ready!(self.right.poll_next_unpin(cx)) {
Some(Ok(right_data)) => right_data,
Some(Err(e)) => return Poll::Ready(Err(e)),
None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
};
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(right_data.num_rows());

self.state = CrossJoinStreamState::BuildBatches(right_data);
Poll::Ready(Ok(StatefulStreamResult::Continue))
}

if self.left_index > 0 && self.left_index < left_data.num_rows() {
/// Joins the the indexed row of left data with the current probe batch.
/// If all the results are produced, the state is set to fetch new probe batch.
fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
let right_batch = self.state.try_as_record_batch()?;
if self.left_index < self.left_data.num_rows() {
let join_timer = self.join_metrics.join_time.timer();
let right_batch = {
let right_batch = self.right_batch.lock();
right_batch.clone().unwrap()
};
let result =
build_batch(self.left_index, &right_batch, left_data, &self.schema);
self.join_metrics.input_rows.add(right_batch.num_rows());
build_batch(self.left_index, right_batch, &self.left_data, &self.schema);
join_timer.done();

if let Ok(ref batch) = result {
join_timer.done();
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
}
self.left_index += 1;
return Poll::Ready(Some(result));
result.map(|r| StatefulStreamResult::Ready(Some(r)))
} else {
self.state = CrossJoinStreamState::FetchProbeBatch;
Ok(StatefulStreamResult::Continue)
}
self.left_index = 0;
self.right
.poll_next_unpin(cx)
.map(|maybe_batch| match maybe_batch {
Some(Ok(batch)) => {
let join_timer = self.join_metrics.join_time.timer();
let result =
build_batch(self.left_index, &batch, left_data, &self.schema);
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
if let Ok(ref batch) = result {
join_timer.done();
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
}
self.left_index = 1;

let mut right_batch = self.right_batch.lock();
*right_batch = Some(batch);

Some(result)
}
other => other,
})
}
}

Expand Down

0 comments on commit 4bd7c13

Please sign in to comment.