From f514e12ec4b73a4e6a417c5756152dd1ddceac10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alihan=20=C3=87elikcan?= Date: Wed, 18 Sep 2024 16:17:51 +0300 Subject: [PATCH] Preserve the order of right table in NestedLoopJoinExec (#12504) * Maintain right child's order in NestedLoopJoinExec * Format * Refactor monotonicity check * Update sqllogictest according to new behavior * Check output ordering properties * Parameterize only batch sizes for left and right tables * Document maintains_input_order --- .../src/joins/nested_loop_join.rs | 260 +++++++++++++++++- datafusion/sqllogictest/test_files/join.slt | 12 +- datafusion/sqllogictest/test_files/joins.slt | 2 +- 3 files changed, 254 insertions(+), 20 deletions(-) diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index c6f1833c13e0..b30e5184f0f7 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -221,7 +221,7 @@ impl NestedLoopJoinExec { right.equivalence_properties().clone(), &join_type, schema, - &[false, false], + &Self::maintains_input_order(join_type), None, // No on columns in nested loop join &[], @@ -238,6 +238,31 @@ impl NestedLoopJoinExec { PlanProperties::new(eq_properties, output_partitioning, mode) } + + /// Returns a vector indicating whether the left and right inputs maintain their order. + /// The first element corresponds to the left input, and the second to the right. + /// + /// The left (build-side) input's order may change, but the right (probe-side) input's + /// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins. + /// + /// Maintaining the right input's order helps optimize the nodes down the pipeline + /// (See [`ExecutionPlan::maintains_input_order`]). + /// + /// This is a separate method because it is also called when computing properties, before + /// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as + /// opposed to `Self`, for the same reason. + fn maintains_input_order(join_type: JoinType) -> Vec { + vec![ + false, + matches!( + join_type, + JoinType::Inner + | JoinType::Right + | JoinType::RightAnti + | JoinType::RightSemi + ), + ] + } } impl DisplayAs for NestedLoopJoinExec { @@ -278,6 +303,10 @@ impl ExecutionPlan for NestedLoopJoinExec { ] } + fn maintains_input_order(&self) -> Vec { + Self::maintains_input_order(self.join_type) + } + fn children(&self) -> Vec<&Arc> { vec![&self.left, &self.right] } @@ -430,17 +459,17 @@ struct NestedLoopJoinStream { } fn build_join_indices( - left_row_index: usize, - right_batch: &RecordBatch, + right_row_index: usize, left_batch: &RecordBatch, + right_batch: &RecordBatch, filter: Option<&JoinFilter>, ) -> Result<(UInt64Array, UInt32Array)> { - // left indices: [left_index, left_index, ...., left_index] - // right indices: [0, 1, 2, 3, 4,....,right_row_count] + // left indices: [0, 1, 2, 3, 4, ..., left_row_count] + // right indices: [right_index, right_index, ..., right_index] - let right_row_count = right_batch.num_rows(); - let left_indices = UInt64Array::from(vec![left_row_index as u64; right_row_count]); - let right_indices = UInt32Array::from_iter_values(0..(right_row_count as u32)); + let left_row_count = left_batch.num_rows(); + let left_indices = UInt64Array::from_iter_values(0..(left_row_count as u64)); + let right_indices = UInt32Array::from(vec![right_row_index as u32; left_row_count]); // in the nested loop join, the filter can contain non-equal and equal condition. if let Some(filter) = filter { apply_join_filter_to_indices( @@ -567,9 +596,9 @@ fn join_left_and_right_batch( schema: &Schema, visited_left_side: &SharedBitmapBuilder, ) -> Result { - let indices = (0..left_batch.num_rows()) - .map(|left_row_index| { - build_join_indices(left_row_index, right_batch, left_batch, filter) + let indices = (0..right_batch.num_rows()) + .map(|right_row_index| { + build_join_indices(right_row_index, left_batch, right_batch, filter) }) .collect::>>() .map_err(|e| { @@ -601,7 +630,7 @@ fn join_left_and_right_batch( right_side, 0..right_batch.num_rows(), join_type, - false, + true, ); build_batch_from_indices( @@ -649,20 +678,59 @@ mod tests { }; use arrow::datatypes::{DataType, Field}; + use arrow_array::Int32Array; + use arrow_schema::SortOptions; use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::{Partitioning, PhysicalExpr}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + + use rstest::rstest; fn build_table( a: (&str, &Vec), b: (&str, &Vec), c: (&str, &Vec), + batch_size: Option, + sorted_column_names: Vec<&str>, ) -> Arc { let batch = build_table_i32(a, b, c); let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + + let batches = if let Some(batch_size) = batch_size { + let num_batches = batch.num_rows().div_ceil(batch_size); + (0..num_batches) + .map(|i| { + let start = i * batch_size; + let remaining_rows = batch.num_rows() - start; + batch.slice(start, batch_size.min(remaining_rows)) + }) + .collect::>() + } else { + vec![batch] + }; + + let mut exec = + MemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap(); + if !sorted_column_names.is_empty() { + let mut sort_info = Vec::new(); + for name in sorted_column_names { + let index = schema.index_of(name).unwrap(); + let sort_expr = PhysicalSortExpr { + expr: Arc::new(Column::new(name, index)), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }; + sort_info.push(sort_expr); + } + exec = exec.with_sort_information(vec![sort_info]); + } + + Arc::new(exec) } fn build_left_table() -> Arc { @@ -670,6 +738,8 @@ mod tests { ("a1", &vec![5, 9, 11]), ("b1", &vec![5, 8, 8]), ("c1", &vec![50, 90, 110]), + None, + Vec::new(), ) } @@ -678,6 +748,8 @@ mod tests { ("a2", &vec![12, 2, 10]), ("b2", &vec![10, 2, 10]), ("c2", &vec![40, 80, 100]), + None, + Vec::new(), ) } @@ -1005,11 +1077,15 @@ mod tests { ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + None, + Vec::new(), ); let right = build_table( ("a2", &vec![10, 11]), ("b2", &vec![12, 13]), ("c2", &vec![14, 15]), + None, + Vec::new(), ); let filter = prepare_join_filter(); @@ -1050,6 +1126,164 @@ mod tests { Ok(()) } + fn prepare_mod_join_filter() -> JoinFilter { + let column_indices = vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + ]); + + // left.b1 % 3 + let left_mod = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Modulo, + Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), + )) as Arc; + // left.b1 % 3 != 0 + let left_filter = Arc::new(BinaryExpr::new( + left_mod, + Operator::NotEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + )) as Arc; + + // right.b2 % 5 + let right_mod = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 1)), + Operator::Modulo, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + // right.b2 % 5 != 0 + let right_filter = Arc::new(BinaryExpr::new( + right_mod, + Operator::NotEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + )) as Arc; + // filter = left.b1 % 3 != 0 and right.b2 % 5 != 0 + let filter_expression = + Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter)) + as Arc; + + JoinFilter::new(filter_expression, column_indices, intermediate_schema) + } + + fn generate_columns(num_columns: usize, num_rows: usize) -> Vec> { + let column = (1..=num_rows).map(|x| x as i32).collect(); + vec![column; num_columns] + } + + #[rstest] + #[tokio::test] + async fn join_maintains_right_order( + #[values( + JoinType::Inner, + JoinType::Right, + JoinType::RightAnti, + JoinType::RightSemi + )] + join_type: JoinType, + #[values(1, 100, 1000)] left_batch_size: usize, + #[values(1, 100, 1000)] right_batch_size: usize, + ) -> Result<()> { + let left_columns = generate_columns(3, 1000); + let left = build_table( + ("a1", &left_columns[0]), + ("b1", &left_columns[1]), + ("c1", &left_columns[2]), + Some(left_batch_size), + Vec::new(), + ); + + let right_columns = generate_columns(3, 1000); + let right = build_table( + ("a2", &right_columns[0]), + ("b2", &right_columns[1]), + ("c2", &right_columns[2]), + Some(right_batch_size), + vec!["a2", "b2", "c2"], + ); + + let filter = prepare_mod_join_filter(); + + let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new( + left, + Arc::clone(&right), + Some(filter), + &join_type, + )?) as Arc; + assert_eq!(nested_loop_join.maintains_input_order(), vec![false, true]); + + let right_column_indices = match join_type { + JoinType::Inner | JoinType::Right => vec![3, 4, 5], + JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2], + _ => unreachable!(), + }; + + let right_ordering = right.output_ordering().unwrap(); + let join_ordering = nested_loop_join.output_ordering().unwrap(); + for (right, join) in right_ordering.iter().zip(join_ordering.iter()) { + let right_column = right.expr.as_any().downcast_ref::().unwrap(); + let join_column = join.expr.as_any().downcast_ref::().unwrap(); + assert_eq!(join_column.name(), join_column.name()); + assert_eq!( + right_column_indices[right_column.index()], + join_column.index() + ); + assert_eq!(right.options, join.options); + } + + let batches = nested_loop_join + .execute(0, Arc::new(TaskContext::default()))? + .try_collect::>() + .await?; + + // Make sure that the order of the right side is maintained + let mut prev_values = [i32::MIN, i32::MIN, i32::MIN]; + + for (batch_index, batch) in batches.iter().enumerate() { + let columns: Vec<_> = right_column_indices + .iter() + .map(|&i| { + batch + .column(i) + .as_any() + .downcast_ref::() + .unwrap() + }) + .collect(); + + for row in 0..batch.num_rows() { + let current_values = [ + columns[0].value(row), + columns[1].value(row), + columns[2].value(row), + ]; + assert!( + current_values + .into_iter() + .zip(prev_values) + .all(|(current, prev)| current >= prev), + "batch_index: {} row: {} current: {:?}, prev: {:?}", + batch_index, + row, + current_values, + prev_values + ); + prev_values = current_values; + } + } + + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 3e7a08981eac..2f505c9fc71c 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -838,10 +838,10 @@ LEFT JOIN department AS d ON (e.name = 'Alice' OR e.name = 'Bob'); ---- 1 Alice HR -2 Bob HR 1 Alice Engineering -2 Bob Engineering 1 Alice Sales +2 Bob HR +2 Bob Engineering 2 Bob Sales 3 Carol NULL @@ -853,10 +853,10 @@ RIGHT JOIN employees AS e ON (e.name = 'Alice' OR e.name = 'Bob'); ---- 1 Alice HR -2 Bob HR 1 Alice Engineering -2 Bob Engineering 1 Alice Sales +2 Bob HR +2 Bob Engineering 2 Bob Sales 3 Carol NULL @@ -868,10 +868,10 @@ FULL JOIN employees AS e ON (e.name = 'Alice' OR e.name = 'Bob'); ---- 1 Alice HR -2 Bob HR 1 Alice Engineering -2 Bob Engineering 1 Alice Sales +2 Bob HR +2 Bob Engineering 2 Bob Sales 3 Carol NULL diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 7d0262952b31..679c2eee10a4 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -2136,10 +2136,10 @@ FROM (select t1_id from join_t1 where join_t1.t1_id > 22) as join_t1 RIGHT JOIN (select t2_id from join_t2 where join_t2.t2_id > 11) as join_t2 ON join_t1.t1_id < join_t2.t2_id ---- +NULL 22 33 44 33 55 44 55 -NULL 22 ##### # Configuration teardown