Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
irenjj committed Dec 13, 2024
1 parent 04a7443 commit 2d71cbe
Showing 1 changed file with 293 additions and 2 deletions.
295 changes: 293 additions & 2 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2378,6 +2378,7 @@ mod tests {
use arrow_array::builder::{BooleanBuilder, UInt64Builder};
use arrow_array::{BooleanArray, UInt64Array};

use datafusion_common::JoinSide;
use datafusion_common::JoinType::*;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result,
Expand All @@ -2386,10 +2387,12 @@ mod tests {
use datafusion_execution::disk_manager::DiskManagerConfig;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_execution::TaskContext;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::BinaryExpr;

use crate::expressions::Column;
use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches};
use crate::joins::utils::JoinOn;
use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
use crate::joins::SortMergeJoinExec;
use crate::memory::MemoryExec;
use crate::test::{build_table_i32, build_table_i32_two_cols};
Expand Down Expand Up @@ -2521,6 +2524,26 @@ mod tests {
)
}

fn join_with_filter(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
filter: JoinFilter,
join_type: JoinType,
sort_options: Vec<SortOptions>,
null_equals_null: bool,
) -> Result<SortMergeJoinExec> {
SortMergeJoinExec::try_new(
left,
right,
on,
Some(filter),
join_type,
sort_options,
null_equals_null,
)
}

async fn join_collect(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
Expand All @@ -2531,6 +2554,25 @@ mod tests {
join_collect_with_options(left, right, on, join_type, sort_options, false).await
}

async fn join_collect_with_filter(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
filter: JoinFilter,
join_type: JoinType,
) -> Result<(Vec<String>, Vec<RecordBatch>)> {
let sort_options = vec![SortOptions::default(); on.len()];

let task_ctx = Arc::new(TaskContext::default());
let join =
join_with_filter(left, right, on, filter, join_type, sort_options, false)?;
let columns = columns(&join.schema());

let stream = join.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;
Ok((columns, batches))
}

async fn join_collect_with_options(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
Expand Down Expand Up @@ -2943,7 +2985,7 @@ mod tests {
}

#[tokio::test]
async fn join_right_anti() -> Result<()> {
async fn join_right_anti_one() -> Result<()> {
let left = build_table(
("a1", &vec![1, 2, 2]),
("b1", &vec![4, 5, 5]),
Expand Down Expand Up @@ -2997,6 +3039,255 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn join_right_anti_two() -> Result<()> {
let left = build_table(
("a1", &vec![1, 2, 2]),
("b1", &vec![4, 5, 5]),
("c1", &vec![7, 8, 8]),
);
let right =
build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]));
let on = vec![
(
Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
),
(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
),
];

let (_, batches) = join_collect(left, right, on, RightAnti).await?;
let expected = [
"+----+----+",
"| a2 | b1 |",
"+----+----+",
"| 10 | 4 |",
"| 20 | 5 |",
"| 30 | 6 |",
"+----+----+",
];
// The output order is important as SMJ preserves sortedness
assert_batches_eq!(expected, &batches);

let left = build_table(
("a1", &vec![1, 2, 2]),
("b1", &vec![4, 5, 5]),
("c1", &vec![7, 8, 8]),
);
let right = build_table(
("a2", &vec![10, 20, 30]),
("b1", &vec![4, 5, 6]),
("c2", &vec![70, 80, 90]),
);

let on = vec![
(
Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
),
(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
),
];

let (_, batches) = join_collect(left, right, on, RightAnti).await?;
let expected = [
"+----+----+----+",
"| a2 | b1 | c2 |",
"+----+----+----+",
"| 10 | 4 | 70 |",
"| 20 | 5 | 80 |",
"| 30 | 6 | 90 |",
"+----+----+----+",
];
// The output order is important as SMJ preserves sortedness
assert_batches_eq!(expected, &batches);

Ok(())
}

#[tokio::test]
async fn join_right_anti_two_with_filter() -> Result<()> {
let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30]));
let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20]));
let on = vec![
(
Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
),
(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
),
];
let filter = JoinFilter::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("c2", 1)),
Operator::Gt,
Arc::new(Column::new("c1", 0)),
)),
vec![
ColumnIndex {
index: 2,
side: JoinSide::Left,
},
ColumnIndex {
index: 2,
side: JoinSide::Right,
},
],
Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Int32, true),
]),
);
let (_, batches) =
join_collect_with_filter(left, right, on, filter, RightAnti).await?;
let expected = [
"+----+----+----+",
"| a1 | b1 | c2 |",
"+----+----+----+",
"| 1 | 10 | 20 |",
"+----+----+----+",
];
assert_batches_eq!(expected, &batches);
Ok(())
}

#[tokio::test]
async fn join_right_anti_with_nulls() -> Result<()> {
let left = build_table_i32_nullable(
("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]),
("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]),
("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]),
);
let right = build_table_i32_nullable(
("a1", &vec![Some(1), Some(2), Some(2), Some(3)]),
("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field
("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field
);
let on = vec![
(
Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
),
(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
),
];

let (_, batches) = join_collect(left, right, on, RightAnti).await?;
let expected = [
"+----+----+----+",
"| a1 | b1 | c2 |",
"+----+----+----+",
"| 2 | | 8 |",
"+----+----+----+",
];
// The output order is important as SMJ preserves sortedness
assert_batches_eq!(expected, &batches);
Ok(())
}

#[tokio::test]
async fn join_right_anti_with_nulls_with_options() -> Result<()> {
let left = build_table_i32_nullable(
("a1", &vec![Some(1), Some(2), Some(1), Some(0), Some(2)]),
("b1", &vec![Some(4), Some(5), Some(5), None, Some(5)]),
("c1", &vec![Some(7), Some(8), Some(8), Some(60), None]),
);
let right = build_table_i32_nullable(
("a1", &vec![Some(3), Some(2), Some(2), Some(1)]),
("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field
("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field
);
let on = vec![
(
Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
),
(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
),
];

let (_, batches) = join_collect_with_options(
left,
right,
on,
RightAnti,
vec![
SortOptions {
descending: true,
nulls_first: false,
};
2
],
true,
)
.await?;

let expected = [
"+----+----+----+",
"| a1 | b1 | c2 |",
"+----+----+----+",
"| 3 | | 9 |",
"| 2 | 5 | |",
"| 2 | 5 | 8 |",
"+----+----+----+",
];
// The output order is important as SMJ preserves sortedness
assert_batches_eq!(expected, &batches);
Ok(())
}

#[tokio::test]
async fn join_right_anti_output_two_batches() -> Result<()> {
let left = build_table(
("a1", &vec![1, 2, 2]),
("b1", &vec![4, 5, 5]),
("c1", &vec![7, 8, 8]),
);
let right = build_table(
("a2", &vec![10, 20, 30]),
("b1", &vec![4, 5, 6]),
("c2", &vec![70, 80, 90]),
);
let on = vec![
(
Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
),
(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
),
];

let (_, batches) =
join_collect_batch_size_equals_two(left, right, on, LeftAnti).await?;
let expected = [
"+----+----+----+",
"| a1 | b1 | c1 |",
"+----+----+----+",
"| 1 | 4 | 7 |",
"| 2 | 5 | 8 |",
"| 2 | 5 | 8 |",
"+----+----+----+",
];
assert_eq!(batches.len(), 2);
assert_eq!(batches[0].num_rows(), 2);
assert_eq!(batches[1].num_rows(), 1);
assert_batches_eq!(expected, &batches);
Ok(())
}

#[tokio::test]
async fn join_semi() -> Result<()> {
let left = build_table(
Expand Down

0 comments on commit 2d71cbe

Please sign in to comment.