From 2d71cbe87d09a8fdc399976b425d75ed73e58bbb Mon Sep 17 00:00:00 2001 From: irenjj Date: Fri, 13 Dec 2024 23:20:24 +0800 Subject: [PATCH] add unit tests --- .../src/joins/sort_merge_join.rs | 295 +++++++++++++++++- 1 file changed, 293 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index f8a0122be17c..378cee0a7b71 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -2378,6 +2378,7 @@ mod tests { use arrow_array::builder::{BooleanBuilder, UInt64Builder}; use arrow_array::{BooleanArray, UInt64Array}; + use datafusion_common::JoinSide; use datafusion_common::JoinType::*; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, @@ -2386,10 +2387,12 @@ mod tests { use datafusion_execution::disk_manager::DiskManagerConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::BinaryExpr; use crate::expressions::Column; use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches}; - use crate::joins::utils::JoinOn; + use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn}; use crate::joins::SortMergeJoinExec; use crate::memory::MemoryExec; use crate::test::{build_table_i32, build_table_i32_two_cols}; @@ -2521,6 +2524,26 @@ mod tests { ) } + fn join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: JoinType, + sort_options: Vec, + null_equals_null: bool, + ) -> Result { + SortMergeJoinExec::try_new( + left, + right, + on, + Some(filter), + join_type, + sort_options, + null_equals_null, + ) + } + async fn join_collect( left: Arc, right: Arc, @@ -2531,6 +2554,25 @@ mod tests { join_collect_with_options(left, right, on, join_type, sort_options, false).await } + async fn join_collect_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let sort_options = vec![SortOptions::default(); on.len()]; + + let task_ctx = Arc::new(TaskContext::default()); + let join = + join_with_filter(left, right, on, filter, join_type, sort_options, false)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + async fn join_collect_with_options( left: Arc, right: Arc, @@ -2943,7 +2985,7 @@ mod tests { } #[tokio::test] - async fn join_right_anti() -> Result<()> { + async fn join_right_anti_one() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 2]), ("b1", &vec![4, 5, 5]), @@ -2997,6 +3039,255 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_right_anti_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = + build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+", + "| a2 | b1 |", + "+----+----+", + "| 10 | 4 |", + "| 20 | 5 |", + "| 30 | 6 |", + "+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+----+", + "| a2 | b1 | c2 |", + "+----+----+----+", + "| 10 | 4 | 70 |", + "| 20 | 5 | 80 |", + "| 30 | 6 | 90 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_two_with_filter() -> Result<()> { + let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30])); + let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c2", 1)), + Operator::Gt, + Arc::new(Column::new("c1", 0)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ]), + ); + let (_, batches) = + join_collect_with_filter(left, right, on, filter, RightAnti).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 10 | 20 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_with_nulls() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]), + ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field + ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 2 | | 8 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_with_nulls_with_options() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(1), Some(0), Some(2)]), + ("b1", &vec![Some(4), Some(5), Some(5), None, Some(5)]), + ("c1", &vec![Some(7), Some(8), Some(8), Some(60), None]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]), + ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field + ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect_with_options( + left, + right, + on, + RightAnti, + vec![ + SortOptions { + descending: true, + nulls_first: false, + }; + 2 + ], + true, + ) + .await?; + + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 3 | | 9 |", + "| 2 | 5 | |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_output_two_batches() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = + join_collect_batch_size_equals_two(left, right, on, LeftAnti).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 1); + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_semi() -> Result<()> { let left = build_table(