From 9d1cf74aa7032fd5cfd8ebd0e3860ccf9ea5f5e8 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 15 Aug 2024 19:05:44 -0400 Subject: [PATCH] Remove physical sort parameters on aggregate window functions (#12009) * Remove order_by on aggregate window functions since that operation is handled by the window function * Add unit test for window functions using udaf with ordering * Resolve clippy warning --- datafusion/core/src/dataframe/mod.rs | 89 ++++++++++++++++++++- datafusion/physical-plan/src/windows/mod.rs | 1 - 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 25a8d1c87f00..3705873ce3bc 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1703,13 +1703,16 @@ mod tests { use arrow::array::{self, Int32Array}; use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; + use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, - ScalarFunctionImplementation, Volatility, WindowFunctionDefinition, + ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; + use sqlparser::ast::NullTreatment; // Get string representation of the plan async fn assert_physical_plan(df: &DataFrame, expected: Vec<&str>) { @@ -2355,6 +2358,90 @@ mod tests { Ok(()) } + #[tokio::test] + async fn window_using_aggregates() -> Result<()> { + // build plan using DataFrame API + let df = test_table().await?.filter(col("c1").eq(lit("a")))?; + let mut aggr_expr = vec![ + ( + datafusion_functions_aggregate::first_last::first_value_udaf(), + "first_value", + ), + ( + datafusion_functions_aggregate::first_last::last_value_udaf(), + "last_val", + ), + ( + datafusion_functions_aggregate::approx_distinct::approx_distinct_udaf(), + "approx_distinct", + ), + ( + datafusion_functions_aggregate::approx_median::approx_median_udaf(), + "approx_median", + ), + ( + datafusion_functions_aggregate::median::median_udaf(), + "median", + ), + (datafusion_functions_aggregate::min_max::max_udaf(), "max"), + (datafusion_functions_aggregate::min_max::min_udaf(), "min"), + ] + .into_iter() + .map(|(func, name)| { + let w = WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(func), + vec![col("c3")], + ); + + Expr::WindowFunction(w) + .null_treatment(NullTreatment::IgnoreNulls) + .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), + )) + .build() + .unwrap() + .alias(name) + }) + .collect::>(); + aggr_expr.extend_from_slice(&[col("c2"), col("c3")]); + + let df: Vec = df.select(aggr_expr)?.collect().await?; + + assert_batches_sorted_eq!( + ["+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", + "| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |", + "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", + "| | | | | | | | 1 | -85 |", + "| -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 |", + "| -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 |", + "| -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 |", + "| -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 |", + "| -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 |", + "| -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 |", + "| -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 |", + "| -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 |", + "| -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 |", + "| -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 |", + "| -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 |", + "| -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 |", + "| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |", + "| -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 |", + "| -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 |", + "| -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 |", + "| -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 |", + "| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |", + "| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |", + "| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |", + "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+"], + &df + ); + + Ok(()) + } + // Test issue: https://github.com/apache/datafusion/issues/10346 #[tokio::test] async fn test_select_over_aggregate_schema() -> Result<()> { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 1fd0ca36b1eb..03090faf3efd 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -113,7 +113,6 @@ pub fn create_window_expr( let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) .schema(Arc::new(input_schema.clone())) .alias(name) - .order_by(order_by.to_vec()) .with_ignore_nulls(ignore_nulls) .build()?; window_expr_from_aggregate_expr(