diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 063f35059fb8..14835f717ea3 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1583,7 +1583,6 @@ mod tests { use rstest::*; use rstest_reuse::*; - #[cfg(not(feature = "force_hash_collisions"))] fn div_ceil(a: usize, b: usize) -> usize { (a + b - 1) / b } @@ -1931,9 +1930,6 @@ mod tests { Ok(()) } - // FIXME(#TODO) test fails with feature `force_hash_collisions` - // https://github.com/apache/datafusion/issues/11658 - #[cfg(not(feature = "force_hash_collisions"))] #[apply(batch_sizes)] #[tokio::test] async fn join_inner_two(batch_size: usize) -> Result<()> { @@ -1964,12 +1960,20 @@ mod tests { assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); - // expected joined records = 3 - // in case batch_size is 1 - additional empty batch for remaining 3-2 row - let mut expected_batch_count = div_ceil(3, batch_size); - if batch_size == 1 { - expected_batch_count += 1; - } + let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { + // Expected number of hash table matches = 3 + // in case batch_size is 1 - additional empty batch for remaining 3-2 row + let mut expected_batch_count = div_ceil(3, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + expected_batch_count + } else { + // With hash collisions enabled, all records will match each other + // and filtered later. + div_ceil(9, batch_size) + }; + assert_eq!(batches.len(), expected_batch_count); let expected = [ @@ -1989,9 +1993,6 @@ mod tests { } /// Test where the left has 2 parts, the right with 1 part => 1 part - // FIXME(#TODO) test fails with feature `force_hash_collisions` - // https://github.com/apache/datafusion/issues/11658 - #[cfg(not(feature = "force_hash_collisions"))] #[apply(batch_sizes)] #[tokio::test] async fn join_inner_one_two_parts_left(batch_size: usize) -> Result<()> { @@ -2029,12 +2030,20 @@ mod tests { assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); - // expected joined records = 3 - // in case batch_size is 1 - additional empty batch for remaining 3-2 row - let mut expected_batch_count = div_ceil(3, batch_size); - if batch_size == 1 { - expected_batch_count += 1; - } + let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { + // Expected number of hash table matches = 3 + // in case batch_size is 1 - additional empty batch for remaining 3-2 row + let mut expected_batch_count = div_ceil(3, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + expected_batch_count + } else { + // With hash collisions enabled, all records will match each other + // and filtered later. + div_ceil(9, batch_size) + }; + assert_eq!(batches.len(), expected_batch_count); let expected = [ @@ -2104,9 +2113,6 @@ mod tests { } /// Test where the left has 1 part, the right has 2 parts => 2 parts - // FIXME(#TODO) test fails with feature `force_hash_collisions` - // https://github.com/apache/datafusion/issues/11658 - #[cfg(not(feature = "force_hash_collisions"))] #[apply(batch_sizes)] #[tokio::test] async fn join_inner_one_two_parts_right(batch_size: usize) -> Result<()> { @@ -2143,12 +2149,19 @@ mod tests { let stream = join.execute(0, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; - // expected joined records = 1 (first right batch) - // and additional empty batch for non-joined 20-6-80 - let mut expected_batch_count = div_ceil(1, batch_size); - if batch_size == 1 { - expected_batch_count += 1; - } + let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { + // Expected number of hash table matches for first right batch = 1 + // and additional empty batch for non-joined 20-6-80 + let mut expected_batch_count = div_ceil(1, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + expected_batch_count + } else { + // With hash collisions enabled, all records will match each other + // and filtered later. + div_ceil(6, batch_size) + }; assert_eq!(batches.len(), expected_batch_count); let expected = [ @@ -2166,8 +2179,14 @@ mod tests { let stream = join.execute(1, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; - // expected joined records = 2 (second right batch) - let expected_batch_count = div_ceil(2, batch_size); + let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { + // Expected number of hash table matches for second right batch = 2 + div_ceil(2, batch_size) + } else { + // With hash collisions enabled, all records will match each other + // and filtered later. + div_ceil(3, batch_size) + }; assert_eq!(batches.len(), expected_batch_count); let expected = [ @@ -3732,9 +3751,9 @@ mod tests { | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - (expected_resultset_records + batch_size - 1) / batch_size + div_ceil(expected_resultset_records, batch_size) } - _ => (expected_resultset_records + batch_size - 1) / batch_size + 1, + _ => div_ceil(expected_resultset_records, batch_size) + 1, }; assert_eq!( batches.len(),