From b3c94a9d89745baa8448eacc18b7bfe988dc6230 Mon Sep 17 00:00:00 2001 From: irenjj Date: Sat, 7 Dec 2024 17:04:42 +0800 Subject: [PATCH 1/6] feat: support `RightAnti` for `SortMergeJoin` --- .../src/joins/sort_merge_join.rs | 72 ++++++++++++++++++- datafusion/physical-plan/src/test.rs | 20 ++++++ 2 files changed, 90 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 43f698c24d05..ffb5c2eefa66 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1434,6 +1434,7 @@ impl SortMergeJoinStream { | JoinType::RightSemi | JoinType::Full | JoinType::LeftAnti + | JoinType::RightAnti | JoinType::LeftMark ) { join_streamed = !self.streamed_joined; @@ -1650,7 +1651,10 @@ impl SortMergeJoinStream { let right_indices: UInt64Array = chunk.buffered_indices.finish(); let mut right_columns = if matches!(self.join_type, JoinType::LeftMark) { vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef] - } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + } else if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightAnti + ) { vec![] } else if let Some(buffered_idx) = chunk.buffered_batch_idx { fetch_right_columns_by_idxs( @@ -2369,7 +2373,7 @@ mod tests { use crate::joins::utils::JoinOn; use crate::joins::SortMergeJoinExec; use crate::memory::MemoryExec; - use crate::test::build_table_i32; + use crate::test::{build_table_i32, build_table_i32_two_cols}; use crate::{common, ExecutionPlan}; fn build_table( @@ -2460,6 +2464,15 @@ mod tests { Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) } + pub fn build_table_two_cols( + a: (&str, &Vec), + b: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32_two_cols(a, b); + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + fn join( left: Arc, right: Arc, @@ -2910,6 +2923,61 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_right_anti() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = + build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6])); + let mut on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+", + "| a2 | b1 |", + "+----+----+", + "| 30 | 6 |", + "+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + + let left2 = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right2 = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left2.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right2.schema())?) as _, + )]; + + let (_, batches2) = join_collect(left2, right2, on, RightAnti).await?; + let expected2 = [ + "+----+----+----+", + "| a2 | b1 | c2 |", + "+----+----+----+", + "| 30 | 6 | 90 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected2, &batches2); + + Ok(()) + } + #[tokio::test] async fn join_semi() -> Result<()> { let left = build_table( diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index 90ec9b106850..b7bbfd116954 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -88,6 +88,26 @@ pub fn build_table_i32( .unwrap() } +/// Returns record batch with 2 columns of i32 in memory +pub fn build_table_i32_two_cols( + a: (&str, &Vec), + b: (&str, &Vec), +) -> RecordBatch { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, false), + Field::new(b.0, DataType::Int32, false), + ]); + + RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + ], + ) + .unwrap() +} + /// Returns memory table scan wrapped around record batch with 3 columns of i32 pub fn build_table_scan_i32( a: (&str, &Vec), From 7976c0d65e998526e306ea153cc47e4e320b020a Mon Sep 17 00:00:00 2001 From: irenjj Date: Mon, 9 Dec 2024 21:20:57 +0800 Subject: [PATCH 2/6] add filter --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 28 +++++++++++++++++-- .../src/joins/sort_merge_join.rs | 24 ++++++++++++---- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index cf1742a30e66..b331388f4f3f 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -210,7 +210,7 @@ async fn test_semi_join_1k_filtered() { } #[tokio::test] -async fn test_anti_join_1k() { +async fn test_left_anti_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), @@ -222,7 +222,7 @@ async fn test_anti_join_1k() { } #[tokio::test] -async fn test_anti_join_1k_filtered() { +async fn test_left_anti_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), @@ -233,6 +233,30 @@ async fn test_anti_join_1k_filtered() { .await } +#[tokio::test] +async fn test_right_anti_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_anti_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + #[tokio::test] async fn test_left_mark_join_1k() { JoinFuzzTestCase::new( diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index ffb5c2eefa66..f464ded1d34b 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -901,7 +901,7 @@ fn get_corrected_filter_mask( Some(corrected_mask.finish()) } - JoinType::LeftAnti => { + JoinType::LeftAnti | JoinType::RightAnti => { for i in 0..row_indices_length { let last_index = last_index_for_row(i, row_indices, batch_ids, row_indices_length); @@ -1013,6 +1013,7 @@ impl Stream for SortMergeJoinStream { | JoinType::LeftMark | JoinType::Right | JoinType::LeftAnti + | JoinType::RightAnti | JoinType::Full ) { @@ -1095,6 +1096,7 @@ impl Stream for SortMergeJoinStream { | JoinType::LeftSemi | JoinType::Right | JoinType::LeftAnti + | JoinType::RightAnti | JoinType::LeftMark | JoinType::Full ) @@ -1118,6 +1120,7 @@ impl Stream for SortMergeJoinStream { | JoinType::LeftSemi | JoinType::Right | JoinType::LeftAnti + | JoinType::RightAnti | JoinType::Full | JoinType::LeftMark ) @@ -1468,7 +1471,9 @@ impl SortMergeJoinStream { join_buffered = true; }; - if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() { + if matches!(self.join_type, JoinType::LeftAnti | JoinType::RightAnti) + && self.filter.is_some() + { join_streamed = !self.streamed_joined; join_buffered = join_streamed; } @@ -1678,7 +1683,10 @@ impl SortMergeJoinStream { if !matches!(self.join_type, JoinType::Right) { if matches!( self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark ) { let right_cols = fetch_right_columns_by_idxs( &self.buffered_data, @@ -1743,6 +1751,7 @@ impl SortMergeJoinStream { | JoinType::LeftSemi | JoinType::Right | JoinType::LeftAnti + | JoinType::RightAnti | JoinType::LeftMark | JoinType::Full ) { @@ -1826,6 +1835,7 @@ impl SortMergeJoinStream { | JoinType::LeftSemi | JoinType::Right | JoinType::LeftAnti + | JoinType::RightAnti | JoinType::LeftMark | JoinType::Full )) @@ -1924,6 +1934,10 @@ impl SortMergeJoinStream { let output_column_indices = (0..left_columns_length).collect::>(); filtered_record_batch = filtered_record_batch.project(&output_column_indices)?; + } else if matches!(self.join_type, JoinType::RightAnti) { + let output_column_indices = (0..right_columns_length).collect::>(); + filtered_record_batch = + filtered_record_batch.project(&output_column_indices)?; } else if matches!(self.join_type, JoinType::Full) && corrected_mask.false_count() > 0 { @@ -2893,7 +2907,7 @@ mod tests { } #[tokio::test] - async fn join_anti() -> Result<()> { + async fn join_left_anti() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 2, 3, 5]), ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right @@ -2932,7 +2946,7 @@ mod tests { ); let right = build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6])); - let mut on = vec![( + let on = vec![( Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; From a275521cbd2f2c701be7560d9c579ba4c82d36a0 Mon Sep 17 00:00:00 2001 From: irenjj Date: Wed, 11 Dec 2024 20:36:12 +0800 Subject: [PATCH 3/6] fix filter --- .../src/joins/sort_merge_join.rs | 325 +++++++++--------- 1 file changed, 166 insertions(+), 159 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index f464ded1d34b..f8a0122be17c 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1683,10 +1683,7 @@ impl SortMergeJoinStream { if !matches!(self.join_type, JoinType::Right) { if matches!( self.join_type, - JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftMark + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark ) { let right_cols = fetch_right_columns_by_idxs( &self.buffered_data, @@ -1695,6 +1692,14 @@ impl SortMergeJoinStream { )?; get_filter_column(&self.filter, &left_columns, &right_cols) + } else if matches!(self.join_type, JoinType::RightAnti) { + let right_cols = fetch_right_columns_by_idxs( + &self.buffered_data, + chunk.buffered_batch_idx.unwrap(), + &right_indices, + )?; + + get_filter_column(&self.filter, &right_cols, &left_columns) } else { get_filter_column(&self.filter, &left_columns, &right_columns) } @@ -4186,174 +4191,176 @@ mod tests { } #[tokio::test] - async fn test_left_anti_join_filtered_mask() -> Result<()> { - let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); - - let output = concat_batches(&schema, &joined_batches.batches)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); + async fn test_anti_join_filtered_mask() -> Result<()> { + for join_type in [LeftAnti, RightAnti] { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftAnti, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![true]), - 1 - ) - .unwrap(), - BooleanArray::from(vec![None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![Some(true)]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftAnti, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![false]), - 1 - ) - .unwrap(), - BooleanArray::from(vec![Some(true)]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + 2 + ) + .unwrap(), + BooleanArray::from(vec![None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0]), - &[0usize; 2], - &BooleanArray::from(vec![true, true]), - 2 - ) - .unwrap(), - BooleanArray::from(vec![None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, true, true]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, false, true]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, true]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, true, true]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true)]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, false]), - 3 + let corrected_mask = get_corrected_filter_mask( + join_type, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), ) - .unwrap(), - BooleanArray::from(vec![None, None, Some(true)]) - ); - - let corrected_mask = get_corrected_filter_mask( - LeftAnti, - &out_indices, - &joined_batches.batch_ids, - &out_mask, - output.num_rows(), - ) - .unwrap(); - - assert_eq!( - corrected_mask, - BooleanArray::from(vec![ - None, - None, - None, - None, - None, - Some(true), - None, - Some(true) - ]) - ); - - let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(true), + None, + Some(true) + ]) + ); - assert_batches_eq!( - &[ - "+---+----+---+----+", - "| a | b | x | y |", - "+---+----+---+----+", - "| 1 | 13 | 1 | 12 |", - "| 1 | 14 | 1 | 11 |", - "+---+----+---+----+", - ], - &[filtered_rb] - ); + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 13 | 1 | 12 |", + "| 1 | 14 | 1 | 11 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); - // output null rows - let null_mask = arrow::compute::not(&corrected_mask)?; - assert_eq!( - null_mask, - BooleanArray::from(vec![ - None, - None, - None, - None, - None, - Some(false), - None, - Some(false), - ]) - ); + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(false), + None, + Some(false), + ]) + ); - let null_joined_batch = filter_record_batch(&output, &null_mask)?; + let null_joined_batch = filter_record_batch(&output, &null_mask)?; - assert_batches_eq!( - &[ - "+---+---+---+---+", - "| a | b | x | y |", - "+---+---+---+---+", - "+---+---+---+---+", - ], - &[null_joined_batch] - ); + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); + } Ok(()) } From 04a74435c237df7b1ddc81b39d67190f3b415595 Mon Sep 17 00:00:00 2001 From: irenjj Date: Wed, 11 Dec 2024 22:48:44 +0800 Subject: [PATCH 4/6] add sqllogic test --- .../test_files/sort_merge_join.slt | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 9a20e7987ff6..1df52dd1eb3d 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -647,6 +647,54 @@ NULL NULL 7 9 NULL NULL 8 10 NULL NULL 9 11 +query II +select * from ( +with +t1 as ( + select 31 a, 32 b union all + select 31 a, 33 b +), +t2 as ( + select 31 a, 32 b union all + select 31 a, 35 b +) +select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b +) order by 1, 2; +---- +31 35 + +query II +select * from ( +with +t1 as ( + select 41 a, 42 b union all + select 41 a, 43 b +), +t2 as ( + select 41 a, 42 b union all + select 41 a, 45 b +) +select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b +) order by 1, 2; +---- +41 45 + +query II +select * from ( +with +t1 as ( + select 51 a, 52 b union all + select 51 a, 53 b +), +t2 as ( + select 51 a, 52 b union all + select 51 a, 54 b +) +select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b +) order by 1, 2; +---- +51 54 + # return sql params back to default values statement ok set datafusion.optimizer.prefer_hash_join = true; From 2d71cbe87d09a8fdc399976b425d75ed73e58bbb Mon Sep 17 00:00:00 2001 From: irenjj Date: Fri, 13 Dec 2024 23:20:24 +0800 Subject: [PATCH 5/6] add unit tests --- .../src/joins/sort_merge_join.rs | 295 +++++++++++++++++- 1 file changed, 293 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index f8a0122be17c..378cee0a7b71 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -2378,6 +2378,7 @@ mod tests { use arrow_array::builder::{BooleanBuilder, UInt64Builder}; use arrow_array::{BooleanArray, UInt64Array}; + use datafusion_common::JoinSide; use datafusion_common::JoinType::*; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, @@ -2386,10 +2387,12 @@ mod tests { use datafusion_execution::disk_manager::DiskManagerConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::BinaryExpr; use crate::expressions::Column; use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches}; - use crate::joins::utils::JoinOn; + use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn}; use crate::joins::SortMergeJoinExec; use crate::memory::MemoryExec; use crate::test::{build_table_i32, build_table_i32_two_cols}; @@ -2521,6 +2524,26 @@ mod tests { ) } + fn join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: JoinType, + sort_options: Vec, + null_equals_null: bool, + ) -> Result { + SortMergeJoinExec::try_new( + left, + right, + on, + Some(filter), + join_type, + sort_options, + null_equals_null, + ) + } + async fn join_collect( left: Arc, right: Arc, @@ -2531,6 +2554,25 @@ mod tests { join_collect_with_options(left, right, on, join_type, sort_options, false).await } + async fn join_collect_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let sort_options = vec![SortOptions::default(); on.len()]; + + let task_ctx = Arc::new(TaskContext::default()); + let join = + join_with_filter(left, right, on, filter, join_type, sort_options, false)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + async fn join_collect_with_options( left: Arc, right: Arc, @@ -2943,7 +2985,7 @@ mod tests { } #[tokio::test] - async fn join_right_anti() -> Result<()> { + async fn join_right_anti_one() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 2]), ("b1", &vec![4, 5, 5]), @@ -2997,6 +3039,255 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_right_anti_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = + build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+", + "| a2 | b1 |", + "+----+----+", + "| 10 | 4 |", + "| 20 | 5 |", + "| 30 | 6 |", + "+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+----+", + "| a2 | b1 | c2 |", + "+----+----+----+", + "| 10 | 4 | 70 |", + "| 20 | 5 | 80 |", + "| 30 | 6 | 90 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_two_with_filter() -> Result<()> { + let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30])); + let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c2", 1)), + Operator::Gt, + Arc::new(Column::new("c1", 0)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ]), + ); + let (_, batches) = + join_collect_with_filter(left, right, on, filter, RightAnti).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 10 | 20 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_with_nulls() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]), + ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field + ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 2 | | 8 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_with_nulls_with_options() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(1), Some(0), Some(2)]), + ("b1", &vec![Some(4), Some(5), Some(5), None, Some(5)]), + ("c1", &vec![Some(7), Some(8), Some(8), Some(60), None]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]), + ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field + ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect_with_options( + left, + right, + on, + RightAnti, + vec![ + SortOptions { + descending: true, + nulls_first: false, + }; + 2 + ], + true, + ) + .await?; + + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 3 | | 9 |", + "| 2 | 5 | |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_output_two_batches() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = + join_collect_batch_size_equals_two(left, right, on, LeftAnti).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 1); + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_semi() -> Result<()> { let left = build_table( From 0e16d57c4e90a64e7498ed56241176771a4ce59a Mon Sep 17 00:00:00 2001 From: irenjj Date: Mon, 30 Dec 2024 22:49:47 +0800 Subject: [PATCH 6/6] fix issues --- .../physical-plan/src/joins/sort_merge_join.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 378cee0a7b71..7c0613f95e3a 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1887,6 +1887,14 @@ impl SortMergeJoinStream { &out_mask }; + self.filter_record_batch_by_join_type(record_batch, corrected_mask) + } + + fn filter_record_batch_by_join_type( + &mut self, + record_batch: RecordBatch, + corrected_mask: &BooleanArray, + ) -> Result { let mut filtered_record_batch = filter_record_batch(&record_batch, corrected_mask)?; let left_columns_length = self.streamed_schema.fields.len(); @@ -2985,7 +2993,7 @@ mod tests { } #[tokio::test] - async fn join_right_anti_one() -> Result<()> { + async fn join_right_anti_one_one() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 2]), ("b1", &vec![4, 5, 5]), @@ -3040,7 +3048,7 @@ mod tests { } #[tokio::test] - async fn join_right_anti_two() -> Result<()> { + async fn join_right_anti_two_two() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 2]), ("b1", &vec![4, 5, 5]),