Skip to content

Commit

Permalink
Fix bug in swap_hash_join (#278)
Browse files Browse the repository at this point in the history
* Try and fix swap_hash_join

* Only swap projections when join does not have projections

* just backport upstream fix

* remove println
  • Loading branch information
thinkharderdev authored Nov 22, 2024
1 parent 6c4432f commit 47348a1
Showing 1 changed file with 88 additions and 16 deletions.
104 changes: 88 additions & 16 deletions datafusion/core/src/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
//! pipeline-friendly ones. To achieve the second goal, it selects the proper
//! `PartitionMode` and the build side using the available statistics for hash joins.
use std::sync::Arc;

use crate::config::ConfigOptions;
use crate::error::Result;
use crate::physical_optimizer::PhysicalOptimizerRule;
Expand All @@ -35,6 +33,7 @@ use crate::physical_plan::joins::{
};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use std::sync::Arc;

use arrow_schema::Schema;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
Expand Down Expand Up @@ -140,20 +139,32 @@ fn swap_join_projection(
left_schema_len: usize,
right_schema_len: usize,
projection: Option<&Vec<usize>>,
join_type: &JoinType,
) -> Option<Vec<usize>> {
projection.map(|p| {
p.iter()
.map(|i| {
// If the index is less than the left schema length, it is from the left schema, so we add the right schema length to it.
// Otherwise, it is from the right schema, so we subtract the left schema length from it.
if *i < left_schema_len {
*i + right_schema_len
} else {
*i - left_schema_len
}
})
.collect()
})
match join_type {
// For Anti/Semi join types, projection should remain unmodified,
// since these joins output schema remains the same after swap
JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::RightAnti
| JoinType::RightSemi => projection.cloned(),

_ => projection.map(|p| {
p.iter()
.map(|i| {
// If the index is less than the left schema length, it is from
// the left schema, so we add the right schema length to it.
// Otherwise, it is from the right schema, so we subtract the left
// schema length from it.
if *i < left_schema_len {
*i + right_schema_len
} else {
*i - left_schema_len
}
})
.collect()
}),
}
}

/// This function swaps the inputs of the given join operator.
Expand All @@ -179,6 +190,7 @@ pub fn swap_hash_join(
left.schema().fields().len(),
right.schema().fields().len(),
hash_join.projection.as_ref(),
hash_join.join_type(),
),
partition_mode,
hash_join.null_equals_null(),
Expand All @@ -189,7 +201,8 @@ pub fn swap_hash_join(
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
) {
) || hash_join.projection.is_some()
{
Ok(Arc::new(new_join))
} else {
// TODO avoid adding ProjectionExec again and again, only adding Final Projection
Expand Down Expand Up @@ -1158,6 +1171,65 @@ mod tests_statistical {
);
}

#[rstest(
join_type, projection, small_on_right,
case::inner(JoinType::Inner, vec![1], true),
case::left(JoinType::Left, vec![1], true),
case::right(JoinType::Right, vec![1], true),
case::full(JoinType::Full, vec![1], true),
case::left_anti(JoinType::LeftAnti, vec![0], false),
case::left_semi(JoinType::LeftSemi, vec![0], false),
case::right_anti(JoinType::RightAnti, vec![0], true),
case::right_semi(JoinType::RightSemi, vec![0], true),
)]
#[tokio::test]
async fn test_hash_join_swap_on_joins_with_projections(
join_type: JoinType,
projection: Vec<usize>,
small_on_right: bool,
) -> Result<()> {
let (big, small) = create_big_and_small();

let left = if small_on_right { &big } else { &small };
let right = if small_on_right { &small } else { &big };

let left_on = if small_on_right {
"big_col"
} else {
"small_col"
};
let right_on = if small_on_right {
"small_col"
} else {
"big_col"
};

let join = Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
vec![(
Arc::new(Column::new_with_schema(left_on, &left.schema())?),
Arc::new(Column::new_with_schema(right_on, &right.schema())?),
)],
None,
&join_type,
Some(projection),
PartitionMode::Partitioned,
false,
)?);

let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned)
.expect("swap_hash_join must support joins with projections");
let swapped_join = swapped.as_any().downcast_ref::<HashJoinExec>().expect(
"ProjectionExec won't be added above if HashJoinExec contains embedded projection",
);

assert_eq!(swapped_join.projection, Some(vec![0_usize]));
assert_eq!(swapped.schema().fields.len(), 1);
assert_eq!(swapped.schema().fields[0].name(), "small_col");
Ok(())
}

#[rstest(
join_type,
case::inner(JoinType::Inner),
Expand Down

0 comments on commit 47348a1

Please sign in to comment.