Skip to content

Commit

Permalink
Preserve the order of right table in NestedLoopJoinExec (#12504)
Browse files Browse the repository at this point in the history
* Maintain right child's order in NestedLoopJoinExec

* Format

* Refactor monotonicity check

* Update sqllogictest according to new behavior

* Check output ordering properties

* Parameterize only batch sizes for left and right tables

* Document maintains_input_order
  • Loading branch information
alihan-synnada authored Sep 18, 2024
1 parent 9781aef commit f514e12
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 20 deletions.
260 changes: 247 additions & 13 deletions datafusion/physical-plan/src/joins/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ impl NestedLoopJoinExec {
right.equivalence_properties().clone(),
&join_type,
schema,
&[false, false],
&Self::maintains_input_order(join_type),
None,
// No on columns in nested loop join
&[],
Expand All @@ -238,6 +238,31 @@ impl NestedLoopJoinExec {

PlanProperties::new(eq_properties, output_partitioning, mode)
}

/// Returns a vector indicating whether the left and right inputs maintain their order.
/// The first element corresponds to the left input, and the second to the right.
///
/// The left (build-side) input's order may change, but the right (probe-side) input's
/// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins.
///
/// Maintaining the right input's order helps optimize the nodes down the pipeline
/// (See [`ExecutionPlan::maintains_input_order`]).
///
/// This is a separate method because it is also called when computing properties, before
/// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as
/// opposed to `Self`, for the same reason.
fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
vec![
false,
matches!(
join_type,
JoinType::Inner
| JoinType::Right
| JoinType::RightAnti
| JoinType::RightSemi
),
]
}
}

impl DisplayAs for NestedLoopJoinExec {
Expand Down Expand Up @@ -278,6 +303,10 @@ impl ExecutionPlan for NestedLoopJoinExec {
]
}

fn maintains_input_order(&self) -> Vec<bool> {
Self::maintains_input_order(self.join_type)
}

fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.left, &self.right]
}
Expand Down Expand Up @@ -430,17 +459,17 @@ struct NestedLoopJoinStream {
}

fn build_join_indices(
left_row_index: usize,
right_batch: &RecordBatch,
right_row_index: usize,
left_batch: &RecordBatch,
right_batch: &RecordBatch,
filter: Option<&JoinFilter>,
) -> Result<(UInt64Array, UInt32Array)> {
// left indices: [left_index, left_index, ...., left_index]
// right indices: [0, 1, 2, 3, 4,....,right_row_count]
// left indices: [0, 1, 2, 3, 4, ..., left_row_count]
// right indices: [right_index, right_index, ..., right_index]

let right_row_count = right_batch.num_rows();
let left_indices = UInt64Array::from(vec![left_row_index as u64; right_row_count]);
let right_indices = UInt32Array::from_iter_values(0..(right_row_count as u32));
let left_row_count = left_batch.num_rows();
let left_indices = UInt64Array::from_iter_values(0..(left_row_count as u64));
let right_indices = UInt32Array::from(vec![right_row_index as u32; left_row_count]);
// in the nested loop join, the filter can contain non-equal and equal condition.
if let Some(filter) = filter {
apply_join_filter_to_indices(
Expand Down Expand Up @@ -567,9 +596,9 @@ fn join_left_and_right_batch(
schema: &Schema,
visited_left_side: &SharedBitmapBuilder,
) -> Result<RecordBatch> {
let indices = (0..left_batch.num_rows())
.map(|left_row_index| {
build_join_indices(left_row_index, right_batch, left_batch, filter)
let indices = (0..right_batch.num_rows())
.map(|right_row_index| {
build_join_indices(right_row_index, left_batch, right_batch, filter)
})
.collect::<Result<Vec<(UInt64Array, UInt32Array)>>>()
.map_err(|e| {
Expand Down Expand Up @@ -601,7 +630,7 @@ fn join_left_and_right_batch(
right_side,
0..right_batch.num_rows(),
join_type,
false,
true,
);

build_batch_from_indices(
Expand Down Expand Up @@ -649,27 +678,68 @@ mod tests {
};

use arrow::datatypes::{DataType, Field};
use arrow_array::Int32Array;
use arrow_schema::SortOptions;
use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue};
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
use datafusion_physical_expr::{Partitioning, PhysicalExpr};
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;

use rstest::rstest;

fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<i32>),
batch_size: Option<usize>,
sorted_column_names: Vec<&str>,
) -> Arc<dyn ExecutionPlan> {
let batch = build_table_i32(a, b, c);
let schema = batch.schema();
Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())

let batches = if let Some(batch_size) = batch_size {
let num_batches = batch.num_rows().div_ceil(batch_size);
(0..num_batches)
.map(|i| {
let start = i * batch_size;
let remaining_rows = batch.num_rows() - start;
batch.slice(start, batch_size.min(remaining_rows))
})
.collect::<Vec<_>>()
} else {
vec![batch]
};

let mut exec =
MemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap();
if !sorted_column_names.is_empty() {
let mut sort_info = Vec::new();
for name in sorted_column_names {
let index = schema.index_of(name).unwrap();
let sort_expr = PhysicalSortExpr {
expr: Arc::new(Column::new(name, index)),
options: SortOptions {
descending: false,
nulls_first: false,
},
};
sort_info.push(sort_expr);
}
exec = exec.with_sort_information(vec![sort_info]);
}

Arc::new(exec)
}

fn build_left_table() -> Arc<dyn ExecutionPlan> {
build_table(
("a1", &vec![5, 9, 11]),
("b1", &vec![5, 8, 8]),
("c1", &vec![50, 90, 110]),
None,
Vec::new(),
)
}

Expand All @@ -678,6 +748,8 @@ mod tests {
("a2", &vec![12, 2, 10]),
("b2", &vec![10, 2, 10]),
("c2", &vec![40, 80, 100]),
None,
Vec::new(),
)
}

Expand Down Expand Up @@ -1005,11 +1077,15 @@ mod tests {
("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
None,
Vec::new(),
);
let right = build_table(
("a2", &vec![10, 11]),
("b2", &vec![12, 13]),
("c2", &vec![14, 15]),
None,
Vec::new(),
);
let filter = prepare_join_filter();

Expand Down Expand Up @@ -1050,6 +1126,164 @@ mod tests {
Ok(())
}

fn prepare_mod_join_filter() -> JoinFilter {
let column_indices = vec![
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
Field::new("x", DataType::Int32, true),
Field::new("x", DataType::Int32, true),
]);

// left.b1 % 3
let left_mod = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 0)),
Operator::Modulo,
Arc::new(Literal::new(ScalarValue::Int32(Some(3)))),
)) as Arc<dyn PhysicalExpr>;
// left.b1 % 3 != 0
let left_filter = Arc::new(BinaryExpr::new(
left_mod,
Operator::NotEq,
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
)) as Arc<dyn PhysicalExpr>;

// right.b2 % 5
let right_mod = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 1)),
Operator::Modulo,
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
)) as Arc<dyn PhysicalExpr>;
// right.b2 % 5 != 0
let right_filter = Arc::new(BinaryExpr::new(
right_mod,
Operator::NotEq,
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
)) as Arc<dyn PhysicalExpr>;
// filter = left.b1 % 3 != 0 and right.b2 % 5 != 0
let filter_expression =
Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
as Arc<dyn PhysicalExpr>;

JoinFilter::new(filter_expression, column_indices, intermediate_schema)
}

fn generate_columns(num_columns: usize, num_rows: usize) -> Vec<Vec<i32>> {
let column = (1..=num_rows).map(|x| x as i32).collect();
vec![column; num_columns]
}

#[rstest]
#[tokio::test]
async fn join_maintains_right_order(
#[values(
JoinType::Inner,
JoinType::Right,
JoinType::RightAnti,
JoinType::RightSemi
)]
join_type: JoinType,
#[values(1, 100, 1000)] left_batch_size: usize,
#[values(1, 100, 1000)] right_batch_size: usize,
) -> Result<()> {
let left_columns = generate_columns(3, 1000);
let left = build_table(
("a1", &left_columns[0]),
("b1", &left_columns[1]),
("c1", &left_columns[2]),
Some(left_batch_size),
Vec::new(),
);

let right_columns = generate_columns(3, 1000);
let right = build_table(
("a2", &right_columns[0]),
("b2", &right_columns[1]),
("c2", &right_columns[2]),
Some(right_batch_size),
vec!["a2", "b2", "c2"],
);

let filter = prepare_mod_join_filter();

let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new(
left,
Arc::clone(&right),
Some(filter),
&join_type,
)?) as Arc<dyn ExecutionPlan>;
assert_eq!(nested_loop_join.maintains_input_order(), vec![false, true]);

let right_column_indices = match join_type {
JoinType::Inner | JoinType::Right => vec![3, 4, 5],
JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2],
_ => unreachable!(),
};

let right_ordering = right.output_ordering().unwrap();
let join_ordering = nested_loop_join.output_ordering().unwrap();
for (right, join) in right_ordering.iter().zip(join_ordering.iter()) {
let right_column = right.expr.as_any().downcast_ref::<Column>().unwrap();
let join_column = join.expr.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(join_column.name(), join_column.name());
assert_eq!(
right_column_indices[right_column.index()],
join_column.index()
);
assert_eq!(right.options, join.options);
}

let batches = nested_loop_join
.execute(0, Arc::new(TaskContext::default()))?
.try_collect::<Vec<_>>()
.await?;

// Make sure that the order of the right side is maintained
let mut prev_values = [i32::MIN, i32::MIN, i32::MIN];

for (batch_index, batch) in batches.iter().enumerate() {
let columns: Vec<_> = right_column_indices
.iter()
.map(|&i| {
batch
.column(i)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
})
.collect();

for row in 0..batch.num_rows() {
let current_values = [
columns[0].value(row),
columns[1].value(row),
columns[2].value(row),
];
assert!(
current_values
.into_iter()
.zip(prev_values)
.all(|(current, prev)| current >= prev),
"batch_index: {} row: {} current: {:?}, prev: {:?}",
batch_index,
row,
current_values,
prev_values
);
prev_values = current_values;
}
}

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
12 changes: 6 additions & 6 deletions datafusion/sqllogictest/test_files/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -838,10 +838,10 @@ LEFT JOIN department AS d
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
2 Bob HR
1 Alice Engineering
2 Bob Engineering
1 Alice Sales
2 Bob HR
2 Bob Engineering
2 Bob Sales
3 Carol NULL

Expand All @@ -853,10 +853,10 @@ RIGHT JOIN employees AS e
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
2 Bob HR
1 Alice Engineering
2 Bob Engineering
1 Alice Sales
2 Bob HR
2 Bob Engineering
2 Bob Sales
3 Carol NULL

Expand All @@ -868,10 +868,10 @@ FULL JOIN employees AS e
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
2 Bob HR
1 Alice Engineering
2 Bob Engineering
1 Alice Sales
2 Bob HR
2 Bob Engineering
2 Bob Sales
3 Carol NULL

Expand Down
Loading

0 comments on commit f514e12

Please sign in to comment.