Skip to content

Commit

Permalink
Bug-fix: MemoryExec sort expressions do NOT refer to the projected sc…
Browse files Browse the repository at this point in the history
…hema (#12876)

* Update memory.rs

* add assert

* Update memory.rs

* Update memory.rs

* Update memory.rs

* address review

* Update memory.rs

* Update memory.rs

* final fix

* Fix comments in test_utils.rs

---------

Co-authored-by: Mehmet Ozan Kabak <[email protected]>
  • Loading branch information
berkaysynnada and ozankabak authored Oct 12, 2024
1 parent 6267ede commit e7ac843
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 24 deletions.
6 changes: 3 additions & 3 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ use crate::physical_planner::create_physical_sort_exprs;

use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_catalog::Session;
use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt};
use datafusion_execution::TaskContext;
use datafusion_expr::dml::InsertOp;
use datafusion_expr::SortExpr;
use datafusion_physical_plan::metrics::MetricsSet;

use async_trait::async_trait;
use datafusion_catalog::Session;
use datafusion_expr::SortExpr;
use futures::StreamExt;
use log::debug;
use parking_lot::Mutex;
Expand Down Expand Up @@ -241,7 +241,7 @@ impl TableProvider for MemTable {
)
})
.collect::<Result<Vec<_>>>()?;
exec = exec.with_sort_information(file_sort_order);
exec = exec.try_with_sort_information(file_sort_order)?;
}

Ok(Arc::new(exec))
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ async fn run_aggregate_test(input1: Vec<RecordBatch>, group_by_columns: Vec<&str
let running_source = Arc::new(
MemoryExec::try_new(&[input1.clone()], schema.clone(), None)
.unwrap()
.with_sort_information(vec![sort_keys]),
.try_with_sort_information(vec![sort_keys])
.unwrap(),
);

let aggregate_expr =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ mod sp_repartition_fuzz_tests {
let running_source = Arc::new(
MemoryExec::try_new(&[input1.clone()], schema.clone(), None)
.unwrap()
.with_sort_information(vec![sort_keys.clone()]),
.try_with_sort_information(vec![sort_keys.clone()])
.unwrap(),
);
let hash_exprs = vec![col("c", &schema).unwrap()];

Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ async fn run_window_test(
];
let mut exec1 = Arc::new(
MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), None)?
.with_sort_information(vec![source_sort_keys.clone()]),
.try_with_sort_information(vec![source_sort_keys.clone()])?,
) as _;
// Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a
// For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort.
Expand All @@ -673,7 +673,7 @@ async fn run_window_test(
)?) as _;
let exec2 = Arc::new(
MemoryExec::try_new(&[input1.clone()], schema.clone(), None)?
.with_sort_information(vec![source_sort_keys.clone()]),
.try_with_sort_information(vec![source_sort_keys.clone()])?,
);
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
vec![create_window_expr(
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/memory_limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ impl TableProvider for SortedTableProvider {
) -> Result<Arc<dyn ExecutionPlan>> {
let mem_exec =
MemoryExec::try_new(&self.batches, self.schema(), projection.cloned())?
.with_sort_information(self.sort_information.clone());
.try_with_sort_information(self.sort_information.clone())?;

Ok(Arc::new(mem_exec))
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-plan/src/joins/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ mod tests {
};
sort_info.push(sort_expr);
}
exec = exec.with_sort_information(vec![sort_info]);
exec = exec.try_with_sort_information(vec![sort_info]).unwrap();
}

Arc::new(exec)
Expand Down
17 changes: 9 additions & 8 deletions datafusion/physical-plan/src/joins/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ macro_rules! join_expr_tests {
ScalarValue::$SCALAR(Some(10 as $type)),
(Operator::Gt, Operator::Lt),
),
// left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10
// left_col - 1 > right_col + 3 AND left_col + 3 < right_col + 15
1 => gen_conjunctive_numerical_expr(
left_col,
right_col,
Expand All @@ -300,9 +300,9 @@ macro_rules! join_expr_tests {
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $type)),
ScalarValue::$SCALAR(Some(5 as $type)),
ScalarValue::$SCALAR(Some(3 as $type)),
ScalarValue::$SCALAR(Some(10 as $type)),
ScalarValue::$SCALAR(Some(3 as $type)),
ScalarValue::$SCALAR(Some(15 as $type)),
(Operator::Gt, Operator::Lt),
),
// left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10
Expand Down Expand Up @@ -353,7 +353,8 @@ macro_rules! join_expr_tests {
ScalarValue::$SCALAR(Some(3 as $type)),
(Operator::Gt, Operator::Lt),
),
// left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3
// left_col - 2 >= right_col + 5 AND left_col + 7 <= right_col - 3
// (filters all input rows)
5 => gen_conjunctive_numerical_expr(
left_col,
right_col,
Expand All @@ -369,7 +370,7 @@ macro_rules! join_expr_tests {
ScalarValue::$SCALAR(Some(3 as $type)),
(Operator::GtEq, Operator::LtEq),
),
// left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39
// left_col + 28 >= right_col - 11 AND left_col + 21 <= right_col + 39
6 => gen_conjunctive_numerical_expr(
left_col,
right_col,
Expand All @@ -385,7 +386,7 @@ macro_rules! join_expr_tests {
ScalarValue::$SCALAR(Some(39 as $type)),
(Operator::Gt, Operator::LtEq),
),
// left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col + 39
// left_col + 28 >= right_col - 11 AND left_col - 21 <= right_col + 39
7 => gen_conjunctive_numerical_expr(
left_col,
right_col,
Expand Down Expand Up @@ -526,10 +527,10 @@ pub fn create_memory_table(
) -> Result<(Arc<dyn ExecutionPlan>, Arc<dyn ExecutionPlan>)> {
let left_schema = left_partition[0].schema();
let left = MemoryExec::try_new(&[left_partition], left_schema, None)?
.with_sort_information(left_sorted);
.try_with_sort_information(left_sorted)?;
let right_schema = right_partition[0].schema();
let right = MemoryExec::try_new(&[right_partition], right_schema, None)?
.with_sort_information(right_sorted);
.try_with_sort_information(right_sorted)?;
Ok((Arc::new(left), Arc::new(right)))
}

Expand Down
58 changes: 54 additions & 4 deletions datafusion/physical-plan/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ use arrow::record_batch::RecordBatch;
use datafusion_common::{internal_err, project_schema, Result};
use datafusion_execution::memory_pool::MemoryReservation;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{EquivalenceProperties, LexOrdering};

use futures::Stream;
Expand Down Expand Up @@ -206,16 +209,63 @@ impl MemoryExec {
/// where both `a ASC` and `b DESC` can describe the table ordering. With
/// [`EquivalenceProperties`], we can keep track of these equivalences
/// and treat `a ASC` and `b DESC` as the same ordering requirement.
pub fn with_sort_information(mut self, sort_information: Vec<LexOrdering>) -> Self {
self.sort_information = sort_information;
///
/// Note that if there is an internal projection, that projection will be
/// also applied to the given `sort_information`.
pub fn try_with_sort_information(
mut self,
mut sort_information: Vec<LexOrdering>,
) -> Result<Self> {
// All sort expressions must refer to the original schema
let fields = self.schema.fields();
let ambiguous_column = sort_information
.iter()
.flatten()
.flat_map(|expr| collect_columns(&expr.expr))
.find(|col| {
fields
.get(col.index())
.map(|field| field.name() != col.name())
.unwrap_or(true)
});
if let Some(col) = ambiguous_column {
return internal_err!(
"Column {:?} is not found in the original schema of the MemoryExec",
col
);
}

// If there is a projection on the source, we also need to project orderings
if let Some(projection) = &self.projection {
let base_eqp = EquivalenceProperties::new_with_orderings(
self.original_schema(),
&sort_information,
);
let proj_exprs = projection
.iter()
.map(|idx| {
let base_schema = self.original_schema();
let name = base_schema.field(*idx).name();
(Arc::new(Column::new(name, *idx)) as _, name.to_string())
})
.collect::<Vec<_>>();
let projection_mapping =
ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?;
sort_information = base_eqp
.project(&projection_mapping, self.schema())
.oeq_class
.orderings;
}

self.sort_information = sort_information;
// We need to update equivalence properties when updating sort information.
let eq_properties = EquivalenceProperties::new_with_orderings(
self.schema(),
&self.sort_information,
);
self.cache = self.cache.with_eq_properties(eq_properties);
self

Ok(self)
}

pub fn original_schema(&self) -> SchemaRef {
Expand Down Expand Up @@ -347,7 +397,7 @@ mod tests {

let sort_information = vec![sort1.clone(), sort2.clone()];
let mem_exec = MemoryExec::try_new(&[vec![]], schema, None)?
.with_sort_information(sort_information);
.try_with_sort_information(sort_information)?;

assert_eq!(
mem_exec.properties().output_ordering().unwrap(),
Expand Down
3 changes: 2 additions & 1 deletion datafusion/physical-plan/src/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,8 @@ mod test {
Arc::new(
MemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
.unwrap()
.with_sort_information(vec![sort_exprs]),
.try_with_sort_information(vec![sort_exprs])
.unwrap(),
)
}
}
4 changes: 2 additions & 2 deletions datafusion/physical-plan/src/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -809,11 +809,11 @@ mod tests {
.collect::<Vec<_>>();
let child1 = Arc::new(
MemoryExec::try_new(&[], Arc::clone(&schema), None)?
.with_sort_information(first_orderings),
.try_with_sort_information(first_orderings)?,
);
let child2 = Arc::new(
MemoryExec::try_new(&[], Arc::clone(&schema), None)?
.with_sort_information(second_orderings),
.try_with_sort_information(second_orderings)?,
);

let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema));
Expand Down

0 comments on commit e7ac843

Please sign in to comment.