From e7ac8434153560816220f0e1492057e61b7ad983 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Sat, 12 Oct 2024 09:02:44 +0300 Subject: [PATCH] Bug-fix: MemoryExec sort expressions do NOT refer to the projected schema (#12876) * Update memory.rs * add assert * Update memory.rs * Update memory.rs * Update memory.rs * address review * Update memory.rs * Update memory.rs * final fix * Fix comments in test_utils.rs --------- Co-authored-by: Mehmet Ozan Kabak --- datafusion/core/src/datasource/memory.rs | 6 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 3 +- .../sort_preserving_repartition_fuzz.rs | 3 +- .../core/tests/fuzz_cases/window_fuzz.rs | 4 +- datafusion/core/tests/memory_limit/mod.rs | 2 +- .../src/joins/nested_loop_join.rs | 2 +- .../physical-plan/src/joins/test_utils.rs | 17 +++--- datafusion/physical-plan/src/memory.rs | 58 +++++++++++++++++-- .../physical-plan/src/repartition/mod.rs | 3 +- datafusion/physical-plan/src/union.rs | 4 +- 10 files changed, 78 insertions(+), 24 deletions(-) diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 24a4938e7b2b..3c2d1b0205d6 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -37,14 +37,14 @@ use crate::physical_planner::create_physical_sort_exprs; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion_catalog::Session; use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; +use datafusion_expr::SortExpr; use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; -use datafusion_catalog::Session; -use datafusion_expr::SortExpr; use futures::StreamExt; use log::debug; use parking_lot::Mutex; @@ -241,7 +241,7 @@ impl TableProvider for MemTable { ) }) .collect::>>()?; - exec = exec.with_sort_information(file_sort_order); + exec = exec.try_with_sort_information(file_sort_order)?; } Ok(Arc::new(exec)) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index b0852501415e..64a7514ebd5e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -395,7 +395,8 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let running_source = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .with_sort_information(vec![sort_keys]), + .try_with_sort_information(vec![sort_keys]) + .unwrap(), ); let aggregate_expr = diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 0cd702372f7c..a72affc2b079 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -358,7 +358,8 @@ mod sp_repartition_fuzz_tests { let running_source = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .with_sort_information(vec![sort_keys.clone()]), + .try_with_sort_information(vec![sort_keys.clone()]) + .unwrap(), ); let hash_exprs = vec![col("c", &schema).unwrap()]; diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index b9881c9f23cf..feffb11bf700 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -647,7 +647,7 @@ async fn run_window_test( ]; let mut exec1 = Arc::new( MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), None)? - .with_sort_information(vec![source_sort_keys.clone()]), + .try_with_sort_information(vec![source_sort_keys.clone()])?, ) as _; // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. @@ -673,7 +673,7 @@ async fn run_window_test( )?) as _; let exec2 = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None)? - .with_sort_information(vec![source_sort_keys.clone()]), + .try_with_sort_information(vec![source_sort_keys.clone()])?, ); let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![create_window_expr( diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index ec66df45c7ba..fc2fb9afb5f9 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -840,7 +840,7 @@ impl TableProvider for SortedTableProvider { ) -> Result> { let mem_exec = MemoryExec::try_new(&self.batches, self.schema(), projection.cloned())? - .with_sort_information(self.sort_information.clone()); + .try_with_sort_information(self.sort_information.clone())?; Ok(Arc::new(mem_exec)) } diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 029003374acc..6068e7526316 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -780,7 +780,7 @@ mod tests { }; sort_info.push(sort_expr); } - exec = exec.with_sort_information(vec![sort_info]); + exec = exec.try_with_sort_information(vec![sort_info]).unwrap(); } Arc::new(exec) diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 264f297ffb4c..090d60f0bac3 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -289,7 +289,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(10 as $type)), (Operator::Gt, Operator::Lt), ), - // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 + // left_col - 1 > right_col + 3 AND left_col + 3 < right_col + 15 1 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -300,9 +300,9 @@ macro_rules! join_expr_tests { Operator::Plus, ), ScalarValue::$SCALAR(Some(1 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(15 as $type)), (Operator::Gt, Operator::Lt), ), // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 @@ -353,7 +353,8 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(3 as $type)), (Operator::Gt, Operator::Lt), ), - // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3 + // left_col - 2 >= right_col + 5 AND left_col + 7 <= right_col - 3 + // (filters all input rows) 5 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -369,7 +370,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(3 as $type)), (Operator::GtEq, Operator::LtEq), ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 + // left_col + 28 >= right_col - 11 AND left_col + 21 <= right_col + 39 6 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -385,7 +386,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(39 as $type)), (Operator::Gt, Operator::LtEq), ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col + 39 + // left_col + 28 >= right_col - 11 AND left_col - 21 <= right_col + 39 7 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -526,10 +527,10 @@ pub fn create_memory_table( ) -> Result<(Arc, Arc)> { let left_schema = left_partition[0].schema(); let left = MemoryExec::try_new(&[left_partition], left_schema, None)? - .with_sort_information(left_sorted); + .try_with_sort_information(left_sorted)?; let right_schema = right_partition[0].schema(); let right = MemoryExec::try_new(&[right_partition], right_schema, None)? - .with_sort_information(right_sorted); + .try_with_sort_information(right_sorted)?; Ok((Arc::new(left), Arc::new(right))) } diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 3aa445d295cb..456f0ef2dcc8 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -33,6 +33,9 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, project_schema, Result}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use futures::Stream; @@ -206,16 +209,63 @@ impl MemoryExec { /// where both `a ASC` and `b DESC` can describe the table ordering. With /// [`EquivalenceProperties`], we can keep track of these equivalences /// and treat `a ASC` and `b DESC` as the same ordering requirement. - pub fn with_sort_information(mut self, sort_information: Vec) -> Self { - self.sort_information = sort_information; + /// + /// Note that if there is an internal projection, that projection will be + /// also applied to the given `sort_information`. + pub fn try_with_sort_information( + mut self, + mut sort_information: Vec, + ) -> Result { + // All sort expressions must refer to the original schema + let fields = self.schema.fields(); + let ambiguous_column = sort_information + .iter() + .flatten() + .flat_map(|expr| collect_columns(&expr.expr)) + .find(|col| { + fields + .get(col.index()) + .map(|field| field.name() != col.name()) + .unwrap_or(true) + }); + if let Some(col) = ambiguous_column { + return internal_err!( + "Column {:?} is not found in the original schema of the MemoryExec", + col + ); + } + + // If there is a projection on the source, we also need to project orderings + if let Some(projection) = &self.projection { + let base_eqp = EquivalenceProperties::new_with_orderings( + self.original_schema(), + &sort_information, + ); + let proj_exprs = projection + .iter() + .map(|idx| { + let base_schema = self.original_schema(); + let name = base_schema.field(*idx).name(); + (Arc::new(Column::new(name, *idx)) as _, name.to_string()) + }) + .collect::>(); + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; + sort_information = base_eqp + .project(&projection_mapping, self.schema()) + .oeq_class + .orderings; + } + self.sort_information = sort_information; // We need to update equivalence properties when updating sort information. let eq_properties = EquivalenceProperties::new_with_orderings( self.schema(), &self.sort_information, ); self.cache = self.cache.with_eq_properties(eq_properties); - self + + Ok(self) } pub fn original_schema(&self) -> SchemaRef { @@ -347,7 +397,7 @@ mod tests { let sort_information = vec![sort1.clone(), sort2.clone()]; let mem_exec = MemoryExec::try_new(&[vec![]], schema, None)? - .with_sort_information(sort_information); + .try_with_sort_information(sort_information)?; assert_eq!( mem_exec.properties().output_ordering().unwrap(), diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index d9368cf86d45..902d9f4477bc 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -1677,7 +1677,8 @@ mod test { Arc::new( MemoryExec::try_new(&[vec![]], Arc::clone(schema), None) .unwrap() - .with_sort_information(vec![sort_exprs]), + .try_with_sort_information(vec![sort_exprs]) + .unwrap(), ) } } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 1cf22060b62a..108e42e7be42 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -809,11 +809,11 @@ mod tests { .collect::>(); let child1 = Arc::new( MemoryExec::try_new(&[], Arc::clone(&schema), None)? - .with_sort_information(first_orderings), + .try_with_sort_information(first_orderings)?, ); let child2 = Arc::new( MemoryExec::try_new(&[], Arc::clone(&schema), None)? - .with_sort_information(second_orderings), + .try_with_sort_information(second_orderings)?, ); let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema));