From b94f70fdcfb41b5661879df489298c6072335a51 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 16 Feb 2024 22:37:21 +0800 Subject: [PATCH 01/46] first draft Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 14 +- .../user_defined/user_defined_aggregates.rs | 128 ++++++++++++++++- datafusion/expr/src/expr_fn.rs | 130 +++++++++++++++++- datafusion/expr/src/function.rs | 12 +- .../physical-expr/src/aggregate/first_last.rs | 2 +- .../physical-expr/src/expressions/mod.rs | 2 +- datafusion/physical-plan/src/udaf.rs | 9 +- datafusion/physical-plan/src/windows/mod.rs | 11 +- datafusion/proto/src/physical_plan/mod.rs | 4 +- datafusion/sql/src/expr/function.rs | 5 +- 10 files changed, 294 insertions(+), 23 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d348e28ededa..c32443ce0098 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -246,7 +246,6 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { args, filter, order_by, - null_treatment: _, }) => match func_def { AggregateFunctionDefinition::BuiltIn(..) => { create_function_physical_name(func_def.name(), *distinct, args) @@ -258,11 +257,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { "aggregate expression with filter is not supported" ); } - if order_by.is_some() { - return exec_err!( - "aggregate expression with order_by is not supported" - ); - } + let names = args .iter() .map(|e| create_physical_name(e, false)) @@ -1709,13 +1704,16 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( (agg_expr, filter, order_by) } AggregateFunctionDefinition::UDF(fun) => { + let ordering_reqs: Vec = + order_by.clone().unwrap_or(vec![]); let agg_expr = udaf::create_aggregate_expr( fun, &args, + &ordering_reqs, physical_input_schema, name, - ); - (agg_expr?, filter, order_by) + )?; + (agg_expr, filter, order_by) } AggregateFunctionDefinition::Name(_) => { return internal_err!( diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 9e231d25f298..a36aec698438 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -20,7 +20,7 @@ use arrow::{array::AsArray, datatypes::Fields}; use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::Schema; +use arrow_schema::{Schema, SortOptions}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -42,11 +42,15 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; -use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; +use datafusion_common::{ + assert_contains, cast::as_primitive_array, exec_err, Column, DataFusionError, +}; use datafusion_expr::{ - create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF, + create_udaf, create_udaf_with_ordering, expr::Sort, AggregateUDFImpl, Expr, + GroupsAccumulator, SimpleAggregateUDF, }; -use datafusion_physical_expr::expressions::AvgAccumulator; +use datafusion_physical_expr::expressions::{self, FirstValueAccumulator}; +use datafusion_physical_expr::{expressions::AvgAccumulator, PhysicalSortExpr}; /// Test to show the contents of the setup #[tokio::test] @@ -208,6 +212,122 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } +#[tokio::test] +async fn simple_udaf_order() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(Int32Array::from(vec![1, 1, 2, 2])), + ], + )?; + + let ctx = SessionContext::new(); + + let provider = MemTable::try_new(Arc::new(schema.clone()), vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + // let expected_result = ctx + // .sql("SELECT FIRST_VALUE(a order by a desc) FROM t group by b order by b") + // .await? + // .collect() + // .await?; + + fn create_accumulator( + data_type: &DataType, + order_by: Vec>, + schema: &Schema, + ) -> Result> { + let mut all_sort_orders = vec![]; + + assert_eq!(order_by.len(), 1); + + for exprs in order_by { + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in exprs { + if let Expr::Sort(sort) = expr { + if let Expr::Column(col) = sort.expr.as_ref() { + let name = &col.name; + let e = expressions::col(name, schema)?; + sort_exprs.push(PhysicalSortExpr { + expr: e, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); + } + } + } + if !sort_exprs.is_empty() { + all_sort_orders.extend(sort_exprs); + } + } + + let ordering_req = all_sort_orders; + + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(schema)) + .collect::>>()?; + + let acc = FirstValueAccumulator::try_new( + data_type, + ordering_types.as_slice(), + ordering_req, + )?; + // let acc = FirstValueAccumulator::try_new(data_type, &[], vec![])?; + Ok(Box::new(acc)) + } + + let order_by = Expr::Sort(Sort { + expr: Box::new(Expr::Column(Column::new(Some("t"), "a"))), + asc: false, + nulls_first: false, + }); + let order_by = vec![vec![order_by]]; + + // define a udaf, using a DataFusion's accumulator + let my_first = create_udaf_with_ordering( + "my_first", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + // Arc::new(|d| create_accumulator(d, None, &dfs, &p, &schema)), + Arc::new(|d, order_by, schema| create_accumulator(d, order_by, schema)), + Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), + order_by, + schema, + ); + + ctx.register_udaf(my_first); + + // Should be the same as `SELECT FIRST_VALUE(a order by a) FROM t group by b order by b` + let result = ctx + .sql("SELECT MY_FIRST(a order by a desc) FROM t group by b order by b") + .await? + .collect() + .await?; + + let expected = [ + "+---------------+", + "| my_first(t.a) |", + "+---------------+", + "| 2 |", + "| 4 |", + "+---------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +} + /// tests the creation, registration and usage of a UDAF #[tokio::test] async fn simple_udaf() -> Result<()> { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 99f44a73c1dd..201d4fca2814 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -21,7 +21,9 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Placeholder, ScalarFunction, TryCast, }; -use crate::function::PartitionEvaluatorFactory; +use crate::function::{ + AccumulatorFactoryFunctionWithOrdering, PartitionEvaluatorFactory, +}; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, @@ -29,7 +31,7 @@ use crate::{ ScalarUDF, Signature, Volatility, }; use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::{Column, Result}; use std::any::Any; use std::fmt::Debug; @@ -1017,6 +1019,34 @@ pub fn create_udaf( )) } +// TODO: Merge with ordering +/// Creates a new UDAF with a specific signature, state type and return type. +/// The signature and state type must match the `Accumulator's implementation`. +pub fn create_udaf_with_ordering( + name: &str, + input_type: Vec, + return_type: Arc, + volatility: Volatility, + accumulator: AccumulatorFactoryFunctionWithOrdering, + state_type: Arc>, + ordering_req: Vec>, + schema: Schema, +) -> AggregateUDF { + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); + + AggregateUDF::from(SimpleOrderedAggregateUDF::new( + name, + input_type, + return_type, + volatility, + accumulator, + state_type, + ordering_req, + schema, + )) +} + /// Implements [`AggregateUDFImpl`] for functions that have a single signature and /// return type. pub struct SimpleAggregateUDF { @@ -1103,6 +1133,102 @@ impl AggregateUDFImpl for SimpleAggregateUDF { } } +/// Implements [`AggregateUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleOrderedAggregateUDF { + name: String, + signature: Signature, + return_type: DataType, + accumulator: AccumulatorFactoryFunctionWithOrdering, + state_type: Vec, + ordering_req: Vec>, + schema: Schema, +} + +impl Debug for SimpleOrderedAggregateUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl SimpleOrderedAggregateUDF { + /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and + /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_type: Vec, + return_type: DataType, + volatility: Volatility, + accumulator: AccumulatorFactoryFunctionWithOrdering, + state_type: Vec, + ordering_req: Vec>, + schema: Schema, + ) -> Self { + let name = name.into(); + let signature = Signature::exact(input_type, volatility); + Self { + name, + signature, + return_type, + accumulator, + state_type, + ordering_req, + schema, + } + } + + pub fn new_with_signature( + name: impl Into, + signature: Signature, + return_type: DataType, + accumulator: AccumulatorFactoryFunctionWithOrdering, + state_type: Vec, + ordering_req: Vec>, + schema: Schema, + ) -> Self { + let name = name.into(); + Self { + name, + signature, + return_type, + accumulator, + state_type, + ordering_req, + schema, + } + } +} + +impl AggregateUDFImpl for SimpleOrderedAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn accumulator(&self, arg: &DataType) -> Result> { + (self.accumulator)(arg, self.ordering_req.clone(), &self.schema) + } + + fn state_type(&self, _return_type: &DataType) -> Result> { + Ok(self.state_type.clone()) + } +} + /// Creates a new UDWF with a specific signature, state type and return type. /// /// The signature and state type must match the [`PartitionEvaluator`]'s implementation`. diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 3e30a5574be0..1c00dcd665f5 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,9 +17,9 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature}; +use crate::{Accumulator, BuiltinScalarFunction, Expr, PartitionEvaluator, Signature}; use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::utils::datafusion_strsim; use datafusion_common::Result; use std::sync::Arc; @@ -45,6 +45,14 @@ pub type ReturnTypeFunction = pub type AccumulatorFactoryFunction = Arc Result> + Send + Sync>; +/// Factory that returns an accumulator for the given aggregate, given +/// its return datatype, the ordering of the input arguments and the schema that are needed for ordering. +pub type AccumulatorFactoryFunctionWithOrdering = Arc< + dyn Fn(&DataType, Vec>, &Schema) -> Result> + + Send + + Sync, +>; + /// Factory that creates a PartitionEvaluator for the given window /// function pub type PartitionEvaluatorFactory = diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 17dd3ef1206d..264f52f93a3a 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -210,7 +210,7 @@ impl PartialEq for FirstValue { } #[derive(Debug)] -struct FirstValueAccumulator { +pub struct FirstValueAccumulator { first: ScalarValue, // At the beginning, `is_set` is false, which means `first` is not seen yet. // Once we see the first value, we set the `is_set` flag and do not update `first` anymore. diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 26d649f57201..d900c00e9db0 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -54,7 +54,7 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::first_last::{FirstValue, LastValue}; +pub use crate::aggregate::first_last::{FirstValue, FirstValueAccumulator, LastValue}; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index fd9279dfd552..f30f95924a82 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -30,7 +30,7 @@ use arrow::{ use super::{expressions::format_state_name, Accumulator, AggregateExpr}; use datafusion_common::{not_impl_err, Result}; pub use datafusion_expr::AggregateUDF; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; use std::sync::Arc; @@ -40,6 +40,7 @@ use std::sync::Arc; pub fn create_aggregate_expr( fun: &AggregateUDF, input_phy_exprs: &[Arc], + ordering_req: &[PhysicalSortExpr], input_schema: &Schema, name: impl Into, ) -> Result> { @@ -53,6 +54,7 @@ pub fn create_aggregate_expr( args: input_phy_exprs.to_vec(), data_type: fun.return_type(&input_exprs_types)?, name: name.into(), + ordering_req: ordering_req.to_vec(), })) } @@ -64,6 +66,7 @@ pub struct AggregateFunctionExpr { /// Output / return type of this aggregate data_type: DataType, name: String, + ordering_req: LexOrdering, } impl AggregateFunctionExpr { @@ -175,6 +178,10 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_groups_accumulator(&self) -> Result> { self.fun.create_groups_accumulator() } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + } } impl PartialEq for AggregateFunctionExpr { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 54731f0d812b..4ddcfd38f1db 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -92,8 +92,15 @@ pub fn create_window_expr( )) } WindowFunctionDefinition::AggregateUDF(fun) => { - let aggregate = - udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; + // TODO: Ordering not supported for Window UDFs + let ordering_req = &[]; + let aggregate = udaf::create_aggregate_expr( + fun.as_ref(), + args, + ordering_req, + input_schema, + name, + )?; window_expr_from_aggregate_expr( partition_by, order_by, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index a4c08d76867d..f6ff2d3efa40 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -473,7 +473,9 @@ impl AsExecutionPlan for PhysicalPlanNode { } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &physical_schema, name) + // TODO: Ordering not supported for UDAF here + let ordering_req = &[]; + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, ordering_req, &physical_schema, name) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index bcf641e4b5a0..bb0a71a7ab5b 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -176,9 +176,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { + let order_by = + self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; + let order_by = (!order_by.is_empty()).then_some(order_by); let args = self.function_args_to_expr(args, schema, planner_context)?; return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fm, args, false, None, None, + fm, args, false, None, order_by, ))); } From c743d13febd81654ccde0fb8a7ec42ac9e93d45a Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 18 Feb 2024 16:35:27 +0800 Subject: [PATCH 02/46] clippy fix Signed-off-by: jayzhan211 --- .../core/tests/user_defined/user_defined_aggregates.rs | 9 +-------- datafusion/expr/src/expr_fn.rs | 3 ++- datafusion/proto/tests/cases/roundtrip_physical_plan.rs | 1 + 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index a36aec698438..10f1494799a2 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -232,12 +232,6 @@ async fn simple_udaf_order() -> Result<()> { let provider = MemTable::try_new(Arc::new(schema.clone()), vec![vec![batch]])?; ctx.register_table("t", Arc::new(provider))?; - // let expected_result = ctx - // .sql("SELECT FIRST_VALUE(a order by a desc) FROM t group by b order by b") - // .await? - // .collect() - // .await?; - fn create_accumulator( data_type: &DataType, order_by: Vec>, @@ -299,8 +293,7 @@ async fn simple_udaf_order() -> Result<()> { vec![DataType::Int32], Arc::new(DataType::Int32), Volatility::Immutable, - // Arc::new(|d| create_accumulator(d, None, &dfs, &p, &schema)), - Arc::new(|d, order_by, schema| create_accumulator(d, order_by, schema)), + Arc::new(create_accumulator), Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), order_by, schema, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 201d4fca2814..4a07bf7bfb43 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1019,9 +1019,9 @@ pub fn create_udaf( )) } -// TODO: Merge with ordering /// Creates a new UDAF with a specific signature, state type and return type. /// The signature and state type must match the `Accumulator's implementation`. +#[allow(clippy::too_many_arguments)] pub fn create_udaf_with_ordering( name: &str, input_type: Vec, @@ -1158,6 +1158,7 @@ impl Debug for SimpleOrderedAggregateUDF { impl SimpleOrderedAggregateUDF { /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility + #[allow(clippy::too_many_arguments)] pub fn new( name: impl Into, input_type: Vec, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7df22e01469b..243e3777db9d 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -425,6 +425,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let aggregates: Vec> = vec![udaf::create_aggregate_expr( &udaf, &[col("b", &schema)?], + &[], &schema, "example_agg", )?]; From 3a7e9652af4ac039b4414963525bd87aff1cb90d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 18 Feb 2024 17:04:25 +0800 Subject: [PATCH 03/46] cleanup Signed-off-by: jayzhan211 --- .../user_defined/user_defined_aggregates.rs | 1 - datafusion/expr/src/expr_fn.rs | 21 ------------------- datafusion/physical-plan/src/windows/mod.rs | 2 +- datafusion/proto/src/physical_plan/mod.rs | 1 - 4 files changed, 1 insertion(+), 24 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 10f1494799a2..711f27443712 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -276,7 +276,6 @@ async fn simple_udaf_order() -> Result<()> { ordering_types.as_slice(), ordering_req, )?; - // let acc = FirstValueAccumulator::try_new(data_type, &[], vec![])?; Ok(Box::new(acc)) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4a07bf7bfb43..dae7269fa285 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1181,27 +1181,6 @@ impl SimpleOrderedAggregateUDF { schema, } } - - pub fn new_with_signature( - name: impl Into, - signature: Signature, - return_type: DataType, - accumulator: AccumulatorFactoryFunctionWithOrdering, - state_type: Vec, - ordering_req: Vec>, - schema: Schema, - ) -> Self { - let name = name.into(); - Self { - name, - signature, - return_type, - accumulator, - state_type, - ordering_req, - schema, - } - } } impl AggregateUDFImpl for SimpleOrderedAggregateUDF { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 4ddcfd38f1db..5143ccc49944 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -92,7 +92,7 @@ pub fn create_window_expr( )) } WindowFunctionDefinition::AggregateUDF(fun) => { - // TODO: Ordering not supported for Window UDFs + // TODO: Ordering not supported for Window UDFs yet let ordering_req = &[]; let aggregate = udaf::create_aggregate_expr( fun.as_ref(), diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index f6ff2d3efa40..43fb7e9e87d0 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -473,7 +473,6 @@ impl AsExecutionPlan for PhysicalPlanNode { } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; - // TODO: Ordering not supported for UDAF here let ordering_req = &[]; udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, ordering_req, &physical_schema, name) } From 4917f56d33d8b6ff02c9797cacd99e17c3e6419b Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 21 Feb 2024 21:55:11 +0800 Subject: [PATCH 04/46] use one vector for ordering req Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 2 +- .../user_defined/user_defined_aggregates.rs | 49 +++++++++---------- datafusion/expr/src/expr_fn.rs | 8 +-- datafusion/expr/src/function.rs | 2 +- datafusion/expr/src/udaf.rs | 2 +- 5 files changed, 31 insertions(+), 32 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c32443ce0098..528f746ca9e6 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -251,7 +251,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { create_function_physical_name(func_def.name(), *distinct, args) } AggregateFunctionDefinition::UDF(fun) => { - // TODO: Add support for filter and order by in AggregateUDF + // TODO: Add support for filter by in AggregateUDF if filter.is_some() { return exec_err!( "aggregate expression with filter is not supported" diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 711f27443712..5f20a8bc0fc9 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -234,41 +234,40 @@ async fn simple_udaf_order() -> Result<()> { fn create_accumulator( data_type: &DataType, - order_by: Vec>, - schema: &Schema, + order_by: Vec, + schema: Option, ) -> Result> { + // test with ordering so schema is required + let schema = schema.unwrap(); + let mut all_sort_orders = vec![]; - assert_eq!(order_by.len(), 1); - - for exprs in order_by { - // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = vec![]; - for expr in exprs { - if let Expr::Sort(sort) = expr { - if let Expr::Column(col) = sort.expr.as_ref() { - let name = &col.name; - let e = expressions::col(name, schema)?; - sort_exprs.push(PhysicalSortExpr { - expr: e, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); - } + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in order_by { + if let Expr::Sort(sort) = expr { + if let Expr::Column(col) = sort.expr.as_ref() { + let name = &col.name; + let e = expressions::col(name, &schema)?; + sort_exprs.push(PhysicalSortExpr { + expr: e, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); } } - if !sort_exprs.is_empty() { - all_sort_orders.extend(sort_exprs); - } + } + if !sort_exprs.is_empty() { + all_sort_orders.extend(sort_exprs); } let ordering_req = all_sort_orders; let ordering_types = ordering_req .iter() - .map(|e| e.expr.data_type(schema)) + .map(|e| e.expr.data_type(&schema)) .collect::>>()?; let acc = FirstValueAccumulator::try_new( @@ -284,7 +283,7 @@ async fn simple_udaf_order() -> Result<()> { asc: false, nulls_first: false, }); - let order_by = vec![vec![order_by]]; + let order_by = vec![order_by]; // define a udaf, using a DataFusion's accumulator let my_first = create_udaf_with_ordering( diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index dae7269fa285..fb8dd9cdf234 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1029,7 +1029,7 @@ pub fn create_udaf_with_ordering( volatility: Volatility, accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Arc>, - ordering_req: Vec>, + ordering_req: Vec, schema: Schema, ) -> AggregateUDF { let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); @@ -1141,7 +1141,7 @@ pub struct SimpleOrderedAggregateUDF { return_type: DataType, accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Vec, - ordering_req: Vec>, + ordering_req: Vec, schema: Schema, } @@ -1166,7 +1166,7 @@ impl SimpleOrderedAggregateUDF { volatility: Volatility, accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Vec, - ordering_req: Vec>, + ordering_req: Vec, schema: Schema, ) -> Self { let name = name.into(); @@ -1201,7 +1201,7 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { } fn accumulator(&self, arg: &DataType) -> Result> { - (self.accumulator)(arg, self.ordering_req.clone(), &self.schema) + (self.accumulator)(arg, self.ordering_req.clone(), Some(self.schema.clone())) } fn state_type(&self, _return_type: &DataType) -> Result> { diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 1c00dcd665f5..f10a82ba6f12 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -48,7 +48,7 @@ pub type AccumulatorFactoryFunction = /// Factory that returns an accumulator for the given aggregate, given /// its return datatype, the ordering of the input arguments and the schema that are needed for ordering. pub type AccumulatorFactoryFunctionWithOrdering = Arc< - dyn Fn(&DataType, Vec>, &Schema) -> Result> + dyn Fn(&DataType, Vec, Option) -> Result> + Send + Sync, >; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e56723063e41..9b883359b5e8 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -264,7 +264,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// If the aggregate expression has a specialized /// [`GroupsAccumulator`] implementation. If this returns true, - /// `[Self::create_groups_accumulator`] will be called. + /// `[Self::create_groups_accumulator]` will be called. fn groups_accumulator_supported(&self) -> bool { false } From c9e8641bb5b0136ea28ecf024a9bd24e7bd2eee9 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 21 Feb 2024 22:27:59 +0800 Subject: [PATCH 05/46] add sort exprs to accumulator Signed-off-by: jayzhan211 --- .../user_defined/user_defined_aggregates.rs | 22 ++++++------ datafusion/expr/src/expr_fn.rs | 30 ++++++++++++---- datafusion/expr/src/udaf.rs | 35 ++++++++++++++++--- datafusion/physical-plan/src/udaf.rs | 5 +-- 4 files changed, 67 insertions(+), 25 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 5f20a8bc0fc9..0f7416c0a419 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -278,13 +278,6 @@ async fn simple_udaf_order() -> Result<()> { Ok(Box::new(acc)) } - let order_by = Expr::Sort(Sort { - expr: Box::new(Expr::Column(Column::new(Some("t"), "a"))), - asc: false, - nulls_first: false, - }); - let order_by = vec![order_by]; - // define a udaf, using a DataFusion's accumulator let my_first = create_udaf_with_ordering( "my_first", @@ -293,8 +286,12 @@ async fn simple_udaf_order() -> Result<()> { Volatility::Immutable, Arc::new(create_accumulator), Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), - order_by, - schema, + vec![Expr::Sort(Sort { + expr: Box::new(Expr::Column(Column::new(Some("t"), "a"))), + asc: false, + nulls_first: false, + })], + Some(schema), ); ctx.register_udaf(my_first); @@ -791,7 +788,12 @@ impl AggregateUDFImpl for TestGroupsAccumulator { Ok(DataType::UInt64) } - fn accumulator(&self, _arg: &DataType) -> Result> { + fn accumulator( + &self, + _arg: &DataType, + _sort_exprs: Vec, + _schmea: Option, + ) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index fb8dd9cdf234..debff818b35f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1030,7 +1030,7 @@ pub fn create_udaf_with_ordering( accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Arc>, ordering_req: Vec, - schema: Schema, + schema: Option, ) -> AggregateUDF { let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); @@ -1124,7 +1124,12 @@ impl AggregateUDFImpl for SimpleAggregateUDF { Ok(self.return_type.clone()) } - fn accumulator(&self, arg: &DataType) -> Result> { + fn accumulator( + &self, + arg: &DataType, + sort_exprs: Vec, + schema: Option, + ) -> Result> { (self.accumulator)(arg) } @@ -1142,7 +1147,7 @@ pub struct SimpleOrderedAggregateUDF { accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Vec, ordering_req: Vec, - schema: Schema, + schema: Option, } impl Debug for SimpleOrderedAggregateUDF { @@ -1167,7 +1172,7 @@ impl SimpleOrderedAggregateUDF { accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Vec, ordering_req: Vec, - schema: Schema, + schema: Option, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -1200,13 +1205,26 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { Ok(self.return_type.clone()) } - fn accumulator(&self, arg: &DataType) -> Result> { - (self.accumulator)(arg, self.ordering_req.clone(), Some(self.schema.clone())) + fn accumulator( + &self, + arg: &DataType, + sort_exprs: Vec, + schema: Option, + ) -> Result> { + (self.accumulator)(arg, sort_exprs, schema) } fn state_type(&self, _return_type: &DataType) -> Result> { Ok(self.state_type.clone()) } + + fn sort_exprs(&self) -> Vec { + self.ordering_req.clone() + } + + fn schema(&self) -> Option { + self.schema.clone() + } } /// Creates a new UDWF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 9b883359b5e8..8707c803001b 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -22,8 +22,8 @@ use crate::{Accumulator, Expr}; use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; -use arrow::datatypes::DataType; -use datafusion_common::{not_impl_err, Result}; +use arrow::datatypes::{DataType, Schema}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -155,8 +155,11 @@ impl AggregateUDF { /// Return an accumulator the given aggregate, given /// its return datatype. + // pub fn accumulator(&self, return_type: &DataType, sort_exprs: Vec, schema: Option) -> Result> { pub fn accumulator(&self, return_type: &DataType) -> Result> { - self.inner.accumulator(return_type) + let sort_exprs = self.inner.sort_exprs(); + let schema = self.inner.schema(); + self.inner.accumulator(return_type, sort_exprs, schema) } /// Return the type of the intermediate state used by this aggregator, given @@ -174,6 +177,10 @@ impl AggregateUDF { pub fn create_groups_accumulator(&self) -> Result> { self.inner.create_groups_accumulator() } + + pub fn sort_exprs() -> Vec { + vec![] + } } impl From for AggregateUDF @@ -256,7 +263,12 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Return a new [`Accumulator`] that aggregates values for a specific /// group during query execution. - fn accumulator(&self, arg: &DataType) -> Result>; + fn accumulator( + &self, + arg: &DataType, + sort_exprs: Vec, + schema: Option, + ) -> Result>; /// Return the type used to serialize the [`Accumulator`]'s intermediate state. /// See [`Accumulator::state()`] for more details @@ -277,6 +289,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn create_groups_accumulator(&self) -> Result> { not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") } + + fn sort_exprs(&self) -> Vec { + vec![] + } + + fn schema(&self) -> Option { + None + } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers @@ -323,7 +343,12 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { Ok(res.as_ref().clone()) } - fn accumulator(&self, arg: &DataType) -> Result> { + fn accumulator( + &self, + arg: &DataType, + _sort_exprs: Vec, + _schema: Option, + ) -> Result> { (self.accumulator)(arg) } diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index f30f95924a82..3049f6cf642a 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -22,10 +22,7 @@ use fmt::Debug; use std::any::Any; use std::fmt; -use arrow::{ - datatypes::Field, - datatypes::{DataType, Schema}, -}; +use arrow::datatypes::{DataType, Field, Schema}; use super::{expressions::format_state_name, Accumulator, AggregateExpr}; use datafusion_common::{not_impl_err, Result}; From 3a5f0d1fe7bc6b32a5eb27285fe16a8aedfed733 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 21 Feb 2024 22:52:04 +0800 Subject: [PATCH 06/46] clippy Signed-off-by: jayzhan211 --- datafusion-examples/examples/advanced_udaf.rs | 10 ++++++++-- datafusion-examples/examples/simple_udaf.rs | 2 +- .../tests/user_defined/user_defined_aggregates.rs | 8 ++++---- .../user_defined/user_defined_scalar_functions.rs | 2 +- datafusion/expr/src/expr_fn.rs | 12 +++++------- datafusion/expr/src/function.rs | 9 ++------- datafusion/expr/src/udaf.rs | 6 +++--- datafusion/optimizer/src/analyzer/type_coercion.rs | 4 ++-- datafusion/optimizer/src/common_subexpr_eliminate.rs | 3 ++- datafusion/proto/src/bytes/mod.rs | 2 +- .../proto/tests/cases/roundtrip_logical_plan.rs | 4 ++-- .../proto/tests/cases/roundtrip_physical_plan.rs | 3 ++- .../substrait/tests/cases/roundtrip_logical_plan.rs | 2 +- 13 files changed, 34 insertions(+), 33 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 10164a850bfb..9fed6efe70f1 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow_schema::Schema; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; use datafusion_physical_expr::NullState; use std::{any::Any, sync::Arc}; @@ -85,7 +86,12 @@ impl AggregateUDFImpl for GeoMeanUdaf { /// is supported, DataFusion will use this row oriented /// accumulator when the aggregate function is used as a window function /// or when there are only aggregates (no GROUP BY columns) in the plan. - fn accumulator(&self, _arg: &DataType) -> Result> { + fn accumulator( + &self, + _arg: &DataType, + _sort_exprs: Vec, + _schema: Option, + ) -> Result> { Ok(Box::new(GeometricMean::new())) } @@ -191,7 +197,7 @@ impl Accumulator for GeometricMean { // create local session context with an in-memory table fn create_context() -> Result { - use datafusion::arrow::datatypes::{Field, Schema}; + use datafusion::arrow::datatypes::Field; use datafusion::datasource::MemTable; // define a schema. let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 0996a67245a8..3a62e28b0568 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -150,7 +150,7 @@ async fn main() -> Result<()> { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(GeometricMean::new()))), + Arc::new(|_, _, _| Ok(Box::new(GeometricMean::new()))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 0f7416c0a419..720eef42831c 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -341,7 +341,7 @@ async fn simple_udaf() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -397,7 +397,7 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -568,7 +568,7 @@ impl TimeSum { let captured_state = Arc::clone(&test_state); let accumulator: AccumulatorFactoryFunction = - Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); + Arc::new(move |_, _, _| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); let time_sum = AggregateUDF::from(SimpleAggregateUDF::new( name, @@ -667,7 +667,7 @@ impl FirstSelector { let signatures = vec![TypeSignature::Exact(Self::input_datatypes())]; let accumulator: AccumulatorFactoryFunction = - Arc::new(|_| Ok(Box::new(Self::new()))); + Arc::new(|_, _, _| Ok(Box::new(Self::new()))); let volatility = Volatility::Immutable; diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index d9b60134b3d9..4229220e697f 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -294,7 +294,7 @@ async fn udaf_as_window_func() -> Result<()> { vec![DataType::Int32], Arc::new(DataType::Int32), Volatility::Immutable, - Arc::new(|_| Ok(Box::new(MyAccumulator))), + Arc::new(|_, _, _| Ok(Box::new(MyAccumulator))), Arc::new(vec![DataType::Int32]), ); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index debff818b35f..c77c1aa48375 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -21,9 +21,7 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Placeholder, ScalarFunction, TryCast, }; -use crate::function::{ - AccumulatorFactoryFunctionWithOrdering, PartitionEvaluatorFactory, -}; +use crate::function::PartitionEvaluatorFactory; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, @@ -1027,7 +1025,7 @@ pub fn create_udaf_with_ordering( input_type: Vec, return_type: Arc, volatility: Volatility, - accumulator: AccumulatorFactoryFunctionWithOrdering, + accumulator: AccumulatorFactoryFunction, state_type: Arc>, ordering_req: Vec, schema: Option, @@ -1130,7 +1128,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { sort_exprs: Vec, schema: Option, ) -> Result> { - (self.accumulator)(arg) + (self.accumulator)(arg, sort_exprs, schema) } fn state_type(&self, _return_type: &DataType) -> Result> { @@ -1144,7 +1142,7 @@ pub struct SimpleOrderedAggregateUDF { name: String, signature: Signature, return_type: DataType, - accumulator: AccumulatorFactoryFunctionWithOrdering, + accumulator: AccumulatorFactoryFunction, state_type: Vec, ordering_req: Vec, schema: Option, @@ -1169,7 +1167,7 @@ impl SimpleOrderedAggregateUDF { input_type: Vec, return_type: DataType, volatility: Volatility, - accumulator: AccumulatorFactoryFunctionWithOrdering, + accumulator: AccumulatorFactoryFunction, state_type: Vec, ordering_req: Vec, schema: Option, diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index f10a82ba6f12..c1e162286cb6 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -41,13 +41,8 @@ pub type ReturnTypeFunction = Arc Result> + Send + Sync>; /// Factory that returns an accumulator for the given aggregate, given -/// its return datatype. -pub type AccumulatorFactoryFunction = - Arc Result> + Send + Sync>; - -/// Factory that returns an accumulator for the given aggregate, given -/// its return datatype, the ordering of the input arguments and the schema that are needed for ordering. -pub type AccumulatorFactoryFunctionWithOrdering = Arc< +/// its return datatype, the sorting expressions and the schema for ordering. +pub type AccumulatorFactoryFunction = Arc< dyn Fn(&DataType, Vec, Option) -> Result> + Send + Sync, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 8707c803001b..534e45e13521 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -346,10 +346,10 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { fn accumulator( &self, arg: &DataType, - _sort_exprs: Vec, - _schema: Option, + sort_exprs: Vec, + schema: Option, ) -> Result> { - (self.accumulator)(arg) + (self.accumulator)(arg, sort_exprs, schema) } fn state_type(&self, return_type: &DataType) -> Result> { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 496def95e1bc..7281425a5721 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -895,7 +895,7 @@ mod test { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( @@ -916,7 +916,7 @@ mod test { let return_type = DataType::Float64; let state_type = vec![DataType::UInt64, DataType::Float64]; let accumulator: AccumulatorFactoryFunction = - Arc::new(|_| Ok(Box::::default())); + Arc::new(|_, _, _| Ok(Box::::default())); let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "MY_AVG", Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 30c184a28e33..487e06674e7a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -975,7 +975,8 @@ mod test { let table_scan = test_table_scan()?; let return_type = DataType::UInt32; - let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_, _, _| unimplemented!()); let state_type = vec![DataType::UInt32]; let udf_agg = |inner: Expr| { Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 610c533d574c..4c570d343574 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -127,7 +127,7 @@ impl Serializeable for Expr { vec![arrow::datatypes::DataType::Null], Arc::new(arrow::datatypes::DataType::Null), Volatility::Immutable, - Arc::new(|_| unimplemented!()), + Arc::new(|_, _, _| unimplemented!()), Arc::new(vec![]), ))) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index fad50d3ecddc..42e8718ff097 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1738,7 +1738,7 @@ fn roundtrip_aggregate_udf() { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(Dummy {}))), + Arc::new(|_, _, _| Ok(Box::new(Dummy {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); @@ -1953,7 +1953,7 @@ fn roundtrip_window() { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(DummyAggr {}))), + Arc::new(|_, _, _| Ok(Box::new(DummyAggr {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 243e3777db9d..c7ff0cf782d0 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -405,7 +405,8 @@ fn roundtrip_aggregate_udaf() -> Result<()> { } let return_type = DataType::Int64; - let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example))); + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_, _, _| Ok(Box::new(Example))); let state_type = vec![DataType::Int64]; let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index bc9cc66b7626..a24a84ad76be 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -750,7 +750,7 @@ async fn roundtrip_aggregate_udf() -> Result<()> { Arc::new(DataType::Int64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(Dummy {}))), + Arc::new(|_, _, _| Ok(Box::new(Dummy {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); From a3ea00adf0595d6862ee6ce76fefd4564a3ff3b2 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 21 Feb 2024 22:55:08 +0800 Subject: [PATCH 07/46] cleanup Signed-off-by: jayzhan211 --- datafusion/expr/src/udaf.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 534e45e13521..1dc1cb12707c 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -153,9 +153,7 @@ impl AggregateUDF { self.inner.return_type(args) } - /// Return an accumulator the given aggregate, given - /// its return datatype. - // pub fn accumulator(&self, return_type: &DataType, sort_exprs: Vec, schema: Option) -> Result> { + /// Return an accumulator the given aggregate, given its return datatype pub fn accumulator(&self, return_type: &DataType) -> Result> { let sort_exprs = self.inner.sort_exprs(); let schema = self.inner.schema(); From f349f215a1e10b129c7a910378d91c32084168ab Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 21 Feb 2024 23:10:43 +0800 Subject: [PATCH 08/46] fix doc test Signed-off-by: jayzhan211 --- datafusion/expr/src/udaf.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 1dc1cb12707c..3409df62a0c6 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -205,8 +205,9 @@ where /// # use std::any::Any; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator}; +/// # use arrow::datatypes::Schema; /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { /// signature: Signature @@ -232,7 +233,7 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _arg: &DataType) -> Result> { unimplemented!() } +/// fn accumulator(&self, _arg: &DataType, _sort_exprs: Vec, _schema: Option) -> Result> { unimplemented!() } /// fn state_type(&self, _return_type: &DataType) -> Result> { /// Ok(vec![DataType::Float64, DataType::UInt32]) /// } From 6fcdaac8e484cef16df5d905c0178e17fa35c840 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 27 Feb 2024 22:46:24 +0800 Subject: [PATCH 09/46] change to ref Signed-off-by: jayzhan211 --- datafusion-examples/examples/advanced_udaf.rs | 2 +- .../tests/user_defined/user_defined_aggregates.rs | 10 +++++----- datafusion/expr/src/expr_fn.rs | 14 +++++++------- datafusion/expr/src/function.rs | 2 +- datafusion/expr/src/udaf.rs | 11 +++++++---- 5 files changed, 21 insertions(+), 18 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 9fed6efe70f1..ef43dc47f2ce 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -90,7 +90,7 @@ impl AggregateUDFImpl for GeoMeanUdaf { &self, _arg: &DataType, _sort_exprs: Vec, - _schema: Option, + _schema: Option<&Schema>, ) -> Result> { Ok(Box::new(GeometricMean::new())) } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 720eef42831c..1f637c0d97a7 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -235,7 +235,7 @@ async fn simple_udaf_order() -> Result<()> { fn create_accumulator( data_type: &DataType, order_by: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Result> { // test with ordering so schema is required let schema = schema.unwrap(); @@ -248,7 +248,7 @@ async fn simple_udaf_order() -> Result<()> { if let Expr::Sort(sort) = expr { if let Expr::Column(col) = sort.expr.as_ref() { let name = &col.name; - let e = expressions::col(name, &schema)?; + let e = expressions::col(name, schema)?; sort_exprs.push(PhysicalSortExpr { expr: e, options: SortOptions { @@ -267,7 +267,7 @@ async fn simple_udaf_order() -> Result<()> { let ordering_types = ordering_req .iter() - .map(|e| e.expr.data_type(&schema)) + .map(|e| e.expr.data_type(schema)) .collect::>>()?; let acc = FirstValueAccumulator::try_new( @@ -291,7 +291,7 @@ async fn simple_udaf_order() -> Result<()> { asc: false, nulls_first: false, })], - Some(schema), + Some(&schema), ); ctx.register_udaf(my_first); @@ -792,7 +792,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator { &self, _arg: &DataType, _sort_exprs: Vec, - _schmea: Option, + _schmea: Option<&Schema>, ) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c77c1aa48375..62856873c833 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1028,7 +1028,7 @@ pub fn create_udaf_with_ordering( accumulator: AccumulatorFactoryFunction, state_type: Arc>, ordering_req: Vec, - schema: Option, + schema: Option<&Schema>, ) -> AggregateUDF { let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); @@ -1126,7 +1126,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { &self, arg: &DataType, sort_exprs: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } @@ -1170,7 +1170,7 @@ impl SimpleOrderedAggregateUDF { accumulator: AccumulatorFactoryFunction, state_type: Vec, ordering_req: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -1181,7 +1181,7 @@ impl SimpleOrderedAggregateUDF { accumulator, state_type, ordering_req, - schema, + schema: schema.cloned(), } } } @@ -1207,7 +1207,7 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { &self, arg: &DataType, sort_exprs: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } @@ -1220,8 +1220,8 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { self.ordering_req.clone() } - fn schema(&self) -> Option { - self.schema.clone() + fn schema(&self) -> Option<&Schema> { + self.schema.as_ref() } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index c1e162286cb6..bbb631af8fd5 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -43,7 +43,7 @@ pub type ReturnTypeFunction = /// Factory that returns an accumulator for the given aggregate, given /// its return datatype, the sorting expressions and the schema for ordering. pub type AccumulatorFactoryFunction = Arc< - dyn Fn(&DataType, Vec, Option) -> Result> + dyn Fn(&DataType, Vec, Option<&Schema>) -> Result> + Send + Sync, >; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 3409df62a0c6..0e4785b25673 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -261,12 +261,13 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn return_type(&self, arg_types: &[DataType]) -> Result; /// Return a new [`Accumulator`] that aggregates values for a specific - /// group during query execution. + /// group during query execution. sort_exprs is a list of ordering expressions, + /// and schema is used while ordering. fn accumulator( &self, arg: &DataType, sort_exprs: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Result>; /// Return the type used to serialize the [`Accumulator`]'s intermediate state. @@ -289,11 +290,13 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") } + /// Return the ordering expressions for the accumulator fn sort_exprs(&self) -> Vec { vec![] } - fn schema(&self) -> Option { + /// Return the schema for the accumulator + fn schema(&self) -> Option<&Schema> { None } } @@ -346,7 +349,7 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { &self, arg: &DataType, sort_exprs: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } From c3512a6c150914f288f7dc19b4dea3274f9cc295 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 27 Feb 2024 22:47:09 +0800 Subject: [PATCH 10/46] fix typo Signed-off-by: jayzhan211 --- datafusion/core/tests/user_defined/user_defined_aggregates.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 1f637c0d97a7..41c19c3ebb7e 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -792,7 +792,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator { &self, _arg: &DataType, _sort_exprs: Vec, - _schmea: Option<&Schema>, + _schema: Option<&Schema>, ) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); From 092d46e50666a422748dbf4755fc8ef260ba3e7f Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 27 Feb 2024 23:03:35 +0800 Subject: [PATCH 11/46] fix doc Signed-off-by: jayzhan211 --- datafusion/expr/src/udaf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 0e4785b25673..2a55904e01b7 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -233,7 +233,7 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _arg: &DataType, _sort_exprs: Vec, _schema: Option) -> Result> { unimplemented!() } +/// fn accumulator(&self, _arg: &DataType, _sort_exprs: Vec, _schema: Option<&Schema>) -> Result> { unimplemented!() } /// fn state_type(&self, _return_type: &DataType) -> Result> { /// Ok(vec![DataType::Float64, DataType::UInt32]) /// } From 8592e6bccc64c45a7c120b94c02b1665bd81b170 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 1 Mar 2024 19:26:06 +0800 Subject: [PATCH 12/46] fmt Signed-off-by: jayzhan211 --- datafusion/expr/src/udaf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 2a55904e01b7..9b038b7c38a9 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -23,7 +23,7 @@ use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{not_impl_err, Result}; use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; From 0f8fc2427abc405feff8b2c8781a3ad1fb043307 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 1 Mar 2024 20:34:23 +0800 Subject: [PATCH 13/46] move schema and logical ordering exprs Signed-off-by: jayzhan211 --- datafusion-examples/examples/advanced_udaf.rs | 4 +-- datafusion/core/src/physical_planner.rs | 16 +++++++----- .../user_defined/user_defined_aggregates.rs | 26 +++++++------------ datafusion/expr/src/expr_fn.rs | 8 +++--- datafusion/expr/src/function.rs | 4 +-- datafusion/expr/src/udaf.rs | 19 +++++++++----- datafusion/physical-plan/src/udaf.rs | 22 ++++++++++++---- datafusion/physical-plan/src/windows/mod.rs | 3 +++ datafusion/proto/src/physical_plan/mod.rs | 2 +- .../tests/cases/roundtrip_physical_plan.rs | 1 + 10 files changed, 60 insertions(+), 45 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index ef43dc47f2ce..06995863a245 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -89,8 +89,8 @@ impl AggregateUDFImpl for GeoMeanUdaf { fn accumulator( &self, _arg: &DataType, - _sort_exprs: Vec, - _schema: Option<&Schema>, + _sort_exprs: &[Expr], + _schema: &Schema, ) -> Result> { Ok(Box::new(GeometricMean::new())) } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 528f746ca9e6..4d1cf1b28d8f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1672,7 +1672,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( )?), None => None, }; - let order_by = match order_by { + + let sort_exprs = order_by.clone().unwrap_or(vec![]); + let phy_order_by = match order_by { Some(e) => Some( e.iter() .map(|expr| { @@ -1691,7 +1693,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { - let ordering_reqs = order_by.clone().unwrap_or(vec![]); + let ordering_reqs = phy_order_by.clone().unwrap_or(vec![]); let agg_expr = aggregates::create_aggregate_expr( fun, *distinct, @@ -1701,19 +1703,21 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( name, ignore_nulls, )?; - (agg_expr, filter, order_by) + (agg_expr, filter, phy_order_by) } AggregateFunctionDefinition::UDF(fun) => { let ordering_reqs: Vec = - order_by.clone().unwrap_or(vec![]); + phy_order_by.clone().unwrap_or(vec![]); + let agg_expr = udaf::create_aggregate_expr( fun, &args, + &sort_exprs, &ordering_reqs, physical_input_schema, name, )?; - (agg_expr, filter, order_by) + (agg_expr, filter, phy_order_by) } AggregateFunctionDefinition::Name(_) => { return internal_err!( @@ -1721,7 +1725,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ) } }; - Ok((agg_expr, filter, order_by)) + Ok((agg_expr, filter, phy_order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 41c19c3ebb7e..85441838c19d 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -42,9 +42,7 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; -use datafusion_common::{ - assert_contains, cast::as_primitive_array, exec_err, Column, DataFusionError, -}; +use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err, Column}; use datafusion_expr::{ create_udaf, create_udaf_with_ordering, expr::Sort, AggregateUDFImpl, Expr, GroupsAccumulator, SimpleAggregateUDF, @@ -234,12 +232,9 @@ async fn simple_udaf_order() -> Result<()> { fn create_accumulator( data_type: &DataType, - order_by: Vec, - schema: Option<&Schema>, + order_by: &[Expr], + schema: &Schema, ) -> Result> { - // test with ordering so schema is required - let schema = schema.unwrap(); - let mut all_sort_orders = vec![]; // Construct PhysicalSortExpr objects from Expr objects: @@ -265,16 +260,13 @@ async fn simple_udaf_order() -> Result<()> { let ordering_req = all_sort_orders; - let ordering_types = ordering_req + let ordering_dtypes = ordering_req .iter() .map(|e| e.expr.data_type(schema)) .collect::>>()?; - let acc = FirstValueAccumulator::try_new( - data_type, - ordering_types.as_slice(), - ordering_req, - )?; + let acc = + FirstValueAccumulator::try_new(data_type, &ordering_dtypes, ordering_req)?; Ok(Box::new(acc)) } @@ -369,7 +361,7 @@ async fn deregister_udaf() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -791,8 +783,8 @@ impl AggregateUDFImpl for TestGroupsAccumulator { fn accumulator( &self, _arg: &DataType, - _sort_exprs: Vec, - _schema: Option<&Schema>, + _sort_exprs: &[Expr], + _schema: &Schema, ) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 62856873c833..3ba4808b9694 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1125,8 +1125,8 @@ impl AggregateUDFImpl for SimpleAggregateUDF { fn accumulator( &self, arg: &DataType, - sort_exprs: Vec, - schema: Option<&Schema>, + sort_exprs: &[Expr], + schema: &Schema, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } @@ -1206,8 +1206,8 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { fn accumulator( &self, arg: &DataType, - sort_exprs: Vec, - schema: Option<&Schema>, + sort_exprs: &[Expr], + schema: &Schema, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index bbb631af8fd5..7a92724227e9 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -43,9 +43,7 @@ pub type ReturnTypeFunction = /// Factory that returns an accumulator for the given aggregate, given /// its return datatype, the sorting expressions and the schema for ordering. pub type AccumulatorFactoryFunction = Arc< - dyn Fn(&DataType, Vec, Option<&Schema>) -> Result> - + Send - + Sync, + dyn Fn(&DataType, &[Expr], &Schema) -> Result> + Send + Sync, >; /// Factory that creates a PartitionEvaluator for the given window diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 9b038b7c38a9..121cd6834306 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -154,9 +154,14 @@ impl AggregateUDF { } /// Return an accumulator the given aggregate, given its return datatype - pub fn accumulator(&self, return_type: &DataType) -> Result> { - let sort_exprs = self.inner.sort_exprs(); - let schema = self.inner.schema(); + pub fn accumulator( + &self, + return_type: &DataType, + sort_exprs: &[Expr], + schema: &Schema, + ) -> Result> { + // let sort_exprs = self.inner.sort_exprs(); + // let schema = self.inner.schema(); self.inner.accumulator(return_type, sort_exprs, schema) } @@ -266,8 +271,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn accumulator( &self, arg: &DataType, - sort_exprs: Vec, - schema: Option<&Schema>, + sort_exprs: &[Expr], + schema: &Schema, ) -> Result>; /// Return the type used to serialize the [`Accumulator`]'s intermediate state. @@ -348,8 +353,8 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { fn accumulator( &self, arg: &DataType, - sort_exprs: Vec, - schema: Option<&Schema>, + sort_exprs: &[Expr], + schema: &Schema, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 3049f6cf642a..b74a2d971d36 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -17,7 +17,7 @@ //! This module contains functions and structs supporting user-defined aggregate functions. -use datafusion_expr::GroupsAccumulator; +use datafusion_expr::{Expr, GroupsAccumulator}; use fmt::Debug; use std::any::Any; use std::fmt; @@ -37,13 +37,14 @@ use std::sync::Arc; pub fn create_aggregate_expr( fun: &AggregateUDF, input_phy_exprs: &[Arc], + sort_exprs: &[Expr], ordering_req: &[PhysicalSortExpr], - input_schema: &Schema, + schema: &Schema, name: impl Into, ) -> Result> { let input_exprs_types = input_phy_exprs .iter() - .map(|arg| arg.data_type(input_schema)) + .map(|arg| arg.data_type(schema)) .collect::>>()?; Ok(Arc::new(AggregateFunctionExpr { @@ -51,6 +52,8 @@ pub fn create_aggregate_expr( args: input_phy_exprs.to_vec(), data_type: fun.return_type(&input_exprs_types)?, name: name.into(), + schema: schema.clone(), + sort_exprs: sort_exprs.to_vec(), ordering_req: ordering_req.to_vec(), })) } @@ -63,6 +66,10 @@ pub struct AggregateFunctionExpr { /// Output / return type of this aggregate data_type: DataType, name: String, + schema: Schema, + // The logical order by expressions + sort_exprs: Vec, + // The physical order by expressions ordering_req: LexOrdering, } @@ -106,11 +113,16 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - self.fun.accumulator(&self.data_type) + self.fun + .accumulator(&self.data_type, self.sort_exprs.as_slice(), &self.schema) } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = self.fun.accumulator(&self.data_type)?; + let accumulator = self.fun.accumulator( + &self.data_type, + self.sort_exprs.as_slice(), + &self.schema, + )?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 5143ccc49944..3082bec9134d 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -93,10 +93,13 @@ pub fn create_window_expr( } WindowFunctionDefinition::AggregateUDF(fun) => { // TODO: Ordering not supported for Window UDFs yet + let sort_exprs = &[]; let ordering_req = &[]; + let aggregate = udaf::create_aggregate_expr( fun.as_ref(), args, + sort_exprs, ordering_req, input_schema, name, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 43fb7e9e87d0..b8da5ea7a092 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -474,7 +474,7 @@ impl AsExecutionPlan for PhysicalPlanNode { AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; let ordering_req = &[]; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, ordering_req, &physical_schema, name) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &[], ordering_req, &physical_schema, name) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index c7ff0cf782d0..a51da2ead738 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -427,6 +427,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { &udaf, &[col("b", &schema)?], &[], + &[], &schema, "example_agg", )?]; From 3185f9f77280c8a3c2fc2d65016b0a7fce045f63 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 1 Mar 2024 20:40:59 +0800 Subject: [PATCH 14/46] remove redudant info Signed-off-by: jayzhan211 --- .../user_defined/user_defined_aggregates.rs | 12 +++--------- datafusion/expr/src/expr_fn.rs | 18 ------------------ datafusion/expr/src/udaf.rs | 10 ---------- 3 files changed, 3 insertions(+), 37 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 85441838c19d..e9583fff35c3 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -42,10 +42,10 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; -use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err, Column}; +use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, create_udaf_with_ordering, expr::Sort, AggregateUDFImpl, Expr, - GroupsAccumulator, SimpleAggregateUDF, + create_udaf, create_udaf_with_ordering, AggregateUDFImpl, Expr, GroupsAccumulator, + SimpleAggregateUDF, }; use datafusion_physical_expr::expressions::{self, FirstValueAccumulator}; use datafusion_physical_expr::{expressions::AvgAccumulator, PhysicalSortExpr}; @@ -278,12 +278,6 @@ async fn simple_udaf_order() -> Result<()> { Volatility::Immutable, Arc::new(create_accumulator), Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), - vec![Expr::Sort(Sort { - expr: Box::new(Expr::Column(Column::new(Some("t"), "a"))), - asc: false, - nulls_first: false, - })], - Some(&schema), ); ctx.register_udaf(my_first); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 3ba4808b9694..1a84a49d0949 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1027,8 +1027,6 @@ pub fn create_udaf_with_ordering( volatility: Volatility, accumulator: AccumulatorFactoryFunction, state_type: Arc>, - ordering_req: Vec, - schema: Option<&Schema>, ) -> AggregateUDF { let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); @@ -1040,8 +1038,6 @@ pub fn create_udaf_with_ordering( volatility, accumulator, state_type, - ordering_req, - schema, )) } @@ -1144,8 +1140,6 @@ pub struct SimpleOrderedAggregateUDF { return_type: DataType, accumulator: AccumulatorFactoryFunction, state_type: Vec, - ordering_req: Vec, - schema: Option, } impl Debug for SimpleOrderedAggregateUDF { @@ -1169,8 +1163,6 @@ impl SimpleOrderedAggregateUDF { volatility: Volatility, accumulator: AccumulatorFactoryFunction, state_type: Vec, - ordering_req: Vec, - schema: Option<&Schema>, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -1180,8 +1172,6 @@ impl SimpleOrderedAggregateUDF { return_type, accumulator, state_type, - ordering_req, - schema: schema.cloned(), } } } @@ -1215,14 +1205,6 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { fn state_type(&self, _return_type: &DataType) -> Result> { Ok(self.state_type.clone()) } - - fn sort_exprs(&self) -> Vec { - self.ordering_req.clone() - } - - fn schema(&self) -> Option<&Schema> { - self.schema.as_ref() - } } /// Creates a new UDWF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 121cd6834306..b5caf860163d 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -294,16 +294,6 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn create_groups_accumulator(&self) -> Result> { not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") } - - /// Return the ordering expressions for the accumulator - fn sort_exprs(&self) -> Vec { - vec![] - } - - /// Return the schema for the accumulator - fn schema(&self) -> Option<&Schema> { - None - } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers From 3ecc772c2a99b185741868d562257a2e84d15dcb Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 1 Mar 2024 20:44:24 +0800 Subject: [PATCH 15/46] rename Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 4d1cf1b28d8f..e065a0d791e4 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -246,6 +246,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { args, filter, order_by, + null_treatment: _, }) => match func_def { AggregateFunctionDefinition::BuiltIn(..) => { create_function_physical_name(func_def.name(), *distinct, args) @@ -1674,7 +1675,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( }; let sort_exprs = order_by.clone().unwrap_or(vec![]); - let phy_order_by = match order_by { + let order_by = match order_by { Some(e) => Some( e.iter() .map(|expr| { @@ -1693,7 +1694,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { - let ordering_reqs = phy_order_by.clone().unwrap_or(vec![]); + let ordering_reqs = order_by.clone().unwrap_or(vec![]); let agg_expr = aggregates::create_aggregate_expr( fun, *distinct, @@ -1703,11 +1704,11 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( name, ignore_nulls, )?; - (agg_expr, filter, phy_order_by) + (agg_expr, filter, order_by) } AggregateFunctionDefinition::UDF(fun) => { let ordering_reqs: Vec = - phy_order_by.clone().unwrap_or(vec![]); + order_by.clone().unwrap_or(vec![]); let agg_expr = udaf::create_aggregate_expr( fun, @@ -1717,7 +1718,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( physical_input_schema, name, )?; - (agg_expr, filter, phy_order_by) + (agg_expr, filter, order_by) } AggregateFunctionDefinition::Name(_) => { return internal_err!( @@ -1725,7 +1726,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ) } }; - Ok((agg_expr, filter, phy_order_by)) + Ok((agg_expr, filter, order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } From faadc63ce4607090cdc7b3cfa8860bc18fc28ac1 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 1 Mar 2024 20:55:15 +0800 Subject: [PATCH 16/46] cleanup Signed-off-by: jayzhan211 --- datafusion/expr/src/expr_fn.rs | 2 -- datafusion/expr/src/udaf.rs | 19 ++++++++++--------- datafusion/physical-plan/src/udaf.rs | 8 +++----- datafusion/proto/src/physical_plan/mod.rs | 4 +++- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1a84a49d0949..4808c7197cfa 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1019,7 +1019,6 @@ pub fn create_udaf( /// Creates a new UDAF with a specific signature, state type and return type. /// The signature and state type must match the `Accumulator's implementation`. -#[allow(clippy::too_many_arguments)] pub fn create_udaf_with_ordering( name: &str, input_type: Vec, @@ -1155,7 +1154,6 @@ impl Debug for SimpleOrderedAggregateUDF { impl SimpleOrderedAggregateUDF { /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility - #[allow(clippy::too_many_arguments)] pub fn new( name: impl Into, input_type: Vec, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b5caf860163d..8119a45e9f0c 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -160,8 +160,6 @@ impl AggregateUDF { sort_exprs: &[Expr], schema: &Schema, ) -> Result> { - // let sort_exprs = self.inner.sort_exprs(); - // let schema = self.inner.schema(); self.inner.accumulator(return_type, sort_exprs, schema) } @@ -180,10 +178,6 @@ impl AggregateUDF { pub fn create_groups_accumulator(&self) -> Result> { self.inner.create_groups_accumulator() } - - pub fn sort_exprs() -> Vec { - vec![] - } } impl From for AggregateUDF @@ -238,7 +232,7 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _arg: &DataType, _sort_exprs: Vec, _schema: Option<&Schema>) -> Result> { unimplemented!() } +/// fn accumulator(&self, _arg: &DataType, _sort_exprs: &[Expr], _schema: &Schema) -> Result> { unimplemented!() } /// fn state_type(&self, _return_type: &DataType) -> Result> { /// Ok(vec![DataType::Float64, DataType::UInt32]) /// } @@ -266,8 +260,15 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn return_type(&self, arg_types: &[DataType]) -> Result; /// Return a new [`Accumulator`] that aggregates values for a specific - /// group during query execution. sort_exprs is a list of ordering expressions, - /// and schema is used while ordering. + /// group during query execution. + /// + /// `arg`: the type of the argument to this accumulator + /// + /// `sort_exprs`: contains a list of `Expr::SortExpr`s if the + /// aggregate is called with an explicit `ORDER BY`. For example, + /// `ARRAY_AGG(x ORDER BY y ASC)`. In this case, `sort_exprs` would contain `[y ASC]` + /// + /// `schema` is the input schema to the udaf fn accumulator( &self, arg: &DataType, diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index b74a2d971d36..3d15e3d012c5 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -118,11 +118,9 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = self.fun.accumulator( - &self.data_type, - self.sort_exprs.as_slice(), - &self.schema, - )?; + let accumulator = + self.fun + .accumulator(&self.data_type, &self.sort_exprs, &self.schema)?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index b8da5ea7a092..4c098a10737d 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -473,8 +473,10 @@ impl AsExecutionPlan for PhysicalPlanNode { } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; + // TODO: `order by` is not supported for UDAF yet + let sort_exprs = &[]; let ordering_req = &[]; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &[], ordering_req, &physical_schema, name) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name) } } }).transpose()?.ok_or_else(|| { From 7e339101f76acb2fa19ee4d689ad9535b43e9e45 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 8 Mar 2024 06:22:54 +0800 Subject: [PATCH 17/46] add ignore nulls Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 2 +- .../core/tests/user_defined/user_defined_aggregates.rs | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index e065a0d791e4..cf56e09cb32e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -245,7 +245,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { distinct, args, filter, - order_by, + order_by: _, null_treatment: _, }) => match func_def { AggregateFunctionDefinition::BuiltIn(..) => { diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index e9583fff35c3..523fb697af9f 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -265,8 +265,12 @@ async fn simple_udaf_order() -> Result<()> { .map(|e| e.expr.data_type(schema)) .collect::>>()?; - let acc = - FirstValueAccumulator::try_new(data_type, &ordering_dtypes, ordering_req)?; + let acc = FirstValueAccumulator::try_new( + data_type, + &ordering_dtypes, + ordering_req, + false, + )?; Ok(Box::new(acc)) } From 6aaa15ca70f14f088d548e28eb762ccf87eee254 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 25 Mar 2024 20:18:10 +0800 Subject: [PATCH 18/46] fix conflict Signed-off-by: jayzhan211 --- .../core/tests/user_defined/user_defined_aggregates.rs | 2 +- datafusion/expr/src/function.rs | 5 ++--- datafusion/expr/src/udaf.rs | 9 +++++++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 3b4c4a69f3b8..dfe17430147b 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -431,7 +431,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ) .with_aliases(vec!["dummy_alias"]); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 2cf89a4fd39c..f08411550124 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,10 +17,9 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{Accumulator, BuiltinScalarFunction, Expr, PartitionEvaluator, Signature}; -use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue}; +use crate::ColumnarValue; +use crate::{Accumulator, Expr, PartitionEvaluator}; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::utils::datafusion_strsim; use datafusion_common::Result; use std::sync::Arc; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index d54152485890..d232a2b09c7e 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -355,8 +355,13 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.return_type(arg_types) } - fn accumulator(&self, arg: &DataType) -> Result> { - self.inner.accumulator(arg) + fn accumulator( + &self, + arg: &DataType, + sort_exprs: &[Expr], + schema: &Schema, + ) -> Result> { + self.inner.accumulator(arg, sort_exprs, schema) } fn state_type(&self, return_type: &DataType) -> Result> { From b74b7d2e8fc19efe61e504c3c5d8fac6b67b95f1 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 26 Mar 2024 09:02:24 +0800 Subject: [PATCH 19/46] backup Signed-off-by: jayzhan211 --- .../physical-expr/src/aggregate/first_last.rs | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 264f52f93a3a..20c3fc359eb9 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -33,7 +33,7 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateUDFImpl}; /// FIRST_VALUE aggregate expression #[derive(Debug, Clone)] @@ -47,6 +47,45 @@ pub struct FirstValue { ignore_nulls: bool, } +impl AggregateUDFImpl for FirstValue { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name.as_str() + } + + fn signature(&self) -> &datafusion_expr::Signature { + todo!() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + todo!() + } + + fn accumulator( + &self, + _arg: &DataType, + _sort_exprs: &[datafusion_expr::Expr], + _schema: &arrow_schema::Schema, + ) -> Result> { + FirstValueAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + self.ignore_nulls, + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) + } + + fn state_type(&self, return_type: &DataType) -> Result> { + todo!() + } +} + impl FirstValue { /// Creates a new FIRST_VALUE aggregation function. pub fn new( From 263e6cb484c5a15c8506c2a59e93106a3c042a3b Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 26 Mar 2024 09:12:29 +0800 Subject: [PATCH 20/46] complete return_type Signed-off-by: jayzhan211 --- datafusion/physical-expr/src/aggregate/first_last.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 20c3fc359eb9..2c713d9bcca8 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -61,7 +61,7 @@ impl AggregateUDFImpl for FirstValue { } fn return_type(&self, arg_types: &[DataType]) -> Result { - todo!() + Ok(arg_types[0].clone()) } fn accumulator( From 0a77e4f269b82583c4c712a6626b11a6410ef1cd Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 13:53:46 +0800 Subject: [PATCH 21/46] complete replace Signed-off-by: jayzhan211 --- Cargo.toml | 2 + datafusion/aggregate-functions/Cargo.toml | 49 ++++++ datafusion/aggregate-functions/src/core.rs | 16 ++ datafusion/aggregate-functions/src/lib.rs | 53 ++++++ datafusion/core/src/execution/context/mod.rs | 15 ++ datafusion/core/src/physical_planner.rs | 10 +- .../user_defined/user_defined_aggregates.rs | 4 + datafusion/expr/src/expr.rs | 3 +- datafusion/expr/src/expr_fn.rs | 124 +++++++++++++- datafusion/expr/src/tree_node/expr.rs | 1 + datafusion/expr/src/udaf.rs | 23 ++- .../optimizer/src/analyzer/type_coercion.rs | 9 +- .../optimizer/src/common_subexpr_eliminate.rs | 1 + .../physical-expr/src/aggregate/build_in.rs | 1 + .../physical-expr/src/aggregate/first_last.rs | 157 +++++++++++++----- .../physical-expr/src/aggregate/utils.rs | 1 + datafusion/physical-expr/src/lib.rs | 2 + .../physical-plan/src/aggregates/mod.rs | 3 + datafusion/physical-plan/src/udaf.rs | 69 +++++++- .../proto/src/logical_plan/from_proto.rs | 1 + .../tests/cases/roundtrip_logical_plan.rs | 1 + datafusion/sql/src/expr/function.rs | 8 +- .../substrait/src/logical_plan/consumer.rs | 2 +- 23 files changed, 507 insertions(+), 48 deletions(-) create mode 100644 datafusion/aggregate-functions/Cargo.toml create mode 100644 datafusion/aggregate-functions/src/core.rs create mode 100644 datafusion/aggregate-functions/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index c3dade8bc6c5..8073d056d712 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ members = [ "datafusion/execution", "datafusion/functions", "datafusion/functions-array", + "datafusion/aggregate-functions", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", @@ -78,6 +79,7 @@ datafusion-execution = { path = "datafusion/execution", version = "36.0.0" } datafusion-expr = { path = "datafusion/expr", version = "36.0.0" } datafusion-functions = { path = "datafusion/functions", version = "36.0.0" } datafusion-functions-array = { path = "datafusion/functions-array", version = "36.0.0" } +datafusion-aggrefate-functions = { path = "datafusion/aggreagate-functions", version = "36.0.0" } datafusion-optimizer = { path = "datafusion/optimizer", version = "36.0.0", default-features = false } datafusion-physical-expr = { path = "datafusion/physical-expr", version = "36.0.0", default-features = false } datafusion-physical-plan = { path = "datafusion/physical-plan", version = "36.0.0" } diff --git a/datafusion/aggregate-functions/Cargo.toml b/datafusion/aggregate-functions/Cargo.toml new file mode 100644 index 000000000000..b3be7410934e --- /dev/null +++ b/datafusion/aggregate-functions/Cargo.toml @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-aggregate-functions" +description = "Aggregate Function packages for the DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "datafusion_aggregate_functions" +path = "src/lib.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { workspace = true } +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-ord = { workspace = true } +arrow-schema = { workspace = true } +datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } +itertools = { version = "0.12", features = ["use_std"] } +log = { workspace = true } +paste = "1.0.14" diff --git a/datafusion/aggregate-functions/src/core.rs b/datafusion/aggregate-functions/src/core.rs new file mode 100644 index 000000000000..b248758bc120 --- /dev/null +++ b/datafusion/aggregate-functions/src/core.rs @@ -0,0 +1,16 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. diff --git a/datafusion/aggregate-functions/src/lib.rs b/datafusion/aggregate-functions/src/lib.rs new file mode 100644 index 000000000000..02b0007ad121 --- /dev/null +++ b/datafusion/aggregate-functions/src/lib.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Array Functions for [DataFusion]. +//! +//! This crate contains a collection of array functions implemented using the +//! extension API. +//! +//! [DataFusion]: https://crates.io/crates/datafusion +//! +//! You can register the functions in this crate using the [`register_all`] function. +//! +//! +//! + +mod core; + +use datafusion_common::Result; +use datafusion_execution::FunctionRegistry; +use datafusion_expr::AggregateUDF; +use log::debug; +use std::sync::Arc; + +/// Fluent-style API for creating `Expr`s +pub mod expr_fn {} + +/// Registers all enabled packages with a [`FunctionRegistry`] +pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { + let functions: Vec> = vec![]; + // functions.into_iter().try_for_each(|udf| { + // let existing_udf = registry.register_udf(udf)?; + // if let Some(existing_udf) = existing_udf { + // debug!("Overwrite existing UDF: {}", existing_udf.name()); + // } + // Ok(()) as Result<()> + // })?; + + Ok(()) +} diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 116e45c8c130..11e975a95c8a 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -69,11 +69,14 @@ use datafusion_common::{ OwnedTableReference, SchemaReference, }; use datafusion_execution::registry::SerializerRegistry; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::{create_first_value, Signature, Volatility}; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, var_provider::is_system_variables, Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; +use datafusion_physical_expr::create_first_value_accumulator; use datafusion_sql::{ parser::{CopyToSource, CopyToStatement, DFParser}, planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel}, @@ -1457,6 +1460,18 @@ impl SessionState { datafusion_functions_array::register_all(&mut new_self) .expect("can not register array expressions"); + // TODO: FIX the name + let my_first = create_first_value( + "my_first", + // vec![DataType::Int32], + // Arc::new(DataType::Int32), + // Volatility::Immutable, + Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), + Arc::new(create_first_value_accumulator), + // Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), + ); + let _ = new_self.register_udaf(Arc::new(my_first)); + new_self } /// Returns new [`SessionState`] using the provided diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 15764d84bdd5..d0e15a511d66 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -90,7 +90,7 @@ use datafusion_expr::{ DescribeTable, DmlStatement, RecursiveQuery, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; -use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::expressions::{self, FirstValue, Literal}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_sql::utils::window_expr_common_partition_keys; @@ -1715,6 +1715,14 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let ordering_reqs: Vec = order_by.clone().unwrap_or(vec![]); + // TODO: fix the name + if fun.name() == "my_first" { + let agg_expr = udaf::create_aggregate_expr_first_value( + fun, &args, &sort_exprs, &ordering_reqs, + physical_input_schema, name, ignore_nulls)?; + return Ok((agg_expr, filter, order_by)); + } + let agg_expr = udaf::create_aggregate_expr( fun, &args, diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index dfe17430147b..16116d2fcf58 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -836,6 +836,10 @@ impl AggregateUDFImpl for TestGroupsAccumulator { fn create_groups_accumulator(&self) -> Result> { Ok(Box::new(self.clone())) } + + // fn state_fields(&self) -> Result> { + // Ok(vec![Field::new("item", DataType::UInt64, true)]) + // } } impl Accumulator for TestGroupsAccumulator { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7ede4cd8ffc9..427c3fde7c0d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -577,6 +577,7 @@ impl AggregateFunction { distinct: bool, filter: Option>, order_by: Option>, + null_treatment: Option, ) -> Self { Self { func_def: AggregateFunctionDefinition::UDF(udf), @@ -584,7 +585,7 @@ impl AggregateFunction { distinct, filter, order_by, - null_treatment: None, + null_treatment, } } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1a021cedbabd..c13a9b002952 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -28,8 +28,8 @@ use crate::{ BuiltinScalarFunction, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; -use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; -use arrow::datatypes::{DataType, Schema}; +use crate::{signature, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Column, Result}; use std::any::Any; use std::fmt::Debug; @@ -769,6 +769,31 @@ pub fn create_udaf_with_ordering( )) } +/// Creates a new UDAF with a specific signature, state type and return type. +/// The signature and state type must match the `Accumulator's implementation`. +pub fn create_first_value( + name: &str, + // input_type: Vec, + // return_type: Arc, + // volatility: Volatility, + signature: Signature, + accumulator: AccumulatorFactoryFunction, + // state_type: Arc>, +) -> AggregateUDF { + // let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + // let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); + + AggregateUDF::from(FirstValue::new( + name, + signature, + // input_type, + // return_type, + // volatility, + accumulator, + // state_type, + )) +} + /// Implements [`AggregateUDFImpl`] for functions that have a single signature and /// return type. pub struct SimpleAggregateUDF { @@ -858,6 +883,92 @@ impl AggregateUDFImpl for SimpleAggregateUDF { fn state_type(&self, _return_type: &DataType) -> Result> { Ok(self.state_type.clone()) } + + // fn state_fields(&self) -> Result> { + // Ok(self + // .state_type + // .iter() + // .enumerate() + // .map(|(i, data_type)| Field::new(format!("{i}"), data_type.clone(), true)) + // .collect()) + // } +} + +pub struct FirstValue { + name: String, + signature: Signature, + // return_type: DataType, + accumulator: AccumulatorFactoryFunction, + // state_type: Vec, +} + +impl Debug for FirstValue { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("fun", &"") + .finish() + } +} + +impl FirstValue { + pub fn new( + name: impl Into, + // input_type: Vec, + // return_type: DataType, + // volatility: Volatility, + signature: Signature, + accumulator: AccumulatorFactoryFunction, + // state_type: Vec, + ) -> Self { + let name = name.into(); + // let signature = Signature::exact(input_type, volatility); + Self { + name, + signature, + // return_type, + accumulator, + // state_type, + } + } +} + +impl AggregateUDFImpl for FirstValue { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn accumulator( + &self, + arg: &DataType, + sort_exprs: &[Expr], + schema: &Schema, + ) -> Result> { + (self.accumulator)(arg, sort_exprs, schema) + } + + fn state_type(&self, _return_type: &DataType) -> Result> { + unreachable!() + } + + fn state_fields(&self, value_field: Field, ordering_field: Vec) -> Result> { + let mut fields = vec![value_field]; + fields.extend(ordering_field); + fields.push(Field::new("is_set", DataType::Boolean, true)); + Ok(fields) + } } /// Implements [`AggregateUDFImpl`] for functions that have a single signature and @@ -932,6 +1043,15 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { fn state_type(&self, _return_type: &DataType) -> Result> { Ok(self.state_type.clone()) } + + // fn state_fields(&self) -> Result> { + // Ok(self + // .state_type + // .iter() + // .enumerate() + // .map(|(i, data_type)| Field::new(format!("{i}"), data_type.clone(), true)) + // .collect()) + // } } /// Creates a new UDWF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 1c672851e9b5..0909d8f662f6 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -379,6 +379,7 @@ impl TreeNode for Expr { false, new_filter, new_order_by, + null_treatment, ))) } AggregateFunctionDefinition::Name(_) => { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index d232a2b09c7e..660862e33bf5 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -22,7 +22,7 @@ use crate::{Accumulator, Expr}; use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{not_impl_err, Result}; use std::any::Any; use std::fmt::{self, Debug, Formatter}; @@ -131,12 +131,14 @@ impl AggregateUDF { /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { + // TODO: Support dictinct, filter, order by and null_treatment Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf( Arc::new(self.clone()), args, false, None, None, + None, )) } @@ -182,6 +184,10 @@ impl AggregateUDF { self.inner.state_type(return_type) } + pub fn state_fields(&self, value_field: Field, ordering_field: Vec) -> Result> { + self.inner.state_fields(value_field, ordering_field) + } + /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. pub fn groups_accumulator_supported(&self) -> bool { self.inner.groups_accumulator_supported() @@ -293,6 +299,13 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// See [`Accumulator::state()`] for more details fn state_type(&self, return_type: &DataType) -> Result>; + /// Default fields including the value field and ordering fields + fn state_fields(&self, value_field: Field, ordering_field: Vec) -> Result> { + let mut fields = vec![value_field]; + fields.extend(ordering_field); + Ok(fields) + } + /// If the aggregate expression has a specialized /// [`GroupsAccumulator`] implementation. If this returns true, /// `[Self::create_groups_accumulator]` will be called. @@ -368,6 +381,10 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.state_type(return_type) } + // fn state_fields(&self, value_field: Field, ordering_field: Vec) -> Result> { + // self.inner.state_fields(value_field, ordering_field) + // } + fn aliases(&self) -> &[String] { &self.aliases } @@ -430,4 +447,8 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { let res = (self.state_type)(return_type)?; Ok(res.as_ref().clone()) } + + // fn state_fields(&self, value_field: Field, ordering_field: Vec) -> Result> { + // not_impl_err!("state_fields not implemented for legacy AggregateUDF") + // } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4f861ffe9967..147706b636f6 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -366,7 +366,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter { )?; Ok(Transformed::yes(Expr::AggregateFunction( expr::AggregateFunction::new_udf( - fun, new_expr, false, filter, order_by, + fun, + new_expr, + false, + filter, + order_by, + null_treatment, ), ))) } @@ -896,6 +901,7 @@ mod test { false, None, None, + None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation"; @@ -922,6 +928,7 @@ mod test { false, None, None, + None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "") diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 8b4d60aafd19..75c71d0aa298 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -979,6 +979,7 @@ mod test { false, None, None, + None, )) }; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 846431034c96..2da84125b7ff 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -367,6 +367,7 @@ pub fn create_aggregate_expr( input_phy_types[0].clone(), ordering_req.to_vec(), ordering_types, + vec![], ) .with_ignore_nulls(ignore_nulls), ), diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 2c713d9bcca8..54789ec1ea42 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use crate::aggregate::utils::{down_cast_any_ref, get_sort_options, ordering_fields}; -use crate::expressions::format_state_name; +use crate::expressions::{self, format_state_name}; use crate::{ reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, }; @@ -29,11 +29,12 @@ use crate::{ use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; use arrow::compute::{self, lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field}; +use arrow_schema::{Schema, SortOptions}; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::{Accumulator, AggregateUDFImpl}; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Expr}; /// FIRST_VALUE aggregate expression #[derive(Debug, Clone)] @@ -45,47 +46,78 @@ pub struct FirstValue { ordering_req: LexOrdering, requirement_satisfied: bool, ignore_nulls: bool, + state_fields: Vec, } -impl AggregateUDFImpl for FirstValue { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - self.name.as_str() - } - - fn signature(&self) -> &datafusion_expr::Signature { - todo!() - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) - } - - fn accumulator( - &self, - _arg: &DataType, - _sort_exprs: &[datafusion_expr::Expr], - _schema: &arrow_schema::Schema, - ) -> Result> { - FirstValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, - ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) - } - - fn state_type(&self, return_type: &DataType) -> Result> { - todo!() - } +#[derive(Debug, Clone)] +pub struct FirstValueUDF { + name: String, + input_data_type: DataType, + order_by_data_types: Vec, + expr: Arc, + ordering_req: LexOrdering, + requirement_satisfied: bool, + ignore_nulls: bool, } +// impl AggregateUDFImpl for FirstValue { +// fn as_any(&self) -> &dyn Any { +// self +// } + +// fn name(&self) -> &str { +// self.name.as_str() +// } + +// fn signature(&self) -> &datafusion_expr::Signature { +// todo!() +// } + +// fn return_type(&self, arg_types: &[DataType]) -> Result { +// Ok(arg_types[0].clone()) +// } + +// fn accumulator( +// &self, +// _arg: &DataType, +// _sort_exprs: &[datafusion_expr::Expr], +// _schema: &arrow_schema::Schema, +// ) -> Result> { +// FirstValueAccumulator::try_new( +// &self.input_data_type, +// &self.order_by_data_types, +// self.ordering_req.clone(), +// self.ignore_nulls, +// ) +// .map(|acc| { +// Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ +// }) +// } + +// fn state_type(&self, return_type: &DataType) -> Result> { +// todo!() +// } + +// fn state_fields(&self) -> Result> { +// let mut fields = vec![Field::new( +// format_state_name(&self.name, "first_value"), +// self.input_data_type.clone(), +// true, +// )]; +// fields.extend(ordering_fields( +// &self.ordering_req, +// &self.order_by_data_types, +// )); +// fields.push(Field::new( +// format_state_name(&self.name, "is_set"), +// DataType::Boolean, +// true, +// )); + +// Ok(fields) +// } +// } + impl FirstValue { /// Creates a new FIRST_VALUE aggregation function. pub fn new( @@ -94,6 +126,7 @@ impl FirstValue { input_data_type: DataType, ordering_req: LexOrdering, order_by_data_types: Vec, + state_fields: Vec, ) -> Self { let requirement_satisfied = ordering_req.is_empty(); Self { @@ -104,6 +137,7 @@ impl FirstValue { ordering_req, requirement_satisfied, ignore_nulls: false, + state_fields, } } @@ -188,6 +222,10 @@ impl AggregateExpr for FirstValue { } fn state_fields(&self) -> Result> { + if !self.state_fields.is_empty() { + return Ok(self.state_fields.clone()); + } + let mut fields = vec![Field::new( format_state_name(&self.name, "first_value"), self.input_data_type.clone(), @@ -423,6 +461,46 @@ impl Accumulator for FirstValueAccumulator { } } +pub fn create_first_value_accumulator( + data_type: &DataType, + order_by: &[Expr], + schema: &Schema, +) -> Result> { + let mut all_sort_orders = vec![]; + + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in order_by { + if let Expr::Sort(sort) = expr { + if let Expr::Column(col) = sort.expr.as_ref() { + let name = &col.name; + let e = expressions::col(name, schema)?; + sort_exprs.push(PhysicalSortExpr { + expr: e, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); + } + } + } + if !sort_exprs.is_empty() { + all_sort_orders.extend(sort_exprs); + } + + let ordering_req = all_sort_orders; + + let ordering_dtypes = ordering_req + .iter() + .map(|e| e.expr.data_type(schema)) + .collect::>>()?; + + let acc = + FirstValueAccumulator::try_new(data_type, &ordering_dtypes, ordering_req, false)?; + Ok(Box::new(acc)) +} + /// LAST_VALUE aggregate expression #[derive(Debug, Clone)] pub struct LastValue { @@ -503,6 +581,7 @@ impl LastValue { input_data_type, reverse_order_bys(&ordering_req), order_by_data_types, + vec![] ) } } diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index 60d59c16be5f..d3f494fe6356 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -187,6 +187,7 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { } } +// TOOD: Remove /// Construct corresponding fields for lexicographical ordering requirement expression pub(crate) fn ordering_fields( ordering_req: &[PhysicalSortExpr], diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 1791a6ed60b2..6e308f20064b 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -60,3 +60,5 @@ pub use sort_expr::{ PhysicalSortRequirement, }; pub use utils::{reverse_order_bys, split_conjunction}; + +pub use aggregate::first_last::create_first_value_accumulator; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 65987e01553d..e0cabde59775 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -769,6 +769,7 @@ fn create_schema( AggregateMode::Partial => { // in partial mode, the fields of the accumulator's state for expr in aggr_expr { + println!("expr: {:?}", expr); fields.extend(expr.state_fields()?.iter().cloned()) } } @@ -2018,6 +2019,7 @@ mod tests { DataType::Float64, ordering_req.clone(), vec![DataType::Float64], + vec![], ))] } else { vec![Arc::new(LastValue::new( @@ -2201,6 +2203,7 @@ mod tests { DataType::Float64, sort_expr_reverse.clone(), vec![DataType::Float64], + vec![] )), Arc::new(LastValue::new( col_b.clone(), diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 3d15e3d012c5..17414ede4876 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -27,7 +27,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use super::{expressions::format_state_name, Accumulator, AggregateExpr}; use datafusion_common::{not_impl_err, Result}; pub use datafusion_expr::AggregateUDF; -use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; +use datafusion_physical_expr::{expressions, LexOrdering, PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; use std::sync::Arc; @@ -42,6 +42,10 @@ pub fn create_aggregate_expr( schema: &Schema, name: impl Into, ) -> Result> { + let name: String = name.into(); + println!("name: {}", name); + + let input_exprs_types = input_phy_exprs .iter() .map(|arg| arg.data_type(schema)) @@ -58,6 +62,69 @@ pub fn create_aggregate_expr( })) } +pub fn create_aggregate_expr_first_value( + fun: &AggregateUDF, + input_phy_exprs: &[Arc], + sort_exprs: &[Expr], + ordering_req: &[PhysicalSortExpr], + schema: &Schema, + name: impl Into, + ignore_nulls: bool, +) -> Result> { + let input_exprs_types = input_phy_exprs + .iter() + .map(|arg| arg.data_type(schema)) + .collect::>>()?; + + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(schema)) + .collect::>>()?; + + let name: String = name.into(); + + let input_data_type = fun.return_type(&input_exprs_types)?; + + let value_field = Field::new( + format_state_name(&name, "first_value"), + input_data_type.clone(), + true, + ); + let ordering_fields = ordering_fields(&ordering_req, &ordering_types); + + let state_fields = fun.state_fields(value_field, ordering_fields)?; + + let first_value = expressions::FirstValue::new( + input_phy_exprs[0].clone(), + name, + input_data_type, + ordering_req.to_vec(), + ordering_types, + state_fields, + ) + .with_ignore_nulls(ignore_nulls); + return Ok(Arc::new(first_value)); +} + +fn ordering_fields( + ordering_req: &[PhysicalSortExpr], + // Data type of each expression in the ordering requirement + data_types: &[DataType], +) -> Vec { + ordering_req + .iter() + .zip(data_types.iter()) + .map(|(sort_expr, dtype)| { + Field::new( + sort_expr.expr.to_string().as_str(), + dtype.clone(), + // Multi partitions may be empty hence field should be nullable. + true, + ) + }) + .collect() +} + /// Physical aggregate expression of a UDAF. #[derive(Debug)] pub struct AggregateFunctionExpr { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d5eebcb69841..cafdeee5e945 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1560,6 +1560,7 @@ pub fn parse_expr( false, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), parse_vec_expr(&pb.order_by, registry, codec)?, + None, ))) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index c40bfc97677c..fa4fa2daba4d 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1773,6 +1773,7 @@ fn roundtrip_aggregate_udf() { false, Some(Box::new(lit(true))), None, + None, )); let ctx = SessionContext::new(); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 9b92067b9ec6..e97eb1a32b12 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -225,8 +225,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; let order_by = (!order_by.is_empty()).then_some(order_by); let args = self.function_args_to_expr(args, schema, planner_context)?; + // TODO: Support filter and distinct for UDAFs return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fm, args, false, None, order_by, + fm, + args, + false, + None, + order_by, + null_treatment, ))); } diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index ed1e48ca71a6..b398c5df86dd 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -749,7 +749,7 @@ pub async fn from_substrait_agg_func( // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { Ok(Arc::new(Expr::AggregateFunction( - expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by), + expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), ))) } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) { From 7b26377910417faa7514da825c2829da5a4abe0b Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 15:05:11 +0800 Subject: [PATCH 22/46] split to first value udf Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 83 ++++-- datafusion/expr/src/expr_fn.rs | 8 +- datafusion/expr/src/udaf.rs | 12 +- .../physical-expr/src/aggregate/first_last.rs | 239 +++++++++++++----- .../physical-expr/src/expressions/mod.rs | 2 +- .../physical-plan/src/aggregates/mod.rs | 3 +- datafusion/physical-plan/src/udaf.rs | 73 ++++-- datafusion/sqllogictest/test_files/test1.slt | 12 + 8 files changed, 320 insertions(+), 112 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/test1.slt diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d0e15a511d66..dced2d3410f4 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -90,7 +90,7 @@ use datafusion_expr::{ DescribeTable, DmlStatement, RecursiveQuery, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; -use datafusion_physical_expr::expressions::{self, FirstValue, Literal}; +use datafusion_physical_expr::expressions::Literal; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_sql::utils::window_expr_common_partition_keys; @@ -1661,15 +1661,11 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( Expr::AggregateFunction(AggregateFunction { func_def, distinct, - args, + args: origin_args, filter, order_by, null_treatment, }) => { - let args = args - .iter() - .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) - .collect::>>()?; let filter = match filter { Some(e) => Some(create_physical_expr( e, @@ -1679,26 +1675,31 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( None => None, }; - let sort_exprs = order_by.clone().unwrap_or(vec![]); - let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, - }; let ignore_nulls = null_treatment .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { + let args = origin_args + .iter() + .map(|e| { + create_physical_expr(e, logical_input_schema, execution_props) + }) + .collect::>>()?; + let order_by = match order_by { + Some(e) => Some( + e.iter() + .map(|expr| { + create_physical_sort_expr( + expr, + logical_input_schema, + execution_props, + ) + }) + .collect::>>()?, + ), + None => None, + }; let ordering_reqs = order_by.clone().unwrap_or(vec![]); let agg_expr = aggregates::create_aggregate_expr( fun, @@ -1712,16 +1713,48 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( (agg_expr, filter, order_by) } AggregateFunctionDefinition::UDF(fun) => { - let ordering_reqs: Vec = - order_by.clone().unwrap_or(vec![]); + let sort_exprs = order_by.clone().unwrap_or(vec![]); + let order_by = match order_by { + Some(e) => Some( + e.iter() + .map(|expr| { + create_physical_sort_expr( + expr, + logical_input_schema, + execution_props, + ) + }) + .collect::>>()?, + ), + None => None, + }; // TODO: fix the name if fun.name() == "my_first" { let agg_expr = udaf::create_aggregate_expr_first_value( - fun, &args, &sort_exprs, &ordering_reqs, - physical_input_schema, name, ignore_nulls)?; + fun, + origin_args, + // &args, + &sort_exprs, + logical_input_schema, + execution_props, + // &ordering_reqs, + physical_input_schema, + name, + ignore_nulls, + )?; return Ok((agg_expr, filter, order_by)); } + + let ordering_reqs: Vec = + order_by.clone().unwrap_or(vec![]); + + let args = origin_args + .iter() + .map(|e| { + create_physical_expr(e, logical_input_schema, execution_props) + }) + .collect::>>()?; let agg_expr = udaf::create_aggregate_expr( fun, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c13a9b002952..6354ab57d9be 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -28,7 +28,7 @@ use crate::{ BuiltinScalarFunction, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; -use crate::{signature, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; +use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Column, Result}; use std::any::Any; @@ -963,7 +963,11 @@ impl AggregateUDFImpl for FirstValue { unreachable!() } - fn state_fields(&self, value_field: Field, ordering_field: Vec) -> Result> { + fn state_fields( + &self, + value_field: Field, + ordering_field: Vec, + ) -> Result> { let mut fields = vec![value_field]; fields.extend(ordering_field); fields.push(Field::new("is_set", DataType::Boolean, true)); diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 660862e33bf5..ceabd74d0fa0 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -184,7 +184,11 @@ impl AggregateUDF { self.inner.state_type(return_type) } - pub fn state_fields(&self, value_field: Field, ordering_field: Vec) -> Result> { + pub fn state_fields( + &self, + value_field: Field, + ordering_field: Vec, + ) -> Result> { self.inner.state_fields(value_field, ordering_field) } @@ -300,7 +304,11 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn state_type(&self, return_type: &DataType) -> Result>; /// Default fields including the value field and ordering fields - fn state_fields(&self, value_field: Field, ordering_field: Vec) -> Result> { + fn state_fields( + &self, + value_field: Field, + ordering_field: Vec, + ) -> Result> { let mut fields = vec![value_field]; fields.extend(ordering_field); Ok(fields) diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 54789ec1ea42..07abf89c1b84 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -34,11 +34,11 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Expr}; +use datafusion_expr::{Accumulator, Expr}; + -/// FIRST_VALUE aggregate expression #[derive(Debug, Clone)] -pub struct FirstValue { +pub struct FirstValueUDF { name: String, input_data_type: DataType, order_by_data_types: Vec, @@ -49,8 +49,178 @@ pub struct FirstValue { state_fields: Vec, } +impl FirstValueUDF { + /// Creates a new FIRST_VALUE aggregation function. + pub fn new( + expr: Arc, + name: impl Into, + input_data_type: DataType, + ordering_req: LexOrdering, + order_by_data_types: Vec, + state_fields: Vec, + ) -> Self { + let requirement_satisfied = ordering_req.is_empty(); + Self { + name: name.into(), + input_data_type, + order_by_data_types, + expr, + ordering_req, + requirement_satisfied, + ignore_nulls: false, + state_fields, + } + } + + pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { + self.ignore_nulls = ignore_nulls; + self + } + + /// Returns the name of the aggregate expression. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the input data type of the aggregate expression. + pub fn input_data_type(&self) -> &DataType { + &self.input_data_type + } + + /// Returns the data types of the order-by columns. + pub fn order_by_data_types(&self) -> &Vec { + &self.order_by_data_types + } + + /// Returns the expression associated with the aggregate function. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Returns the lexical ordering requirements of the aggregate expression. + pub fn ordering_req(&self) -> &LexOrdering { + &self.ordering_req + } + + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + + pub fn convert_to_last(self) -> LastValue { + let name = if self.name.starts_with("FIRST") { + format!("LAST{}", &self.name[5..]) + } else { + format!("LAST_VALUE({})", self.expr) + }; + let FirstValueUDF { + expr, + input_data_type, + ordering_req, + order_by_data_types, + .. + } = self; + LastValue::new( + expr, + name, + input_data_type, + reverse_order_bys(&ordering_req), + order_by_data_types, + ) + } +} + +impl AggregateExpr for FirstValueUDF { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.input_data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + FirstValueAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + self.ignore_nulls, + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) + } + + fn state_fields(&self) -> Result> { + if !self.state_fields.is_empty() { + return Ok(self.state_fields.clone()); + } + + let mut fields = vec![Field::new( + format_state_name(&self.name, "first_value"), + self.input_data_type.clone(), + true, + )]; + fields.extend(ordering_fields( + &self.ordering_req, + &self.order_by_data_types, + )); + fields.push(Field::new( + format_state_name(&self.name, "is_set"), + DataType::Boolean, + true, + )); + Ok(fields) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone().convert_to_last())) + } + + fn create_sliding_accumulator(&self) -> Result> { + FirstValueAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + self.ignore_nulls, + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) + } +} + +impl PartialEq for FirstValueUDF { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.input_data_type == x.input_data_type + && self.order_by_data_types == x.order_by_data_types + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + + +/// FIRST_VALUE aggregate expression #[derive(Debug, Clone)] -pub struct FirstValueUDF { +pub struct FirstValue { name: String, input_data_type: DataType, order_by_data_types: Vec, @@ -58,66 +228,9 @@ pub struct FirstValueUDF { ordering_req: LexOrdering, requirement_satisfied: bool, ignore_nulls: bool, + state_fields: Vec, } -// impl AggregateUDFImpl for FirstValue { -// fn as_any(&self) -> &dyn Any { -// self -// } - -// fn name(&self) -> &str { -// self.name.as_str() -// } - -// fn signature(&self) -> &datafusion_expr::Signature { -// todo!() -// } - -// fn return_type(&self, arg_types: &[DataType]) -> Result { -// Ok(arg_types[0].clone()) -// } - -// fn accumulator( -// &self, -// _arg: &DataType, -// _sort_exprs: &[datafusion_expr::Expr], -// _schema: &arrow_schema::Schema, -// ) -> Result> { -// FirstValueAccumulator::try_new( -// &self.input_data_type, -// &self.order_by_data_types, -// self.ordering_req.clone(), -// self.ignore_nulls, -// ) -// .map(|acc| { -// Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ -// }) -// } - -// fn state_type(&self, return_type: &DataType) -> Result> { -// todo!() -// } - -// fn state_fields(&self) -> Result> { -// let mut fields = vec![Field::new( -// format_state_name(&self.name, "first_value"), -// self.input_data_type.clone(), -// true, -// )]; -// fields.extend(ordering_fields( -// &self.ordering_req, -// &self.order_by_data_types, -// )); -// fields.push(Field::new( -// format_state_name(&self.name, "is_set"), -// DataType::Boolean, -// true, -// )); - -// Ok(fields) -// } -// } - impl FirstValue { /// Creates a new FIRST_VALUE aggregation function. pub fn new( @@ -581,7 +694,7 @@ impl LastValue { input_data_type, reverse_order_bys(&ordering_req), order_by_data_types, - vec![] + vec![], ) } } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index fcd656173355..8a4b1bda67f8 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -53,7 +53,7 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::first_last::{FirstValue, FirstValueAccumulator, LastValue}; +pub use crate::aggregate::first_last::{FirstValue, FirstValueUDF, FirstValueAccumulator, LastValue}; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index e0cabde59775..400ee9db7ee7 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -769,7 +769,6 @@ fn create_schema( AggregateMode::Partial => { // in partial mode, the fields of the accumulator's state for expr in aggr_expr { - println!("expr: {:?}", expr); fields.extend(expr.state_fields()?.iter().cloned()) } } @@ -2203,7 +2202,7 @@ mod tests { DataType::Float64, sort_expr_reverse.clone(), vec![DataType::Float64], - vec![] + vec![], )), Arc::new(LastValue::new( col_b.clone(), diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 17414ede4876..453fe541d8e4 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -17,7 +17,9 @@ //! This module contains functions and structs supporting user-defined aggregate functions. -use datafusion_expr::{Expr, GroupsAccumulator}; +use arrow_schema::SortOptions; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::{expr, Expr, GroupsAccumulator}; use fmt::Debug; use std::any::Any; use std::fmt; @@ -25,9 +27,11 @@ use std::fmt; use arrow::datatypes::{DataType, Field, Schema}; use super::{expressions::format_state_name, Accumulator, AggregateExpr}; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{internal_err, not_impl_err, DFSchema, Result}; pub use datafusion_expr::AggregateUDF; -use datafusion_physical_expr::{expressions, LexOrdering, PhysicalExpr, PhysicalSortExpr}; +use datafusion_physical_expr::{ + create_physical_expr, expressions, LexOrdering, PhysicalExpr, PhysicalSortExpr, +}; use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; use std::sync::Arc; @@ -42,10 +46,6 @@ pub fn create_aggregate_expr( schema: &Schema, name: impl Into, ) -> Result> { - let name: String = name.into(); - println!("name: {}", name); - - let input_exprs_types = input_phy_exprs .iter() .map(|arg| arg.data_type(schema)) @@ -62,20 +62,57 @@ pub fn create_aggregate_expr( })) } +// TODO: Duplicated functoin from `datafusion/core/src/physical_planner.rs`, remove one of them +// TODO: Maybe move to physical-expr +/// Create a physical sort expression from a logical expression +pub fn create_physical_sort_expr( + e: &Expr, + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result { + if let Expr::Sort(expr::Sort { + expr, + asc, + nulls_first, + }) = e + { + Ok(PhysicalSortExpr { + expr: create_physical_expr(expr, input_dfschema, execution_props)?, + options: SortOptions { + descending: !asc, + nulls_first: *nulls_first, + }, + }) + } else { + internal_err!("Expects a sort expression") + } +} + pub fn create_aggregate_expr_first_value( fun: &AggregateUDF, - input_phy_exprs: &[Arc], + args: &[Expr], + // input_phy_exprs: &[Arc], sort_exprs: &[Expr], - ordering_req: &[PhysicalSortExpr], + dfschema: &DFSchema, + execution_props: &ExecutionProps, + // ordering_req: &[PhysicalSortExpr], schema: &Schema, name: impl Into, ignore_nulls: bool, ) -> Result> { - let input_exprs_types = input_phy_exprs + let args = args + .iter() + .map(|e| create_physical_expr(e, dfschema, execution_props)) + .collect::>>()?; + let input_exprs_types = args .iter() .map(|arg| arg.data_type(schema)) .collect::>>()?; + let ordering_req = sort_exprs + .iter() + .map(|e| create_physical_sort_expr(e, dfschema, execution_props)) + .collect::>>()?; let ordering_types = ordering_req .iter() .map(|e| e.expr.data_type(schema)) @@ -85,17 +122,18 @@ pub fn create_aggregate_expr_first_value( let input_data_type = fun.return_type(&input_exprs_types)?; - let value_field = Field::new( - format_state_name(&name, "first_value"), - input_data_type.clone(), - true, - ); + let value_field = Field::new( + format_state_name(&name, "first_value"), + input_data_type.clone(), + true, + ); let ordering_fields = ordering_fields(&ordering_req, &ordering_types); let state_fields = fun.state_fields(value_field, ordering_fields)?; - let first_value = expressions::FirstValue::new( - input_phy_exprs[0].clone(), + let first_value = expressions::FirstValueUDF::new( + args[0].clone(), + // input_phy_exprs[0].clone(), name, input_data_type, ordering_req.to_vec(), @@ -106,6 +144,7 @@ pub fn create_aggregate_expr_first_value( return Ok(Arc::new(first_value)); } +// TODO: Duplicated functoin. fn ordering_fields( ordering_req: &[PhysicalSortExpr], // Data type of each expression in the ordering requirement diff --git a/datafusion/sqllogictest/test_files/test1.slt b/datafusion/sqllogictest/test_files/test1.slt new file mode 100644 index 000000000000..21099576309b --- /dev/null +++ b/datafusion/sqllogictest/test_files/test1.slt @@ -0,0 +1,12 @@ +statement ok +CREATE TABLE t AS VALUES (null::bigint), (3), (4); + +query I +SELECT first_value(column1) FROM t; +---- +NULL + +query I +SELECT my_first(column1) FROM t; +---- +NULL From 4bfd91d37b0bdead75acadf5069ed50124ffce39 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 17:07:14 +0800 Subject: [PATCH 23/46] replace accumulator Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 3 +- datafusion/expr/src/expr_fn.rs | 16 ++++-- datafusion/expr/src/function.rs | 6 +++ datafusion/expr/src/udaf.rs | 12 ++++- .../physical-expr/src/aggregate/first_last.rs | 54 ++++++++++--------- datafusion/physical-plan/src/udaf.rs | 21 +++++--- datafusion/physical-plan/src/windows/mod.rs | 1 + 7 files changed, 75 insertions(+), 38 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index dced2d3410f4..f5aa9cf93fe8 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1745,7 +1745,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( )?; return Ok((agg_expr, filter, order_by)); } - + let ordering_reqs: Vec = order_by.clone().unwrap_or(vec![]); @@ -1763,6 +1763,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( &ordering_reqs, physical_input_schema, name, + ignore_nulls, )?; (agg_expr, filter, order_by) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 6354ab57d9be..25f722076aa6 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -21,7 +21,7 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Placeholder, ScalarFunction, TryCast, }; -use crate::function::PartitionEvaluatorFactory; +use crate::function::{AccumulatorFactoryFunctionForFirstValue, PartitionEvaluatorFactory}; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, @@ -777,7 +777,7 @@ pub fn create_first_value( // return_type: Arc, // volatility: Volatility, signature: Signature, - accumulator: AccumulatorFactoryFunction, + accumulator: AccumulatorFactoryFunctionForFirstValue, // state_type: Arc>, ) -> AggregateUDF { // let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); @@ -876,6 +876,8 @@ impl AggregateUDFImpl for SimpleAggregateUDF { arg: &DataType, sort_exprs: &[Expr], schema: &Schema, + _ignore_nulls: bool, + _requirement_satisfied: bool, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } @@ -898,7 +900,7 @@ pub struct FirstValue { name: String, signature: Signature, // return_type: DataType, - accumulator: AccumulatorFactoryFunction, + accumulator: AccumulatorFactoryFunctionForFirstValue, // state_type: Vec, } @@ -918,7 +920,7 @@ impl FirstValue { // return_type: DataType, // volatility: Volatility, signature: Signature, - accumulator: AccumulatorFactoryFunction, + accumulator: AccumulatorFactoryFunctionForFirstValue, // state_type: Vec, ) -> Self { let name = name.into(); @@ -955,8 +957,10 @@ impl AggregateUDFImpl for FirstValue { arg: &DataType, sort_exprs: &[Expr], schema: &Schema, + ignore_nulls: bool, + requirement_satisfied: bool, ) -> Result> { - (self.accumulator)(arg, sort_exprs, schema) + (self.accumulator)(arg, sort_exprs, schema, ignore_nulls, requirement_satisfied) } fn state_type(&self, _return_type: &DataType) -> Result> { @@ -1040,6 +1044,8 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { arg: &DataType, sort_exprs: &[Expr], schema: &Schema, + _ignore_nulls: bool, + _requirement_satisfied: bool, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index f08411550124..ef577627a5f9 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -44,6 +44,12 @@ pub type AccumulatorFactoryFunction = Arc< dyn Fn(&DataType, &[Expr], &Schema) -> Result> + Send + Sync, >; +/// Factory that returns an accumulator for the given aggregate, given +/// its return datatype, the sorting expressions and the schema for ordering. +pub type AccumulatorFactoryFunctionForFirstValue = Arc< + dyn Fn(&DataType, &[Expr], &Schema, bool, bool) -> Result> + Send + Sync, +>; + /// Factory that creates a PartitionEvaluator for the given window /// function pub type PartitionEvaluatorFactory = diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index ceabd74d0fa0..b76fff7ad005 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -174,8 +174,10 @@ impl AggregateUDF { return_type: &DataType, sort_exprs: &[Expr], schema: &Schema, + ignore_nulls: bool, + requirement_satisfied: bool, ) -> Result> { - self.inner.accumulator(return_type, sort_exprs, schema) + self.inner.accumulator(return_type, sort_exprs, schema, ignore_nulls, requirement_satisfied) } /// Return the type of the intermediate state used by this aggregator, given @@ -297,6 +299,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { arg: &DataType, sort_exprs: &[Expr], schema: &Schema, + ignore_nulls: bool, + requirement_satisfied: bool, ) -> Result>; /// Return the type used to serialize the [`Accumulator`]'s intermediate state. @@ -381,8 +385,10 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { arg: &DataType, sort_exprs: &[Expr], schema: &Schema, + ignore_nulls: bool, + requirement_satisfied: bool, ) -> Result> { - self.inner.accumulator(arg, sort_exprs, schema) + self.inner.accumulator(arg, sort_exprs, schema, ignore_nulls, requirement_satisfied) } fn state_type(&self, return_type: &DataType) -> Result> { @@ -447,6 +453,8 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { arg: &DataType, sort_exprs: &[Expr], schema: &Schema, + _ignore_nulls: bool, + _requirement_satisfied: bool, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 07abf89c1b84..d555628632ef 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -34,41 +34,52 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::{Accumulator, Expr}; +use datafusion_expr::{Accumulator, AggregateUDF, Expr}; + +// Equivalent to AggregateFunctionExpr #[derive(Debug, Clone)] pub struct FirstValueUDF { + fun: AggregateUDF, name: String, - input_data_type: DataType, + return_type: DataType, order_by_data_types: Vec, expr: Arc, ordering_req: LexOrdering, requirement_satisfied: bool, ignore_nulls: bool, state_fields: Vec, + sort_exprs: Vec, + schema: Schema, } impl FirstValueUDF { /// Creates a new FIRST_VALUE aggregation function. pub fn new( + fun: AggregateUDF, expr: Arc, name: impl Into, - input_data_type: DataType, + return_type: DataType, ordering_req: LexOrdering, order_by_data_types: Vec, state_fields: Vec, + sort_exprs: Vec, + schema: Schema, ) -> Self { let requirement_satisfied = ordering_req.is_empty(); Self { + fun, name: name.into(), - input_data_type, + return_type, order_by_data_types, expr, ordering_req, requirement_satisfied, ignore_nulls: false, state_fields, + sort_exprs, + schema, } } @@ -83,8 +94,8 @@ impl FirstValueUDF { } /// Returns the input data type of the aggregate expression. - pub fn input_data_type(&self) -> &DataType { - &self.input_data_type + pub fn return_type(&self) -> &DataType { + &self.return_type } /// Returns the data types of the order-by columns. @@ -115,7 +126,7 @@ impl FirstValueUDF { }; let FirstValueUDF { expr, - input_data_type, + return_type, ordering_req, order_by_data_types, .. @@ -123,7 +134,7 @@ impl FirstValueUDF { LastValue::new( expr, name, - input_data_type, + return_type, reverse_order_bys(&ordering_req), order_by_data_types, ) @@ -137,19 +148,11 @@ impl AggregateExpr for FirstValueUDF { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), true)) + Ok(Field::new(&self.name, self.return_type.clone(), true)) } fn create_accumulator(&self) -> Result> { - FirstValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, - ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) + self.fun.accumulator(&self.return_type, &self.sort_exprs, &self.schema, self.ignore_nulls, self.requirement_satisfied) } fn state_fields(&self) -> Result> { @@ -159,7 +162,7 @@ impl AggregateExpr for FirstValueUDF { let mut fields = vec![Field::new( format_state_name(&self.name, "first_value"), - self.input_data_type.clone(), + self.return_type.clone(), true, )]; fields.extend(ordering_fields( @@ -192,7 +195,7 @@ impl AggregateExpr for FirstValueUDF { fn create_sliding_accumulator(&self) -> Result> { FirstValueAccumulator::try_new( - &self.input_data_type, + &self.return_type, &self.order_by_data_types, self.ordering_req.clone(), self.ignore_nulls, @@ -209,7 +212,7 @@ impl PartialEq for FirstValueUDF { .downcast_ref::() .map(|x| { self.name == x.name - && self.input_data_type == x.input_data_type + && self.return_type == x.return_type && self.order_by_data_types == x.order_by_data_types && self.expr.eq(&x.expr) }) @@ -578,6 +581,8 @@ pub fn create_first_value_accumulator( data_type: &DataType, order_by: &[Expr], schema: &Schema, + ignore_nulls: bool, + requirement_satisfied: bool, ) -> Result> { let mut all_sort_orders = vec![]; @@ -609,9 +614,10 @@ pub fn create_first_value_accumulator( .map(|e| e.expr.data_type(schema)) .collect::>>()?; - let acc = - FirstValueAccumulator::try_new(data_type, &ordering_dtypes, ordering_req, false)?; - Ok(Box::new(acc)) + + FirstValueAccumulator::try_new(data_type, &ordering_dtypes, ordering_req, ignore_nulls) + .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) + // Ok(Box::new(acc)) } /// LAST_VALUE aggregate expression diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 453fe541d8e4..717346157322 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -45,12 +45,15 @@ pub fn create_aggregate_expr( ordering_req: &[PhysicalSortExpr], schema: &Schema, name: impl Into, + ignore_nulls: bool, ) -> Result> { let input_exprs_types = input_phy_exprs .iter() .map(|arg| arg.data_type(schema)) .collect::>>()?; + let requirement_satisfied = ordering_req.is_empty(); + Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), args: input_phy_exprs.to_vec(), @@ -59,6 +62,8 @@ pub fn create_aggregate_expr( schema: schema.clone(), sort_exprs: sort_exprs.to_vec(), ordering_req: ordering_req.to_vec(), + ignore_nulls, + requirement_satisfied, })) } @@ -120,11 +125,11 @@ pub fn create_aggregate_expr_first_value( let name: String = name.into(); - let input_data_type = fun.return_type(&input_exprs_types)?; + let return_type = fun.return_type(&input_exprs_types)?; let value_field = Field::new( format_state_name(&name, "first_value"), - input_data_type.clone(), + return_type.clone(), true, ); let ordering_fields = ordering_fields(&ordering_req, &ordering_types); @@ -132,13 +137,15 @@ pub fn create_aggregate_expr_first_value( let state_fields = fun.state_fields(value_field, ordering_fields)?; let first_value = expressions::FirstValueUDF::new( + fun.clone(), args[0].clone(), - // input_phy_exprs[0].clone(), name, - input_data_type, + return_type, ordering_req.to_vec(), ordering_types, state_fields, + sort_exprs.to_vec(), + schema.clone(), ) .with_ignore_nulls(ignore_nulls); return Ok(Arc::new(first_value)); @@ -177,6 +184,8 @@ pub struct AggregateFunctionExpr { sort_exprs: Vec, // The physical order by expressions ordering_req: LexOrdering, + ignore_nulls: bool, + requirement_satisfied: bool, } impl AggregateFunctionExpr { @@ -220,13 +229,13 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_accumulator(&self) -> Result> { self.fun - .accumulator(&self.data_type, self.sort_exprs.as_slice(), &self.schema) + .accumulator(&self.data_type, self.sort_exprs.as_slice(), &self.schema, self.ignore_nulls, self.requirement_satisfied) } fn create_sliding_accumulator(&self) -> Result> { let accumulator = self.fun - .accumulator(&self.data_type, &self.sort_exprs, &self.schema)?; + .accumulator(&self.data_type, &self.sort_exprs, &self.schema, self.ignore_nulls, self.requirement_satisfied)?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index e0ad9363051f..c5c845614c7b 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -103,6 +103,7 @@ pub fn create_window_expr( ordering_req, input_schema, name, + ignore_nulls, )?; window_expr_from_aggregate_expr( partition_by, From 7f54141faf41db49c47e08f8909cf5fd5448289a Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 17:08:09 +0800 Subject: [PATCH 24/46] fmt Signed-off-by: jayzhan211 --- datafusion/expr/src/expr_fn.rs | 4 +++- datafusion/expr/src/function.rs | 4 +++- datafusion/expr/src/udaf.rs | 16 ++++++++++++-- .../physical-expr/src/aggregate/first_last.rs | 21 ++++++++++++------- .../physical-expr/src/expressions/mod.rs | 4 +++- datafusion/physical-plan/src/udaf.rs | 19 ++++++++++++----- datafusion/proto/src/physical_plan/mod.rs | 3 ++- 7 files changed, 53 insertions(+), 18 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 25f722076aa6..d4c9874a0794 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -21,7 +21,9 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Placeholder, ScalarFunction, TryCast, }; -use crate::function::{AccumulatorFactoryFunctionForFirstValue, PartitionEvaluatorFactory}; +use crate::function::{ + AccumulatorFactoryFunctionForFirstValue, PartitionEvaluatorFactory, +}; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index ef577627a5f9..bf5385e04a70 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -47,7 +47,9 @@ pub type AccumulatorFactoryFunction = Arc< /// Factory that returns an accumulator for the given aggregate, given /// its return datatype, the sorting expressions and the schema for ordering. pub type AccumulatorFactoryFunctionForFirstValue = Arc< - dyn Fn(&DataType, &[Expr], &Schema, bool, bool) -> Result> + Send + Sync, + dyn Fn(&DataType, &[Expr], &Schema, bool, bool) -> Result> + + Send + + Sync, >; /// Factory that creates a PartitionEvaluator for the given window diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b76fff7ad005..7d8458fd0570 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -177,7 +177,13 @@ impl AggregateUDF { ignore_nulls: bool, requirement_satisfied: bool, ) -> Result> { - self.inner.accumulator(return_type, sort_exprs, schema, ignore_nulls, requirement_satisfied) + self.inner.accumulator( + return_type, + sort_exprs, + schema, + ignore_nulls, + requirement_satisfied, + ) } /// Return the type of the intermediate state used by this aggregator, given @@ -388,7 +394,13 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { ignore_nulls: bool, requirement_satisfied: bool, ) -> Result> { - self.inner.accumulator(arg, sort_exprs, schema, ignore_nulls, requirement_satisfied) + self.inner.accumulator( + arg, + sort_exprs, + schema, + ignore_nulls, + requirement_satisfied, + ) } fn state_type(&self, return_type: &DataType) -> Result> { diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index d555628632ef..5bc20e27f392 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -36,8 +36,6 @@ use datafusion_common::{ }; use datafusion_expr::{Accumulator, AggregateUDF, Expr}; - - // Equivalent to AggregateFunctionExpr #[derive(Debug, Clone)] pub struct FirstValueUDF { @@ -152,7 +150,13 @@ impl AggregateExpr for FirstValueUDF { } fn create_accumulator(&self) -> Result> { - self.fun.accumulator(&self.return_type, &self.sort_exprs, &self.schema, self.ignore_nulls, self.requirement_satisfied) + self.fun.accumulator( + &self.return_type, + &self.sort_exprs, + &self.schema, + self.ignore_nulls, + self.requirement_satisfied, + ) } fn state_fields(&self) -> Result> { @@ -220,7 +224,6 @@ impl PartialEq for FirstValueUDF { } } - /// FIRST_VALUE aggregate expression #[derive(Debug, Clone)] pub struct FirstValue { @@ -614,9 +617,13 @@ pub fn create_first_value_accumulator( .map(|e| e.expr.data_type(schema)) .collect::>>()?; - - FirstValueAccumulator::try_new(data_type, &ordering_dtypes, ordering_req, ignore_nulls) - .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) + FirstValueAccumulator::try_new( + data_type, + &ordering_dtypes, + ordering_req, + ignore_nulls, + ) + .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) // Ok(Box::new(acc)) } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 8a4b1bda67f8..62b03a727581 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -53,7 +53,9 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::first_last::{FirstValue, FirstValueUDF, FirstValueAccumulator, LastValue}; +pub use crate::aggregate::first_last::{ + FirstValue, FirstValueAccumulator, FirstValueUDF, LastValue, +}; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 717346157322..f97452e4d7b6 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -228,14 +228,23 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - self.fun - .accumulator(&self.data_type, self.sort_exprs.as_slice(), &self.schema, self.ignore_nulls, self.requirement_satisfied) + self.fun.accumulator( + &self.data_type, + self.sort_exprs.as_slice(), + &self.schema, + self.ignore_nulls, + self.requirement_satisfied, + ) } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = - self.fun - .accumulator(&self.data_type, &self.sort_exprs, &self.schema, self.ignore_nulls, self.requirement_satisfied)?; + let accumulator = self.fun.accumulator( + &self.data_type, + &self.sort_exprs, + &self.schema, + self.ignore_nulls, + self.requirement_satisfied, + )?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index a45b08c333da..60707e3e3bf1 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -520,7 +520,8 @@ impl AsExecutionPlan for PhysicalPlanNode { // TODO: `order by` is not supported for UDAF yet let sort_exprs = &[]; let ordering_req = &[]; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name) + let ignore_nulls = false; + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls) } } }).transpose()?.ok_or_else(|| { From 6339535390e2371592968c7daa92fe4a4046d88a Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 17:28:19 +0800 Subject: [PATCH 25/46] cleanup Signed-off-by: jayzhan211 --- .../user_defined/user_defined_aggregates.rs | 2 ++ .../physical-expr/src/aggregate/first_last.rs | 22 ++----------------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 16116d2fcf58..184ed99c3e02 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -820,6 +820,8 @@ impl AggregateUDFImpl for TestGroupsAccumulator { _arg: &DataType, _sort_exprs: &[Expr], _schema: &Schema, + _ignore_nulls: bool, + _requirement_satisfied: bool, ) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 5bc20e27f392..5bfd225f3a36 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -65,7 +65,7 @@ impl FirstValueUDF { sort_exprs: Vec, schema: Schema, ) -> Self { - let requirement_satisfied = ordering_req.is_empty(); + let requirement_satisfied = sort_exprs.is_empty(); Self { fun, name: name.into(), @@ -160,25 +160,7 @@ impl AggregateExpr for FirstValueUDF { } fn state_fields(&self) -> Result> { - if !self.state_fields.is_empty() { - return Ok(self.state_fields.clone()); - } - - let mut fields = vec![Field::new( - format_state_name(&self.name, "first_value"), - self.return_type.clone(), - true, - )]; - fields.extend(ordering_fields( - &self.ordering_req, - &self.order_by_data_types, - )); - fields.push(Field::new( - format_state_name(&self.name, "is_set"), - DataType::Boolean, - true, - )); - Ok(fields) + Ok(self.state_fields.clone()) } fn expressions(&self) -> Vec> { From 33ae6eea479caae11bab189570ba8dce7254970d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 17:34:22 +0800 Subject: [PATCH 26/46] small fix Signed-off-by: jayzhan211 --- datafusion/proto/tests/cases/roundtrip_physical_plan.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 33dbb92d8d3a..a032aaa451f3 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -436,6 +436,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { &[], &schema, "example_agg", + false, )?]; roundtrip_test_with_context( From b4eb865d0ce3d7a7232658908ed149c935d584cb Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 20:32:44 +0800 Subject: [PATCH 27/46] remove ordering types Signed-off-by: jayzhan211 --- Cargo.toml | 2 +- datafusion-examples/examples/advanced_udaf.rs | 2 + datafusion/aggregate-functions/src/lib.rs | 9 ++- datafusion/core/src/physical_planner.rs | 2 - .../physical-expr/src/aggregate/first_last.rs | 45 ++++++-------- datafusion/physical-plan/src/udaf.rs | 60 ++++++++++--------- 6 files changed, 58 insertions(+), 62 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8073d056d712..aa88da9909dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,13 +73,13 @@ chrono = { version = "0.4.34", default-features = false } ctor = "0.2.0" dashmap = "5.4.0" datafusion = { path = "datafusion/core", version = "36.0.0", default-features = false } +datafusion-aggrefate-functions = { path = "datafusion/aggreagate-functions", version = "36.0.0" } datafusion-common = { path = "datafusion/common", version = "36.0.0", default-features = false } datafusion-common-runtime = { path = "datafusion/common-runtime", version = "36.0.0" } datafusion-execution = { path = "datafusion/execution", version = "36.0.0" } datafusion-expr = { path = "datafusion/expr", version = "36.0.0" } datafusion-functions = { path = "datafusion/functions", version = "36.0.0" } datafusion-functions-array = { path = "datafusion/functions-array", version = "36.0.0" } -datafusion-aggrefate-functions = { path = "datafusion/aggreagate-functions", version = "36.0.0" } datafusion-optimizer = { path = "datafusion/optimizer", version = "36.0.0", default-features = false } datafusion-physical-expr = { path = "datafusion/physical-expr", version = "36.0.0", default-features = false } datafusion-physical-plan = { path = "datafusion/physical-plan", version = "36.0.0" } diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 06995863a245..26476996f50b 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -91,6 +91,8 @@ impl AggregateUDFImpl for GeoMeanUdaf { _arg: &DataType, _sort_exprs: &[Expr], _schema: &Schema, + _ignore_nulls: bool, + _requirement_satisfied: bool, ) -> Result> { Ok(Box::new(GeometricMean::new())) } diff --git a/datafusion/aggregate-functions/src/lib.rs b/datafusion/aggregate-functions/src/lib.rs index 02b0007ad121..b7d9c168e504 100644 --- a/datafusion/aggregate-functions/src/lib.rs +++ b/datafusion/aggregate-functions/src/lib.rs @@ -31,16 +31,15 @@ mod core; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; -use datafusion_expr::AggregateUDF; -use log::debug; -use std::sync::Arc; +// use datafusion_expr::AggregateUDF; +// use std::sync::Arc; /// Fluent-style API for creating `Expr`s pub mod expr_fn {} /// Registers all enabled packages with a [`FunctionRegistry`] -pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { - let functions: Vec> = vec![]; +pub fn register_all(_registry: &mut dyn FunctionRegistry) -> Result<()> { + // let functions: Vec> = vec![]; // functions.into_iter().try_for_each(|udf| { // let existing_udf = registry.register_udf(udf)?; // if let Some(existing_udf) = existing_udf { diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index f5aa9cf93fe8..47e7d9c9d0ff 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1734,11 +1734,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let agg_expr = udaf::create_aggregate_expr_first_value( fun, origin_args, - // &args, &sort_exprs, logical_input_schema, execution_props, - // &ordering_reqs, physical_input_schema, name, ignore_nulls, diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 5bfd225f3a36..7ec44f94f933 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -37,47 +37,46 @@ use datafusion_common::{ use datafusion_expr::{Accumulator, AggregateUDF, Expr}; // Equivalent to AggregateFunctionExpr +// TODO: Make it similar to AggregateFunctionExpr #[derive(Debug, Clone)] pub struct FirstValueUDF { fun: AggregateUDF, name: String, return_type: DataType, - order_by_data_types: Vec, expr: Arc, ordering_req: LexOrdering, - requirement_satisfied: bool, ignore_nulls: bool, state_fields: Vec, sort_exprs: Vec, schema: Schema, + requirement_satisfied: bool, } impl FirstValueUDF { /// Creates a new FIRST_VALUE aggregation function. + #[allow(clippy::too_many_arguments)] // TODO: Fix this clippy if better pub fn new( fun: AggregateUDF, expr: Arc, name: impl Into, return_type: DataType, ordering_req: LexOrdering, - order_by_data_types: Vec, state_fields: Vec, sort_exprs: Vec, schema: Schema, + requirement_satisfied: bool, ) -> Self { - let requirement_satisfied = sort_exprs.is_empty(); Self { fun, name: name.into(), return_type, - order_by_data_types, expr, ordering_req, - requirement_satisfied, ignore_nulls: false, state_fields, sort_exprs, schema, + requirement_satisfied, } } @@ -96,11 +95,6 @@ impl FirstValueUDF { &self.return_type } - /// Returns the data types of the order-by columns. - pub fn order_by_data_types(&self) -> &Vec { - &self.order_by_data_types - } - /// Returns the expression associated with the aggregate function. pub fn expr(&self) -> &Arc { &self.expr @@ -116,7 +110,7 @@ impl FirstValueUDF { self } - pub fn convert_to_last(self) -> LastValue { + pub fn convert_to_last(self) -> Result { let name = if self.name.starts_with("FIRST") { format!("LAST{}", &self.name[5..]) } else { @@ -126,16 +120,21 @@ impl FirstValueUDF { expr, return_type, ordering_req, - order_by_data_types, .. } = self; - LastValue::new( + + let order_by_data_types = ordering_req + .iter() + .map(|e| e.expr.data_type(&self.schema)) + .collect::>>()?; + + Ok(LastValue::new( expr, name, return_type, reverse_order_bys(&ordering_req), order_by_data_types, - ) + )) } } @@ -176,19 +175,14 @@ impl AggregateExpr for FirstValueUDF { } fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone().convert_to_last())) + self.clone() + .convert_to_last() + .ok() + .map(|l| Arc::new(l) as _) } fn create_sliding_accumulator(&self) -> Result> { - FirstValueAccumulator::try_new( - &self.return_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, - ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) + self.create_accumulator() } } @@ -199,7 +193,6 @@ impl PartialEq for FirstValueUDF { .map(|x| { self.name == x.name && self.return_type == x.return_type - && self.order_by_data_types == x.order_by_data_types && self.expr.eq(&x.expr) }) .unwrap_or(false) diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index f97452e4d7b6..868539eb9049 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -67,32 +67,8 @@ pub fn create_aggregate_expr( })) } -// TODO: Duplicated functoin from `datafusion/core/src/physical_planner.rs`, remove one of them -// TODO: Maybe move to physical-expr -/// Create a physical sort expression from a logical expression -pub fn create_physical_sort_expr( - e: &Expr, - input_dfschema: &DFSchema, - execution_props: &ExecutionProps, -) -> Result { - if let Expr::Sort(expr::Sort { - expr, - asc, - nulls_first, - }) = e - { - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } else { - internal_err!("Expects a sort expression") - } -} - +//TODO: fix this clippy +#[allow(clippy::too_many_arguments)] pub fn create_aggregate_expr_first_value( fun: &AggregateUDF, args: &[Expr], @@ -136,19 +112,21 @@ pub fn create_aggregate_expr_first_value( let state_fields = fun.state_fields(value_field, ordering_fields)?; + let requirement_satisfied = sort_exprs.is_empty(); let first_value = expressions::FirstValueUDF::new( fun.clone(), args[0].clone(), name, return_type, ordering_req.to_vec(), - ordering_types, state_fields, sort_exprs.to_vec(), schema.clone(), + requirement_satisfied, ) .with_ignore_nulls(ignore_nulls); - return Ok(Arc::new(first_value)); + + Ok(Arc::new(first_value)) } // TODO: Duplicated functoin. @@ -171,6 +149,32 @@ fn ordering_fields( .collect() } +// TODO: Duplicated functoin from `datafusion/core/src/physical_planner.rs`, remove one of them +// TODO: Maybe move to physical-expr +/// Create a physical sort expression from a logical expression +pub fn create_physical_sort_expr( + e: &Expr, + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result { + if let Expr::Sort(expr::Sort { + expr, + asc, + nulls_first, + }) = e + { + Ok(PhysicalSortExpr { + expr: create_physical_expr(expr, input_dfschema, execution_props)?, + options: SortOptions { + descending: !asc, + nulls_first: *nulls_first, + }, + }) + } else { + internal_err!("Expects a sort expression") + } +} + /// Physical aggregate expression of a UDAF. #[derive(Debug)] pub struct AggregateFunctionExpr { From d8ab6c57b83a4556b036d0188b25a67633397810 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 21:18:05 +0800 Subject: [PATCH 28/46] make state fields more flexible Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 15 -- datafusion/expr/src/expr_fn.rs | 14 +- datafusion/expr/src/udaf.rs | 25 ++- .../physical-expr/src/aggregate/utils.rs | 2 +- datafusion/physical-plan/src/udaf.rs | 151 +++--------------- 5 files changed, 49 insertions(+), 158 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 47e7d9c9d0ff..9ade578ed82f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1729,21 +1729,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( None => None, }; - // TODO: fix the name - if fun.name() == "my_first" { - let agg_expr = udaf::create_aggregate_expr_first_value( - fun, - origin_args, - &sort_exprs, - logical_input_schema, - execution_props, - physical_input_schema, - name, - ignore_nulls, - )?; - return Ok((agg_expr, filter, order_by)); - } - let ordering_reqs: Vec = order_by.clone().unwrap_or(vec![]); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index d4c9874a0794..30af228d4aa5 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -24,6 +24,7 @@ use crate::expr::{ use crate::function::{ AccumulatorFactoryFunctionForFirstValue, PartitionEvaluatorFactory, }; +use crate::udaf::format_state_name; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, @@ -971,11 +972,16 @@ impl AggregateUDFImpl for FirstValue { fn state_fields( &self, - value_field: Field, - ordering_field: Vec, + name: &str, + value_type: DataType, + ordering_fields: Vec, ) -> Result> { - let mut fields = vec![value_field]; - fields.extend(ordering_field); + let mut fields = vec![Field::new( + format_state_name(name, "first_value"), + value_type, + true, + )]; + fields.extend(ordering_fields); fields.push(Field::new("is_set", DataType::Boolean, true)); Ok(fields) } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 7d8458fd0570..19cdf468c057 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -194,10 +194,11 @@ impl AggregateUDF { pub fn state_fields( &self, - value_field: Field, - ordering_field: Vec, + name: &str, + value_type: DataType, + ordering_fields: Vec, ) -> Result> { - self.inner.state_fields(value_field, ordering_field) + self.inner.state_fields(name, value_type, ordering_fields) } /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. @@ -316,11 +317,18 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Default fields including the value field and ordering fields fn state_fields( &self, - value_field: Field, - ordering_field: Vec, + name: &str, + value_type: DataType, + ordering_fields: Vec, ) -> Result> { + let value_field = Field::new( + format_state_name(name, "default_state_name"), + value_type, + true, + ); + let mut fields = vec![value_field]; - fields.extend(ordering_field); + fields.extend(ordering_fields); Ok(fields) } @@ -480,3 +488,8 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { // not_impl_err!("state_fields not implemented for legacy AggregateUDF") // } } + +/// returns the name of the state +pub(crate) fn format_state_name(name: &str, state_name: &str) -> String { + format!("{name}[{state_name}]") +} diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index d3f494fe6356..296669ac3a99 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -189,7 +189,7 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { // TOOD: Remove /// Construct corresponding fields for lexicographical ordering requirement expression -pub(crate) fn ordering_fields( +pub fn ordering_fields( ordering_req: &[PhysicalSortExpr], // Data type of each expression in the ordering requirement data_types: &[DataType], diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 868539eb9049..a17209202b4d 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -17,23 +17,19 @@ //! This module contains functions and structs supporting user-defined aggregate functions. -use arrow_schema::SortOptions; -use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::{expr, Expr, GroupsAccumulator}; +use datafusion_expr::{Expr, GroupsAccumulator}; use fmt::Debug; use std::any::Any; use std::fmt; use arrow::datatypes::{DataType, Field, Schema}; -use super::{expressions::format_state_name, Accumulator, AggregateExpr}; -use datafusion_common::{internal_err, not_impl_err, DFSchema, Result}; +use super::{Accumulator, AggregateExpr}; +use datafusion_common::{not_impl_err, Result}; pub use datafusion_expr::AggregateUDF; -use datafusion_physical_expr::{ - create_physical_expr, expressions, LexOrdering, PhysicalExpr, PhysicalSortExpr, -}; +use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; -use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; +use datafusion_physical_expr::aggregate::utils::{down_cast_any_ref, ordering_fields}; use std::sync::Arc; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. @@ -54,6 +50,13 @@ pub fn create_aggregate_expr( let requirement_satisfied = ordering_req.is_empty(); + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(schema)) + .collect::>>()?; + + let ordering_fields = ordering_fields(ordering_req, &ordering_types); + Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), args: input_phy_exprs.to_vec(), @@ -64,117 +67,10 @@ pub fn create_aggregate_expr( ordering_req: ordering_req.to_vec(), ignore_nulls, requirement_satisfied, + ordering_fields, })) } -//TODO: fix this clippy -#[allow(clippy::too_many_arguments)] -pub fn create_aggregate_expr_first_value( - fun: &AggregateUDF, - args: &[Expr], - // input_phy_exprs: &[Arc], - sort_exprs: &[Expr], - dfschema: &DFSchema, - execution_props: &ExecutionProps, - // ordering_req: &[PhysicalSortExpr], - schema: &Schema, - name: impl Into, - ignore_nulls: bool, -) -> Result> { - let args = args - .iter() - .map(|e| create_physical_expr(e, dfschema, execution_props)) - .collect::>>()?; - let input_exprs_types = args - .iter() - .map(|arg| arg.data_type(schema)) - .collect::>>()?; - - let ordering_req = sort_exprs - .iter() - .map(|e| create_physical_sort_expr(e, dfschema, execution_props)) - .collect::>>()?; - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(schema)) - .collect::>>()?; - - let name: String = name.into(); - - let return_type = fun.return_type(&input_exprs_types)?; - - let value_field = Field::new( - format_state_name(&name, "first_value"), - return_type.clone(), - true, - ); - let ordering_fields = ordering_fields(&ordering_req, &ordering_types); - - let state_fields = fun.state_fields(value_field, ordering_fields)?; - - let requirement_satisfied = sort_exprs.is_empty(); - let first_value = expressions::FirstValueUDF::new( - fun.clone(), - args[0].clone(), - name, - return_type, - ordering_req.to_vec(), - state_fields, - sort_exprs.to_vec(), - schema.clone(), - requirement_satisfied, - ) - .with_ignore_nulls(ignore_nulls); - - Ok(Arc::new(first_value)) -} - -// TODO: Duplicated functoin. -fn ordering_fields( - ordering_req: &[PhysicalSortExpr], - // Data type of each expression in the ordering requirement - data_types: &[DataType], -) -> Vec { - ordering_req - .iter() - .zip(data_types.iter()) - .map(|(sort_expr, dtype)| { - Field::new( - sort_expr.expr.to_string().as_str(), - dtype.clone(), - // Multi partitions may be empty hence field should be nullable. - true, - ) - }) - .collect() -} - -// TODO: Duplicated functoin from `datafusion/core/src/physical_planner.rs`, remove one of them -// TODO: Maybe move to physical-expr -/// Create a physical sort expression from a logical expression -pub fn create_physical_sort_expr( - e: &Expr, - input_dfschema: &DFSchema, - execution_props: &ExecutionProps, -) -> Result { - if let Expr::Sort(expr::Sort { - expr, - asc, - nulls_first, - }) = e - { - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } else { - internal_err!("Expects a sort expression") - } -} - /// Physical aggregate expression of a UDAF. #[derive(Debug)] pub struct AggregateFunctionExpr { @@ -190,6 +86,7 @@ pub struct AggregateFunctionExpr { ordering_req: LexOrdering, ignore_nulls: bool, requirement_satisfied: bool, + ordering_fields: Vec, } impl AggregateFunctionExpr { @@ -210,21 +107,11 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - let fields = self - .fun - .state_type(&self.data_type)? - .iter() - .enumerate() - .map(|(i, data_type)| { - Field::new( - format_state_name(&self.name, &format!("{i}")), - data_type.clone(), - true, - ) - }) - .collect::>(); - - Ok(fields) + self.fun.state_fields( + self.name(), + self.data_type.clone(), + self.ordering_fields.clone(), + ) } fn field(&self) -> Result { From a3bff42bdeddeb2fdc984940167c72a689f5f99e Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 21:50:53 +0800 Subject: [PATCH 29/46] cleanup Signed-off-by: jayzhan211 --- Cargo.toml | 2 - datafusion/aggregate-functions/Cargo.toml | 49 -------------------- datafusion/aggregate-functions/src/core.rs | 16 ------- datafusion/aggregate-functions/src/lib.rs | 52 ---------------------- datafusion/core/src/physical_planner.rs | 35 ++++++--------- 5 files changed, 14 insertions(+), 140 deletions(-) delete mode 100644 datafusion/aggregate-functions/Cargo.toml delete mode 100644 datafusion/aggregate-functions/src/core.rs delete mode 100644 datafusion/aggregate-functions/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index aa88da9909dc..c3dade8bc6c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,6 @@ members = [ "datafusion/execution", "datafusion/functions", "datafusion/functions-array", - "datafusion/aggregate-functions", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", @@ -73,7 +72,6 @@ chrono = { version = "0.4.34", default-features = false } ctor = "0.2.0" dashmap = "5.4.0" datafusion = { path = "datafusion/core", version = "36.0.0", default-features = false } -datafusion-aggrefate-functions = { path = "datafusion/aggreagate-functions", version = "36.0.0" } datafusion-common = { path = "datafusion/common", version = "36.0.0", default-features = false } datafusion-common-runtime = { path = "datafusion/common-runtime", version = "36.0.0" } datafusion-execution = { path = "datafusion/execution", version = "36.0.0" } diff --git a/datafusion/aggregate-functions/Cargo.toml b/datafusion/aggregate-functions/Cargo.toml deleted file mode 100644 index b3be7410934e..000000000000 --- a/datafusion/aggregate-functions/Cargo.toml +++ /dev/null @@ -1,49 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "datafusion-aggregate-functions" -description = "Aggregate Function packages for the DataFusion query engine" -keywords = ["datafusion", "logical", "plan", "expressions"] -readme = "README.md" -version = { workspace = true } -edition = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } -license = { workspace = true } -authors = { workspace = true } -rust-version = { workspace = true } - -[lib] -name = "datafusion_aggregate_functions" -path = "src/lib.rs" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -arrow = { workspace = true } -arrow-array = { workspace = true } -arrow-buffer = { workspace = true } -arrow-ord = { workspace = true } -arrow-schema = { workspace = true } -datafusion-common = { workspace = true } -datafusion-execution = { workspace = true } -datafusion-expr = { workspace = true } -datafusion-functions = { workspace = true } -itertools = { version = "0.12", features = ["use_std"] } -log = { workspace = true } -paste = "1.0.14" diff --git a/datafusion/aggregate-functions/src/core.rs b/datafusion/aggregate-functions/src/core.rs deleted file mode 100644 index b248758bc120..000000000000 --- a/datafusion/aggregate-functions/src/core.rs +++ /dev/null @@ -1,16 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. diff --git a/datafusion/aggregate-functions/src/lib.rs b/datafusion/aggregate-functions/src/lib.rs deleted file mode 100644 index b7d9c168e504..000000000000 --- a/datafusion/aggregate-functions/src/lib.rs +++ /dev/null @@ -1,52 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Array Functions for [DataFusion]. -//! -//! This crate contains a collection of array functions implemented using the -//! extension API. -//! -//! [DataFusion]: https://crates.io/crates/datafusion -//! -//! You can register the functions in this crate using the [`register_all`] function. -//! -//! -//! - -mod core; - -use datafusion_common::Result; -use datafusion_execution::FunctionRegistry; -// use datafusion_expr::AggregateUDF; -// use std::sync::Arc; - -/// Fluent-style API for creating `Expr`s -pub mod expr_fn {} - -/// Registers all enabled packages with a [`FunctionRegistry`] -pub fn register_all(_registry: &mut dyn FunctionRegistry) -> Result<()> { - // let functions: Vec> = vec![]; - // functions.into_iter().try_for_each(|udf| { - // let existing_udf = registry.register_udf(udf)?; - // if let Some(existing_udf) = existing_udf { - // debug!("Overwrite existing UDF: {}", existing_udf.name()); - // } - // Ok(()) as Result<()> - // })?; - - Ok(()) -} diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 9ade578ed82f..9f8428a8ebce 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1661,11 +1661,16 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( Expr::AggregateFunction(AggregateFunction { func_def, distinct, - args: origin_args, + args, filter, order_by, null_treatment, }) => { + let args = args + .iter() + .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) + .collect::>>()?; + let filter = match filter { Some(e) => Some(create_physical_expr( e, @@ -1680,13 +1685,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { - let args = origin_args - .iter() - .map(|e| { - create_physical_expr(e, logical_input_schema, execution_props) - }) - .collect::>>()?; - let order_by = match order_by { + let physical_sort_exprs = match order_by { Some(e) => Some( e.iter() .map(|expr| { @@ -1700,7 +1699,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ), None => None, }; - let ordering_reqs = order_by.clone().unwrap_or(vec![]); + + let ordering_reqs: Vec = + physical_sort_exprs.clone().unwrap_or(vec![]); let agg_expr = aggregates::create_aggregate_expr( fun, *distinct, @@ -1710,11 +1711,11 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( name, ignore_nulls, )?; - (agg_expr, filter, order_by) + (agg_expr, filter, physical_sort_exprs) } AggregateFunctionDefinition::UDF(fun) => { let sort_exprs = order_by.clone().unwrap_or(vec![]); - let order_by = match order_by { + let physical_sort_exprs = match order_by { Some(e) => Some( e.iter() .map(|expr| { @@ -1730,15 +1731,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( }; let ordering_reqs: Vec = - order_by.clone().unwrap_or(vec![]); - - let args = origin_args - .iter() - .map(|e| { - create_physical_expr(e, logical_input_schema, execution_props) - }) - .collect::>>()?; - + physical_sort_exprs.clone().unwrap_or(vec![]); let agg_expr = udaf::create_aggregate_expr( fun, &args, @@ -1748,7 +1741,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( name, ignore_nulls, )?; - (agg_expr, filter, order_by) + (agg_expr, filter, physical_sort_exprs) } AggregateFunctionDefinition::Name(_) => { return internal_err!( From 53465fd009ccf83df9d07e4535bf81ef6cc459ee Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 22:08:38 +0800 Subject: [PATCH 30/46] replace done Signed-off-by: jayzhan211 --- datafusion/core/src/execution/context/mod.rs | 11 +- .../physical-expr/src/aggregate/first_last.rs | 165 +----------------- .../physical-expr/src/aggregate/utils.rs | 1 - .../physical-expr/src/expressions/mod.rs | 4 +- datafusion/sqllogictest/test_files/test1.slt | 7 +- 5 files changed, 6 insertions(+), 182 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 11e975a95c8a..f06e69bb5341 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1460,17 +1460,12 @@ impl SessionState { datafusion_functions_array::register_all(&mut new_self) .expect("can not register array expressions"); - // TODO: FIX the name - let my_first = create_first_value( - "my_first", - // vec![DataType::Int32], - // Arc::new(DataType::Int32), - // Volatility::Immutable, + let first_value = create_first_value( + "FIRST_VALUE", Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), Arc::new(create_first_value_accumulator), - // Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), ); - let _ = new_self.register_udaf(Arc::new(my_first)); + let _ = new_self.register_udaf(Arc::new(first_value)); new_self } diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 7ec44f94f933..25e2156ff562 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -34,170 +34,7 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::{Accumulator, AggregateUDF, Expr}; - -// Equivalent to AggregateFunctionExpr -// TODO: Make it similar to AggregateFunctionExpr -#[derive(Debug, Clone)] -pub struct FirstValueUDF { - fun: AggregateUDF, - name: String, - return_type: DataType, - expr: Arc, - ordering_req: LexOrdering, - ignore_nulls: bool, - state_fields: Vec, - sort_exprs: Vec, - schema: Schema, - requirement_satisfied: bool, -} - -impl FirstValueUDF { - /// Creates a new FIRST_VALUE aggregation function. - #[allow(clippy::too_many_arguments)] // TODO: Fix this clippy if better - pub fn new( - fun: AggregateUDF, - expr: Arc, - name: impl Into, - return_type: DataType, - ordering_req: LexOrdering, - state_fields: Vec, - sort_exprs: Vec, - schema: Schema, - requirement_satisfied: bool, - ) -> Self { - Self { - fun, - name: name.into(), - return_type, - expr, - ordering_req, - ignore_nulls: false, - state_fields, - sort_exprs, - schema, - requirement_satisfied, - } - } - - pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { - self.ignore_nulls = ignore_nulls; - self - } - - /// Returns the name of the aggregate expression. - pub fn name(&self) -> &str { - &self.name - } - - /// Returns the input data type of the aggregate expression. - pub fn return_type(&self) -> &DataType { - &self.return_type - } - - /// Returns the expression associated with the aggregate function. - pub fn expr(&self) -> &Arc { - &self.expr - } - - /// Returns the lexical ordering requirements of the aggregate expression. - pub fn ordering_req(&self) -> &LexOrdering { - &self.ordering_req - } - - pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { - self.requirement_satisfied = requirement_satisfied; - self - } - - pub fn convert_to_last(self) -> Result { - let name = if self.name.starts_with("FIRST") { - format!("LAST{}", &self.name[5..]) - } else { - format!("LAST_VALUE({})", self.expr) - }; - let FirstValueUDF { - expr, - return_type, - ordering_req, - .. - } = self; - - let order_by_data_types = ordering_req - .iter() - .map(|e| e.expr.data_type(&self.schema)) - .collect::>>()?; - - Ok(LastValue::new( - expr, - name, - return_type, - reverse_order_bys(&ordering_req), - order_by_data_types, - )) - } -} - -impl AggregateExpr for FirstValueUDF { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.return_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - self.fun.accumulator( - &self.return_type, - &self.sort_exprs, - &self.schema, - self.ignore_nulls, - self.requirement_satisfied, - ) - } - - fn state_fields(&self) -> Result> { - Ok(self.state_fields.clone()) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } - - fn name(&self) -> &str { - &self.name - } - - fn reverse_expr(&self) -> Option> { - self.clone() - .convert_to_last() - .ok() - .map(|l| Arc::new(l) as _) - } - - fn create_sliding_accumulator(&self) -> Result> { - self.create_accumulator() - } -} - -impl PartialEq for FirstValueUDF { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.return_type == x.return_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} +use datafusion_expr::{Accumulator, Expr}; /// FIRST_VALUE aggregate expression #[derive(Debug, Clone)] diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index 296669ac3a99..613f6118e907 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -187,7 +187,6 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { } } -// TOOD: Remove /// Construct corresponding fields for lexicographical ordering requirement expression pub fn ordering_fields( ordering_req: &[PhysicalSortExpr], diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 62b03a727581..fcd656173355 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -53,9 +53,7 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::first_last::{ - FirstValue, FirstValueAccumulator, FirstValueUDF, LastValue, -}; +pub use crate::aggregate::first_last::{FirstValue, FirstValueAccumulator, LastValue}; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; diff --git a/datafusion/sqllogictest/test_files/test1.slt b/datafusion/sqllogictest/test_files/test1.slt index 21099576309b..ed132d2676d2 100644 --- a/datafusion/sqllogictest/test_files/test1.slt +++ b/datafusion/sqllogictest/test_files/test1.slt @@ -4,9 +4,4 @@ CREATE TABLE t AS VALUES (null::bigint), (3), (4); query I SELECT first_value(column1) FROM t; ---- -NULL - -query I -SELECT my_first(column1) FROM t; ----- -NULL +NULL \ No newline at end of file From cc214963865ae7921ec9c6a28d93e70ceee87418 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 22:23:08 +0800 Subject: [PATCH 31/46] cleanup Signed-off-by: jayzhan211 --- .../user_defined/user_defined_aggregates.rs | 102 +------------ datafusion/expr/src/expr_fn.rs | 143 +----------------- datafusion/expr/src/function.rs | 2 + datafusion/expr/src/udaf.rs | 4 +- datafusion/physical-plan/src/udaf.rs | 2 +- 5 files changed, 8 insertions(+), 245 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 184ed99c3e02..6852f3dc75be 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -45,11 +45,9 @@ use datafusion::{ }; use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, create_udaf_with_ordering, AggregateUDFImpl, Expr, GroupsAccumulator, - SimpleAggregateUDF, + create_udaf, AggregateUDFImpl, Expr, GroupsAccumulator, SimpleAggregateUDF, }; -use datafusion_physical_expr::expressions::{self, FirstValueAccumulator}; -use datafusion_physical_expr::{expressions::AvgAccumulator, PhysicalSortExpr}; +use datafusion_physical_expr::expressions::AvgAccumulator; /// Test to show the contents of the setup #[tokio::test] @@ -211,102 +209,6 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } -#[tokio::test] -async fn simple_udaf_order() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3, 4])), - Arc::new(Int32Array::from(vec![1, 1, 2, 2])), - ], - )?; - - let ctx = SessionContext::new(); - - let provider = MemTable::try_new(Arc::new(schema.clone()), vec![vec![batch]])?; - ctx.register_table("t", Arc::new(provider))?; - - fn create_accumulator( - data_type: &DataType, - order_by: &[Expr], - schema: &Schema, - ) -> Result> { - let mut all_sort_orders = vec![]; - - // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = vec![]; - for expr in order_by { - if let Expr::Sort(sort) = expr { - if let Expr::Column(col) = sort.expr.as_ref() { - let name = &col.name; - let e = expressions::col(name, schema)?; - sort_exprs.push(PhysicalSortExpr { - expr: e, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); - } - } - } - if !sort_exprs.is_empty() { - all_sort_orders.extend(sort_exprs); - } - - let ordering_req = all_sort_orders; - - let ordering_dtypes = ordering_req - .iter() - .map(|e| e.expr.data_type(schema)) - .collect::>>()?; - - let acc = FirstValueAccumulator::try_new( - data_type, - &ordering_dtypes, - ordering_req, - false, - )?; - Ok(Box::new(acc)) - } - - // define a udaf, using a DataFusion's accumulator - let my_first = create_udaf_with_ordering( - "my_first", - vec![DataType::Int32], - Arc::new(DataType::Int32), - Volatility::Immutable, - Arc::new(create_accumulator), - Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), - ); - - ctx.register_udaf(my_first); - - // Should be the same as `SELECT FIRST_VALUE(a order by a) FROM t group by b order by b` - let result = ctx - .sql("SELECT MY_FIRST(a order by a desc) FROM t group by b order by b") - .await? - .collect() - .await?; - - let expected = [ - "+---------------+", - "| my_first(t.a) |", - "+---------------+", - "| 2 |", - "| 4 |", - "+---------------+", - ]; - assert_batches_eq!(expected, &result); - - Ok(()) -} - /// tests the creation, registration and usage of a UDAF #[tokio::test] async fn simple_udaf() -> Result<()> { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 30af228d4aa5..707bbf9cb9f4 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -749,52 +749,14 @@ pub fn create_udaf( )) } -/// Creates a new UDAF with a specific signature, state type and return type. -/// The signature and state type must match the `Accumulator's implementation`. -pub fn create_udaf_with_ordering( - name: &str, - input_type: Vec, - return_type: Arc, - volatility: Volatility, - accumulator: AccumulatorFactoryFunction, - state_type: Arc>, -) -> AggregateUDF { - let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); - let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); - - AggregateUDF::from(SimpleOrderedAggregateUDF::new( - name, - input_type, - return_type, - volatility, - accumulator, - state_type, - )) -} - /// Creates a new UDAF with a specific signature, state type and return type. /// The signature and state type must match the `Accumulator's implementation`. pub fn create_first_value( name: &str, - // input_type: Vec, - // return_type: Arc, - // volatility: Volatility, signature: Signature, accumulator: AccumulatorFactoryFunctionForFirstValue, - // state_type: Arc>, ) -> AggregateUDF { - // let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); - // let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); - - AggregateUDF::from(FirstValue::new( - name, - signature, - // input_type, - // return_type, - // volatility, - accumulator, - // state_type, - )) + AggregateUDF::from(FirstValue::new(name, signature, accumulator)) } /// Implements [`AggregateUDFImpl`] for functions that have a single signature and @@ -888,23 +850,12 @@ impl AggregateUDFImpl for SimpleAggregateUDF { fn state_type(&self, _return_type: &DataType) -> Result> { Ok(self.state_type.clone()) } - - // fn state_fields(&self) -> Result> { - // Ok(self - // .state_type - // .iter() - // .enumerate() - // .map(|(i, data_type)| Field::new(format!("{i}"), data_type.clone(), true)) - // .collect()) - // } } pub struct FirstValue { name: String, signature: Signature, - // return_type: DataType, accumulator: AccumulatorFactoryFunctionForFirstValue, - // state_type: Vec, } impl Debug for FirstValue { @@ -919,21 +870,14 @@ impl Debug for FirstValue { impl FirstValue { pub fn new( name: impl Into, - // input_type: Vec, - // return_type: DataType, - // volatility: Volatility, signature: Signature, accumulator: AccumulatorFactoryFunctionForFirstValue, - // state_type: Vec, ) -> Self { let name = name.into(); - // let signature = Signature::exact(input_type, volatility); Self { name, signature, - // return_type, accumulator, - // state_type, } } } @@ -987,91 +931,6 @@ impl AggregateUDFImpl for FirstValue { } } -/// Implements [`AggregateUDFImpl`] for functions that have a single signature and -/// return type. -pub struct SimpleOrderedAggregateUDF { - name: String, - signature: Signature, - return_type: DataType, - accumulator: AccumulatorFactoryFunction, - state_type: Vec, -} - -impl Debug for SimpleOrderedAggregateUDF { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("AggregateUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } -} - -impl SimpleOrderedAggregateUDF { - /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and - /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility - pub fn new( - name: impl Into, - input_type: Vec, - return_type: DataType, - volatility: Volatility, - accumulator: AccumulatorFactoryFunction, - state_type: Vec, - ) -> Self { - let name = name.into(); - let signature = Signature::exact(input_type, volatility); - Self { - name, - signature, - return_type, - accumulator, - state_type, - } - } -} - -impl AggregateUDFImpl for SimpleOrderedAggregateUDF { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - &self.name - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(self.return_type.clone()) - } - - fn accumulator( - &self, - arg: &DataType, - sort_exprs: &[Expr], - schema: &Schema, - _ignore_nulls: bool, - _requirement_satisfied: bool, - ) -> Result> { - (self.accumulator)(arg, sort_exprs, schema) - } - - fn state_type(&self, _return_type: &DataType) -> Result> { - Ok(self.state_type.clone()) - } - - // fn state_fields(&self) -> Result> { - // Ok(self - // .state_type - // .iter() - // .enumerate() - // .map(|(i, data_type)| Field::new(format!("{i}"), data_type.clone(), true)) - // .collect()) - // } -} - /// Creates a new UDWF with a specific signature, state type and return type. /// /// The signature and state type must match the [`PartitionEvaluator`]'s implementation`. diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index bf5385e04a70..5e63138a0c24 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -46,6 +46,8 @@ pub type AccumulatorFactoryFunction = Arc< /// Factory that returns an accumulator for the given aggregate, given /// its return datatype, the sorting expressions and the schema for ordering. +/// FirstValue needs additional `ignore_nulls` and `requirement_satisfied` flags. +// TODO: It would be nice if we can have flexible design for arbitrary arguments. pub type AccumulatorFactoryFunctionForFirstValue = Arc< dyn Fn(&DataType, &[Expr], &Schema, bool, bool) -> Result> + Send diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 19cdf468c057..7da55bc6265e 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -171,14 +171,14 @@ impl AggregateUDF { /// Return an accumulator the given aggregate, given its return datatype pub fn accumulator( &self, - return_type: &DataType, + arg: &DataType, sort_exprs: &[Expr], schema: &Schema, ignore_nulls: bool, requirement_satisfied: bool, ) -> Result> { self.inner.accumulator( - return_type, + arg, sort_exprs, schema, ignore_nulls, diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index a17209202b4d..9989a51f4680 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -121,7 +121,7 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_accumulator(&self) -> Result> { self.fun.accumulator( &self.data_type, - self.sort_exprs.as_slice(), + &self.sort_exprs, &self.schema, self.ignore_nulls, self.requirement_satisfied, From b62544ff92ff27ab188848e0a4c017c19c0e9cd2 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 22:26:19 +0800 Subject: [PATCH 32/46] cleanup Signed-off-by: jayzhan211 --- .../core/tests/user_defined/user_defined_aggregates.rs | 2 +- datafusion/expr/src/udaf.rs | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 6852f3dc75be..2577db926fc4 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -20,7 +20,7 @@ use arrow::{array::AsArray, datatypes::Fields}; use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::{Schema, SortOptions}; +use arrow_schema::Schema; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 7da55bc6265e..3fdad5d1bc44 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -483,13 +483,10 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { let res = (self.state_type)(return_type)?; Ok(res.as_ref().clone()) } - - // fn state_fields(&self, value_field: Field, ordering_field: Vec) -> Result> { - // not_impl_err!("state_fields not implemented for legacy AggregateUDF") - // } } /// returns the name of the state +/// TODO: Remove duplicated function in physical-expr pub(crate) fn format_state_name(name: &str, state_name: &str) -> String { format!("{name}[{state_name}]") } From 4b809b08bb744af185ff4e0b8623bcf4042acfca Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 22:31:53 +0800 Subject: [PATCH 33/46] rm comments Signed-off-by: jayzhan211 --- datafusion/expr/src/udaf.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 3fdad5d1bc44..b65f7f30935f 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -415,10 +415,6 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.state_type(return_type) } - // fn state_fields(&self, value_field: Field, ordering_field: Vec) -> Result> { - // self.inner.state_fields(value_field, ordering_field) - // } - fn aliases(&self) -> &[String] { &self.aliases } From 253472706b86b06c825300d55eb3c248b92d17c2 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 22:41:29 +0800 Subject: [PATCH 34/46] cleanup Signed-off-by: jayzhan211 --- datafusion/core/tests/user_defined/user_defined_aggregates.rs | 4 ---- datafusion/expr/src/expr_fn.rs | 4 ++-- datafusion/physical-expr/src/aggregate/first_last.rs | 1 - 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 2577db926fc4..bfb93fb18389 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -740,10 +740,6 @@ impl AggregateUDFImpl for TestGroupsAccumulator { fn create_groups_accumulator(&self) -> Result> { Ok(Box::new(self.clone())) } - - // fn state_fields(&self) -> Result> { - // Ok(vec![Field::new("item", DataType::UInt64, true)]) - // } } impl Accumulator for TestGroupsAccumulator { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 00799802a155..f2e1471e0574 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -33,7 +33,7 @@ use crate::{ }; use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::{Column, Result}; +use datafusion_common::{internal_err, Column, Result}; use std::any::Any; use std::fmt::Debug; use std::ops::Not; @@ -884,7 +884,7 @@ impl AggregateUDFImpl for FirstValue { } fn state_type(&self, _return_type: &DataType) -> Result> { - unreachable!() + internal_err!("FirstValue does not have a state type") } fn state_fields( diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 0ed84f536766..bbe1b3140191 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -436,7 +436,6 @@ pub fn create_first_value_accumulator( ignore_nulls, ) .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) - // Ok(Box::new(acc)) } /// LAST_VALUE aggregate expression From 17378dd2460989751200d2f3661e629deff69fcf Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 30 Mar 2024 22:42:49 +0800 Subject: [PATCH 35/46] rm test1 Signed-off-by: jayzhan211 --- datafusion/sqllogictest/test_files/test1.slt | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 datafusion/sqllogictest/test_files/test1.slt diff --git a/datafusion/sqllogictest/test_files/test1.slt b/datafusion/sqllogictest/test_files/test1.slt deleted file mode 100644 index ed132d2676d2..000000000000 --- a/datafusion/sqllogictest/test_files/test1.slt +++ /dev/null @@ -1,7 +0,0 @@ -statement ok -CREATE TABLE t AS VALUES (null::bigint), (3), (4); - -query I -SELECT first_value(column1) FROM t; ----- -NULL \ No newline at end of file From dd1c4babe7138c2701b5f13eaac489380d4c1373 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 31 Mar 2024 09:26:34 +0800 Subject: [PATCH 36/46] fix state fields Signed-off-by: jayzhan211 --- datafusion/expr/src/udaf.rs | 22 ++++++++-------------- datafusion/physical-plan/src/udaf.rs | 27 +++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b65f7f30935f..bba3faa19cf0 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -264,7 +264,7 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _arg: &DataType, _sort_exprs: &[Expr], _schema: &Schema) -> Result> { unimplemented!() } +/// fn accumulator(&self, _arg: &DataType, _sort_exprs: &[Expr], _schema: &Schema, _ignore_nulls: bool, _requirement_satisfied: bool) -> Result> { unimplemented!() } /// fn state_type(&self, _return_type: &DataType) -> Result> { /// Ok(vec![DataType::Float64, DataType::UInt32]) /// } @@ -314,22 +314,16 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// See [`Accumulator::state()`] for more details fn state_type(&self, return_type: &DataType) -> Result>; - /// Default fields including the value field and ordering fields + /// Return the fields of the intermediate state. It is mutually exclusive with [`Self::state_type`]. + /// If you define `state_type`, you don't need to define `state_fields` and vice versa. + /// If you want empty fields, you should define empty `state_type` fn state_fields( &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, + _name: &str, + _value_type: DataType, + _ordering_fields: Vec, ) -> Result> { - let value_field = Field::new( - format_state_name(name, "default_state_name"), - value_type, - true, - ); - - let mut fields = vec![value_field]; - fields.extend(ordering_fields); - Ok(fields) + Ok(vec![]) } /// If the aggregate expression has a specialized diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 9989a51f4680..9d7d1f9c4f11 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -18,6 +18,7 @@ //! This module contains functions and structs supporting user-defined aggregate functions. use datafusion_expr::{Expr, GroupsAccumulator}; +use datafusion_physical_expr::expressions::format_state_name; use fmt::Debug; use std::any::Any; use std::fmt; @@ -107,11 +108,33 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - self.fun.state_fields( + let fields = self.fun.state_fields( self.name(), self.data_type.clone(), self.ordering_fields.clone(), - ) + )?; + + if !fields.is_empty() { + return Ok(fields) + } + + // If fields is empty, we will use the default state fields + let fields = self + .fun + .state_type(&self.data_type)? + .iter() + .enumerate() + .map(|(i, data_type)| { + Field::new( + format_state_name(&self.name, &format!("{i}")), + data_type.clone(), + true, + ) + }) + .collect::>(); + + Ok(fields) + } fn field(&self) -> Result { From 5d5d310a424958065d8062050dd1ca9a84f32d57 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 31 Mar 2024 09:31:34 +0800 Subject: [PATCH 37/46] fmt Signed-off-by: jayzhan211 --- datafusion/physical-plan/src/udaf.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 9d7d1f9c4f11..04b228e2f825 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -115,7 +115,7 @@ impl AggregateExpr for AggregateFunctionExpr { )?; if !fields.is_empty() { - return Ok(fields) + return Ok(fields); } // If fields is empty, we will use the default state fields @@ -134,7 +134,6 @@ impl AggregateExpr for AggregateFunctionExpr { .collect::>(); Ok(fields) - } fn field(&self) -> Result { From 23f20f99f57ce52dd8e03ff68ef3cc7181afbc06 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 31 Mar 2024 22:40:31 +0800 Subject: [PATCH 38/46] args struct for accumulator Signed-off-by: jayzhan211 --- datafusion-examples/examples/advanced_udaf.rs | 12 +--- datafusion-examples/examples/simple_udaf.rs | 2 +- datafusion/core/src/execution/context/mod.rs | 12 +++- datafusion/core/src/physical_planner.rs | 32 ++++------ .../user_defined/user_defined_aggregates.rs | 24 +++----- .../user_defined_scalar_functions.rs | 2 +- datafusion/expr/src/expr_fn.rs | 33 ++++------ datafusion/expr/src/function.rs | 40 ++++++++---- datafusion/expr/src/udaf.rs | 61 +++++-------------- .../optimizer/src/analyzer/type_coercion.rs | 4 +- .../optimizer/src/common_subexpr_eliminate.rs | 3 +- .../physical-expr/src/aggregate/first_last.rs | 23 ++++--- .../physical-expr/src/expressions/mod.rs | 2 +- datafusion/physical-plan/src/udaf.rs | 22 +++---- datafusion/proto/src/bytes/mod.rs | 2 +- .../tests/cases/roundtrip_logical_plan.rs | 4 +- .../tests/cases/roundtrip_physical_plan.rs | 3 +- .../tests/cases/roundtrip_logical_plan.rs | 2 +- 18 files changed, 119 insertions(+), 164 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 26476996f50b..85d4a2f757e8 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -31,7 +31,8 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{cast::as_float64_array, ScalarValue}; use datafusion_expr::{ - Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, + function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl, + GroupsAccumulator, Signature, }; /// This example shows how to use the full AggregateUDFImpl API to implement a user @@ -86,14 +87,7 @@ impl AggregateUDFImpl for GeoMeanUdaf { /// is supported, DataFusion will use this row oriented /// accumulator when the aggregate function is used as a window function /// or when there are only aggregates (no GROUP BY columns) in the plan. - fn accumulator( - &self, - _arg: &DataType, - _sort_exprs: &[Expr], - _schema: &Schema, - _ignore_nulls: bool, - _requirement_satisfied: bool, - ) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(GeometricMean::new())) } diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 3a62e28b0568..0996a67245a8 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -150,7 +150,7 @@ async fn main() -> Result<()> { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_, _, _| Ok(Box::new(GeometricMean::new()))), + Arc::new(|_| Ok(Box::new(GeometricMean::new()))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 4e7ae32df74f..8ae2762240d0 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -85,6 +85,7 @@ use datafusion_sql::{ use async_trait::async_trait; use chrono::{DateTime, Utc}; +use log::debug; use parking_lot::RwLock; use sqlparser::dialect::dialect_from_str; use url::Url; @@ -1465,7 +1466,16 @@ impl SessionState { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), Arc::new(create_first_value_accumulator), ); - let _ = new_self.register_udaf(Arc::new(first_value)); + + match new_self.register_udaf(Arc::new(first_value)) { + Ok(Some(existing_udaf)) => { + debug!("Overwrite existing UDF: {}", existing_udaf.name()); + } + Ok(None) => {} + Err(err) => { + panic!("Failed to register UDF: {}", err); + } + } new_self } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d68615e8a5d3..9bafc4200118 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1666,11 +1666,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( )?), None => None, }; - let ignore_nulls = null_treatment - .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) - == NullTreatment::IgnoreNulls; - let (agg_expr, filter, order_by) = match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { + + let try_create_physical_sort_expr = + |order_by: &Option>| -> Result>> { let physical_sort_exprs = match order_by { Some(e) => Some( e.iter() @@ -1685,7 +1683,15 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ), None => None, }; + Ok(physical_sort_exprs) + }; + let ignore_nulls = null_treatment + .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) + == NullTreatment::IgnoreNulls; + let (agg_expr, filter, order_by) = match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let physical_sort_exprs = try_create_physical_sort_expr(order_by)?; let ordering_reqs: Vec = physical_sort_exprs.clone().unwrap_or(vec![]); let agg_expr = aggregates::create_aggregate_expr( @@ -1701,21 +1707,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( } AggregateFunctionDefinition::UDF(fun) => { let sort_exprs = order_by.clone().unwrap_or(vec![]); - let physical_sort_exprs = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, - }; - + let physical_sort_exprs = try_create_physical_sort_expr(order_by)?; let ordering_reqs: Vec = physical_sort_exprs.clone().unwrap_or(vec![]); let agg_expr = udaf::create_aggregate_expr( diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index bfb93fb18389..63dffb7161e4 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -45,7 +45,8 @@ use datafusion::{ }; use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, AggregateUDFImpl, Expr, GroupsAccumulator, SimpleAggregateUDF, + create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, + SimpleAggregateUDF, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -234,7 +235,7 @@ async fn simple_udaf() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_, _, _| Ok(Box::::default())), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -262,7 +263,7 @@ async fn deregister_udaf() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_, _, _| Ok(Box::::default())), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -290,7 +291,7 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_, _, _| Ok(Box::::default())), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -333,7 +334,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_, _, _| Ok(Box::::default())), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ) .with_aliases(vec!["dummy_alias"]); @@ -497,7 +498,7 @@ impl TimeSum { let captured_state = Arc::clone(&test_state); let accumulator: AccumulatorFactoryFunction = - Arc::new(move |_, _, _| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); + Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); let time_sum = AggregateUDF::from(SimpleAggregateUDF::new( name, @@ -596,7 +597,7 @@ impl FirstSelector { let signatures = vec![TypeSignature::Exact(Self::input_datatypes())]; let accumulator: AccumulatorFactoryFunction = - Arc::new(|_, _, _| Ok(Box::new(Self::new()))); + Arc::new(|_| Ok(Box::new(Self::new()))); let volatility = Volatility::Immutable; @@ -717,14 +718,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator { Ok(DataType::UInt64) } - fn accumulator( - &self, - _arg: &DataType, - _sort_exprs: &[Expr], - _schema: &Schema, - _ignore_nulls: bool, - _requirement_satisfied: bool, - ) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 540c256cfe00..86be887198ae 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -291,7 +291,7 @@ async fn udaf_as_window_func() -> Result<()> { vec![DataType::Int32], Arc::new(DataType::Int32), Volatility::Immutable, - Arc::new(|_, _, _| Ok(Box::new(MyAccumulator))), + Arc::new(|_| Ok(Box::new(MyAccumulator))), Arc::new(vec![DataType::Int32]), ); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index f2e1471e0574..b4a46a22d994 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -22,17 +22,16 @@ use crate::expr::{ Placeholder, ScalarFunction, TryCast, }; use crate::function::{ - AccumulatorFactoryFunctionForFirstValue, PartitionEvaluatorFactory, + AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, }; use crate::udaf::format_state_name; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, - logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, - BuiltinScalarFunction, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, - ScalarUDF, Signature, Volatility, + logical_plan::Subquery, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, + Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{internal_err, Column, Result}; use std::any::Any; use std::fmt::Debug; @@ -727,7 +726,7 @@ pub fn create_udaf( pub fn create_first_value( name: &str, signature: Signature, - accumulator: AccumulatorFactoryFunctionForFirstValue, + accumulator: AccumulatorFactoryFunction, ) -> AggregateUDF { AggregateUDF::from(FirstValue::new(name, signature, accumulator)) } @@ -744,7 +743,7 @@ pub struct SimpleAggregateUDF { impl Debug for SimpleAggregateUDF { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("AggregateUDF") + f.debug_struct("FirstValue") .field("name", &self.name) .field("signature", &self.signature) .field("fun", &"") @@ -811,13 +810,9 @@ impl AggregateUDFImpl for SimpleAggregateUDF { fn accumulator( &self, - arg: &DataType, - sort_exprs: &[Expr], - schema: &Schema, - _ignore_nulls: bool, - _requirement_satisfied: bool, + acc_args: AccumulatorArgs, ) -> Result> { - (self.accumulator)(arg, sort_exprs, schema) + (self.accumulator)(acc_args) } fn state_type(&self, _return_type: &DataType) -> Result> { @@ -828,7 +823,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { pub struct FirstValue { name: String, signature: Signature, - accumulator: AccumulatorFactoryFunctionForFirstValue, + accumulator: AccumulatorFactoryFunction, } impl Debug for FirstValue { @@ -844,7 +839,7 @@ impl FirstValue { pub fn new( name: impl Into, signature: Signature, - accumulator: AccumulatorFactoryFunctionForFirstValue, + accumulator: AccumulatorFactoryFunction, ) -> Self { let name = name.into(); Self { @@ -874,13 +869,9 @@ impl AggregateUDFImpl for FirstValue { fn accumulator( &self, - arg: &DataType, - sort_exprs: &[Expr], - schema: &Schema, - ignore_nulls: bool, - requirement_satisfied: bool, + acc_args: AccumulatorArgs, ) -> Result> { - (self.accumulator)(arg, sort_exprs, schema, ignore_nulls, requirement_satisfied) + (self.accumulator)(acc_args) } fn state_type(&self, _return_type: &DataType) -> Result> { diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 5e63138a0c24..3bddc3aef050 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -38,21 +38,37 @@ pub type ScalarFunctionImplementation = pub type ReturnTypeFunction = Arc Result> + Send + Sync>; -/// Factory that returns an accumulator for the given aggregate, given -/// its return datatype, the sorting expressions and the schema for ordering. -pub type AccumulatorFactoryFunction = Arc< - dyn Fn(&DataType, &[Expr], &Schema) -> Result> + Send + Sync, ->; +/// Arguments passed to create an accumulator +pub struct AccumulatorArgs<'a> { + // default arguments + pub data_type: &'a DataType, + pub schema: &'a Schema, + pub ignore_nulls: bool, + + // ordering arguments + pub sort_exprs: &'a [Expr], +} + +impl<'a> AccumulatorArgs<'a> { + pub fn new( + data_type: &'a DataType, + schema: &'a Schema, + ignore_nulls: bool, + sort_exprs: &'a [Expr], + ) -> Self { + Self { + data_type, + schema, + ignore_nulls, + sort_exprs, + } + } +} /// Factory that returns an accumulator for the given aggregate, given /// its return datatype, the sorting expressions and the schema for ordering. -/// FirstValue needs additional `ignore_nulls` and `requirement_satisfied` flags. -// TODO: It would be nice if we can have flexible design for arbitrary arguments. -pub type AccumulatorFactoryFunctionForFirstValue = Arc< - dyn Fn(&DataType, &[Expr], &Schema, bool, bool) -> Result> - + Send - + Sync, ->; +pub type AccumulatorFactoryFunction = + Arc Result> + Send + Sync>; /// Factory that creates a PartitionEvaluator for the given window /// function diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index bba3faa19cf0..a2f6e17f9ab7 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,12 +17,13 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions +use crate::function::AccumulatorArgs; use crate::groups_accumulator::GroupsAccumulator; use crate::{Accumulator, Expr}; use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{not_impl_err, Result}; use std::any::Any; use std::fmt::{self, Debug, Formatter}; @@ -169,21 +170,8 @@ impl AggregateUDF { } /// Return an accumulator the given aggregate, given its return datatype - pub fn accumulator( - &self, - arg: &DataType, - sort_exprs: &[Expr], - schema: &Schema, - ignore_nulls: bool, - requirement_satisfied: bool, - ) -> Result> { - self.inner.accumulator( - arg, - sort_exprs, - schema, - ignore_nulls, - requirement_satisfied, - ) + pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + self.inner.accumulator(acc_args) } /// Return the type of the intermediate state used by this aggregator, given @@ -294,21 +282,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Return a new [`Accumulator`] that aggregates values for a specific /// group during query execution. /// - /// `arg`: the type of the argument to this accumulator + /// `data_type`: the type of the argument to this accumulator /// /// `sort_exprs`: contains a list of `Expr::SortExpr`s if the /// aggregate is called with an explicit `ORDER BY`. For example, /// `ARRAY_AGG(x ORDER BY y ASC)`. In this case, `sort_exprs` would contain `[y ASC]` /// /// `schema` is the input schema to the udaf - fn accumulator( - &self, - arg: &DataType, - sort_exprs: &[Expr], - schema: &Schema, - ignore_nulls: bool, - requirement_satisfied: bool, - ) -> Result>; + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result>; /// Return the type used to serialize the [`Accumulator`]'s intermediate state. /// See [`Accumulator::state()`] for more details @@ -388,21 +369,8 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.return_type(arg_types) } - fn accumulator( - &self, - arg: &DataType, - sort_exprs: &[Expr], - schema: &Schema, - ignore_nulls: bool, - requirement_satisfied: bool, - ) -> Result> { - self.inner.accumulator( - arg, - sort_exprs, - schema, - ignore_nulls, - requirement_satisfied, - ) + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + self.inner.accumulator(acc_args) } fn state_type(&self, return_type: &DataType) -> Result> { @@ -460,13 +428,14 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { fn accumulator( &self, - arg: &DataType, - sort_exprs: &[Expr], - schema: &Schema, - _ignore_nulls: bool, - _requirement_satisfied: bool, + acc_args: AccumulatorArgs, + // data_type: &DataType, + // schema: &Schema, + // _sort_exprs: &[Expr], + // _ignore_nulls: bool, ) -> Result> { - (self.accumulator)(arg, sort_exprs, schema) + // let acc_args = AccumulatorArgs::new(data_type, schema); + (self.accumulator)(acc_args) } fn state_type(&self, return_type: &DataType) -> Result> { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 147706b636f6..6e3e14ea0fdf 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -892,7 +892,7 @@ mod test { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_, _, _| Ok(Box::::default())), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( @@ -914,7 +914,7 @@ mod test { let return_type = DataType::Float64; let state_type = vec![DataType::UInt64, DataType::Float64]; let accumulator: AccumulatorFactoryFunction = - Arc::new(|_, _, _| Ok(Box::::default())); + Arc::new(|_| Ok(Box::::default())); let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "MY_AVG", Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 75c71d0aa298..c5dc64416d7b 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -963,8 +963,7 @@ mod test { let table_scan = test_table_scan()?; let return_type = DataType::UInt32; - let accumulator: AccumulatorFactoryFunction = - Arc::new(|_, _, _| unimplemented!()); + let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); let state_type = vec![DataType::UInt32]; let udf_agg = |inner: Expr| { Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index bbe1b3140191..26bd219f65f0 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -29,11 +29,12 @@ use crate::{ use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; use arrow::compute::{self, lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field}; -use arrow_schema::{Schema, SortOptions}; +use arrow_schema::SortOptions; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{Accumulator, Expr}; /// FIRST_VALUE aggregate expression @@ -218,7 +219,7 @@ impl PartialEq for FirstValue { } #[derive(Debug)] -pub struct FirstValueAccumulator { +struct FirstValueAccumulator { first: ScalarValue, // At the beginning, `is_set` is false, which means `first` is not seen yet. // Once we see the first value, we set the `is_set` flag and do not update `first` anymore. @@ -393,21 +394,17 @@ impl Accumulator for FirstValueAccumulator { } pub fn create_first_value_accumulator( - data_type: &DataType, - order_by: &[Expr], - schema: &Schema, - ignore_nulls: bool, - requirement_satisfied: bool, + acc_args: AccumulatorArgs, ) -> Result> { let mut all_sort_orders = vec![]; // Construct PhysicalSortExpr objects from Expr objects: let mut sort_exprs = vec![]; - for expr in order_by { + for expr in acc_args.sort_exprs { if let Expr::Sort(sort) = expr { if let Expr::Column(col) = sort.expr.as_ref() { let name = &col.name; - let e = expressions::col(name, schema)?; + let e = expressions::col(name, acc_args.schema)?; sort_exprs.push(PhysicalSortExpr { expr: e, options: SortOptions { @@ -426,14 +423,16 @@ pub fn create_first_value_accumulator( let ordering_dtypes = ordering_req .iter() - .map(|e| e.expr.data_type(schema)) + .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; + let requirement_satisfied = ordering_req.is_empty(); + FirstValueAccumulator::try_new( - data_type, + acc_args.data_type, &ordering_dtypes, ordering_req, - ignore_nulls, + acc_args.ignore_nulls, ) .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index fcd656173355..7c4ea07dfbcb 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -53,7 +53,7 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::first_last::{FirstValue, FirstValueAccumulator, LastValue}; +pub use crate::aggregate::first_last::{FirstValue, LastValue}; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 04b228e2f825..46d4604ba4f7 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -17,6 +17,7 @@ //! This module contains functions and structs supporting user-defined aggregate functions. +use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{Expr, GroupsAccumulator}; use datafusion_physical_expr::expressions::format_state_name; use fmt::Debug; @@ -49,8 +50,6 @@ pub fn create_aggregate_expr( .map(|arg| arg.data_type(schema)) .collect::>>()?; - let requirement_satisfied = ordering_req.is_empty(); - let ordering_types = ordering_req .iter() .map(|e| e.expr.data_type(schema)) @@ -67,7 +66,6 @@ pub fn create_aggregate_expr( sort_exprs: sort_exprs.to_vec(), ordering_req: ordering_req.to_vec(), ignore_nulls, - requirement_satisfied, ordering_fields, })) } @@ -86,7 +84,6 @@ pub struct AggregateFunctionExpr { // The physical order by expressions ordering_req: LexOrdering, ignore_nulls: bool, - requirement_satisfied: bool, ordering_fields: Vec, } @@ -141,23 +138,18 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - self.fun.accumulator( + let acc_args = AccumulatorArgs::new( &self.data_type, - &self.sort_exprs, &self.schema, self.ignore_nulls, - self.requirement_satisfied, - ) + &self.sort_exprs, + ); + + self.fun.accumulator(acc_args) } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = self.fun.accumulator( - &self.data_type, - &self.sort_exprs, - &self.schema, - self.ignore_nulls, - self.requirement_satisfied, - )?; + let accumulator = self.create_accumulator()?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 4c570d343574..610c533d574c 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -127,7 +127,7 @@ impl Serializeable for Expr { vec![arrow::datatypes::DataType::Null], Arc::new(arrow::datatypes::DataType::Null), Volatility::Immutable, - Arc::new(|_, _, _| unimplemented!()), + Arc::new(|_| unimplemented!()), Arc::new(vec![]), ))) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 4cd3eb9286d6..8a8bfe6112d3 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1762,7 +1762,7 @@ fn roundtrip_aggregate_udf() { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_, _, _| Ok(Box::new(Dummy {}))), + Arc::new(|_| Ok(Box::new(Dummy {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); @@ -1989,7 +1989,7 @@ fn roundtrip_window() { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_, _, _| Ok(Box::new(DummyAggr {}))), + Arc::new(|_| Ok(Box::new(DummyAggr {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index a032aaa451f3..e95e033a878d 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -411,8 +411,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { } let return_type = DataType::Int64; - let accumulator: AccumulatorFactoryFunction = - Arc::new(|_, _, _| Ok(Box::new(Example))); + let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example))); let state_type = vec![DataType::Int64]; let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index a24a84ad76be..bc9cc66b7626 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -750,7 +750,7 @@ async fn roundtrip_aggregate_udf() -> Result<()> { Arc::new(DataType::Int64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_, _, _| Ok(Box::new(Dummy {}))), + Arc::new(|_| Ok(Box::new(Dummy {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); From b2ba8c3f4ec0851ac640d1be74d07c068423feeb Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 31 Mar 2024 22:46:20 +0800 Subject: [PATCH 39/46] simplify Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 9bafc4200118..a3bdaf51c1dc 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1653,11 +1653,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( order_by, null_treatment, }) => { - let args = args - .iter() - .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) - .collect::>>()?; - + let args = + create_physical_exprs(args, logical_input_schema, execution_props)?; let filter = match filter { Some(e) => Some(create_physical_expr( e, From 75aa2fe91f392638a49f865613e5a3316978b7e4 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 31 Mar 2024 22:48:50 +0800 Subject: [PATCH 40/46] add sig Signed-off-by: jayzhan211 --- datafusion/expr/src/expr_fn.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index b4a46a22d994..92bcd206fdf7 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -743,7 +743,7 @@ pub struct SimpleAggregateUDF { impl Debug for SimpleAggregateUDF { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("FirstValue") + f.debug_struct("AggregateUDF") .field("name", &self.name) .field("signature", &self.signature) .field("fun", &"") @@ -828,8 +828,9 @@ pub struct FirstValue { impl Debug for FirstValue { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("AggregateUDF") + f.debug_struct("FirstValue") .field("name", &self.name) + .field("signature", &self.signature) .field("fun", &"") .finish() } From 5b9625f6a2b9c1052c22dd23c8b4de43fd054312 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 31 Mar 2024 22:54:30 +0800 Subject: [PATCH 41/46] add comments Signed-off-by: jayzhan211 --- datafusion/expr/src/function.rs | 11 +++++------ datafusion/expr/src/udaf.rs | 5 ----- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 3bddc3aef050..cfc581f816f1 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -41,12 +41,12 @@ pub type ReturnTypeFunction = /// Arguments passed to create an accumulator pub struct AccumulatorArgs<'a> { // default arguments - pub data_type: &'a DataType, - pub schema: &'a Schema, - pub ignore_nulls: bool, + pub data_type: &'a DataType, // the return type of the function + pub schema: &'a Schema, // the schema of the input arguments + pub ignore_nulls: bool, // whether to ignore nulls // ordering arguments - pub sort_exprs: &'a [Expr], + pub sort_exprs: &'a [Expr], // the expressions of `order by` } impl<'a> AccumulatorArgs<'a> { @@ -65,8 +65,7 @@ impl<'a> AccumulatorArgs<'a> { } } -/// Factory that returns an accumulator for the given aggregate, given -/// its return datatype, the sorting expressions and the schema for ordering. +/// Factory that returns an accumulator for the given aggregate function. pub type AccumulatorFactoryFunction = Arc Result> + Send + Sync>; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index a2f6e17f9ab7..bebc108f606e 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -429,12 +429,7 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { fn accumulator( &self, acc_args: AccumulatorArgs, - // data_type: &DataType, - // schema: &Schema, - // _sort_exprs: &[Expr], - // _ignore_nulls: bool, ) -> Result> { - // let acc_args = AccumulatorArgs::new(data_type, schema); (self.accumulator)(acc_args) } From d5c3f6f76c1e343b45e3360d915efcb08f584dad Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 31 Mar 2024 22:56:24 +0800 Subject: [PATCH 42/46] fmt Signed-off-by: jayzhan211 --- datafusion/expr/src/function.rs | 8 ++++---- datafusion/expr/src/udaf.rs | 5 +---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index cfc581f816f1..d267e15b88c2 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -41,12 +41,12 @@ pub type ReturnTypeFunction = /// Arguments passed to create an accumulator pub struct AccumulatorArgs<'a> { // default arguments - pub data_type: &'a DataType, // the return type of the function - pub schema: &'a Schema, // the schema of the input arguments - pub ignore_nulls: bool, // whether to ignore nulls + pub data_type: &'a DataType, // the return type of the function + pub schema: &'a Schema, // the schema of the input arguments + pub ignore_nulls: bool, // whether to ignore nulls // ordering arguments - pub sort_exprs: &'a [Expr], // the expressions of `order by` + pub sort_exprs: &'a [Expr], // the expressions of `order by` } impl<'a> AccumulatorArgs<'a> { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index bebc108f606e..d61147785b1e 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -426,10 +426,7 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { Ok(res.as_ref().clone()) } - fn accumulator( - &self, - acc_args: AccumulatorArgs, - ) -> Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { (self.accumulator)(acc_args) } From dc9549a2800b011b8ca7c06d954db51e089854cf Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 1 Apr 2024 08:18:58 +0800 Subject: [PATCH 43/46] fix docs Signed-off-by: jayzhan211 --- datafusion/expr/src/udaf.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index d61147785b1e..4a467f571dbe 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -225,7 +225,7 @@ where /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; -/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator}; +/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs}; /// # use arrow::datatypes::Schema; /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { @@ -252,7 +252,7 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _arg: &DataType, _sort_exprs: &[Expr], _schema: &Schema, _ignore_nulls: bool, _requirement_satisfied: bool) -> Result> { unimplemented!() } +/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } /// fn state_type(&self, _return_type: &DataType) -> Result> { /// Ok(vec![DataType::Float64, DataType::UInt32]) /// } From 49b4a7657b2ac17a8fba04ac593b08be65782379 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 1 Apr 2024 19:35:40 +0800 Subject: [PATCH 44/46] use exprs utils Signed-off-by: jayzhan211 --- datafusion/core/src/execution/context/mod.rs | 4 +-- datafusion/core/src/physical_planner.rs | 37 +++++++++----------- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 8ae2762240d0..e66486cfa4e3 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1469,11 +1469,11 @@ impl SessionState { match new_self.register_udaf(Arc::new(first_value)) { Ok(Some(existing_udaf)) => { - debug!("Overwrite existing UDF: {}", existing_udaf.name()); + debug!("Overwrite existing UDAF: {}", existing_udaf.name()); } Ok(None) => {} Err(err) => { - panic!("Failed to register UDF: {}", err); + panic!("Failed to register UDAF: {}", err); } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index a3bdaf51c1dc..6e12ce6d5ad9 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1664,31 +1664,19 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( None => None, }; - let try_create_physical_sort_expr = - |order_by: &Option>| -> Result>> { - let physical_sort_exprs = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, - }; - Ok(physical_sort_exprs) - }; - let ignore_nulls = null_treatment .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { - let physical_sort_exprs = try_create_physical_sort_expr(order_by)?; + let physical_sort_exprs = match order_by { + Some(exprs) => Some(create_physical_sort_exprs( + exprs, + logical_input_schema, + execution_props, + )?), + None => None, + }; let ordering_reqs: Vec = physical_sort_exprs.clone().unwrap_or(vec![]); let agg_expr = aggregates::create_aggregate_expr( @@ -1704,7 +1692,14 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( } AggregateFunctionDefinition::UDF(fun) => { let sort_exprs = order_by.clone().unwrap_or(vec![]); - let physical_sort_exprs = try_create_physical_sort_expr(order_by)?; + let physical_sort_exprs = match order_by { + Some(exprs) => Some(create_physical_sort_exprs( + exprs, + logical_input_schema, + execution_props, + )?), + None => None, + }; let ordering_reqs: Vec = physical_sort_exprs.clone().unwrap_or(vec![]); let agg_expr = udaf::create_aggregate_expr( From d70cce54c7bf2fdd0516493f845dc592021793b6 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 2 Apr 2024 08:55:44 +0800 Subject: [PATCH 45/46] rm state type Signed-off-by: jayzhan211 --- datafusion-examples/examples/advanced_udaf.rs | 15 ++-- .../user_defined/user_defined_aggregates.rs | 15 ++-- datafusion/expr/src/expr_fn.rs | 35 +++++---- datafusion/expr/src/function.rs | 12 ++-- datafusion/expr/src/udaf.rs | 71 ++++++++----------- .../optimizer/src/analyzer/type_coercion.rs | 8 ++- .../optimizer/src/common_subexpr_eliminate.rs | 3 +- datafusion/physical-plan/src/udaf.rs | 26 +------ .../tests/cases/roundtrip_physical_plan.rs | 3 +- 9 files changed, 86 insertions(+), 102 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 85d4a2f757e8..342a23b6e73d 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::Schema; +use arrow_schema::{Field, Schema}; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; use datafusion_physical_expr::NullState; use std::{any::Any, sync::Arc}; @@ -92,8 +92,16 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_type(&self, _return_type: &DataType) -> Result> { - Ok(vec![DataType::Float64, DataType::UInt32]) + fn state_fields( + &self, + _name: &str, + value_type: DataType, + _ordering_fields: Vec, + ) -> Result> { + Ok(vec![ + Field::new("prod", value_type, true), + Field::new("n", DataType::UInt32, true), + ]) } /// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator` @@ -193,7 +201,6 @@ impl Accumulator for GeometricMean { // create local session context with an in-memory table fn create_context() -> Result { - use datafusion::arrow::datatypes::Field; use datafusion::datasource::MemTable; // define a schema. let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 63dffb7161e4..6085fca8761f 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -492,7 +492,7 @@ impl TimeSum { // Returns the same type as its input let return_type = timestamp_type.clone(); - let state_type = vec![timestamp_type.clone()]; + let state_fields = vec![Field::new("sum", timestamp_type, true)]; let volatility = Volatility::Immutable; @@ -506,7 +506,7 @@ impl TimeSum { return_type, volatility, accumulator, - state_type, + state_fields, )); // register the selector as "time_sum" @@ -592,6 +592,11 @@ impl FirstSelector { fn register(ctx: &mut SessionContext) { let return_type = Self::output_datatype(); let state_type = Self::state_datatypes(); + let state_fields = state_type + .into_iter() + .enumerate() + .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .collect::>(); // Possible input signatures let signatures = vec![TypeSignature::Exact(Self::input_datatypes())]; @@ -608,7 +613,7 @@ impl FirstSelector { Signature::one_of(signatures, volatility), return_type, accumulator, - state_type, + state_fields, )); // register the selector as "first" @@ -723,10 +728,6 @@ impl AggregateUDFImpl for TestGroupsAccumulator { panic!("accumulator shouldn't invoke"); } - fn state_type(&self, _return_type: &DataType) -> Result> { - Ok(vec![DataType::UInt64]) - } - fn groups_accumulator_supported(&self) -> bool { true } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index fb5f35612108..8a7c3f70bf6c 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -32,7 +32,7 @@ use crate::{ }; use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{internal_err, Column, Result}; +use datafusion_common::{Column, Result}; use std::any::Any; use std::fmt::Debug; use std::ops::Not; @@ -702,18 +702,24 @@ pub fn create_udaf( ) -> AggregateUDF { let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); + let state_fields = state_type + .into_iter() + .enumerate() + .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .collect::>(); AggregateUDF::from(SimpleAggregateUDF::new( name, input_type, return_type, volatility, accumulator, - state_type, + state_fields, )) } /// Creates a new UDAF with a specific signature, state type and return type. /// The signature and state type must match the `Accumulator's implementation`. +/// TOOD: We plan to move aggregate function to its own crate. This function will be deprecated then. pub fn create_first_value( name: &str, signature: Signature, @@ -729,7 +735,7 @@ pub struct SimpleAggregateUDF { signature: Signature, return_type: DataType, accumulator: AccumulatorFactoryFunction, - state_type: Vec, + state_fields: Vec, } impl Debug for SimpleAggregateUDF { @@ -751,7 +757,7 @@ impl SimpleAggregateUDF { return_type: DataType, volatility: Volatility, accumulator: AccumulatorFactoryFunction, - state_type: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -760,7 +766,7 @@ impl SimpleAggregateUDF { signature, return_type, accumulator, - state_type, + state_fields, } } @@ -769,7 +775,7 @@ impl SimpleAggregateUDF { signature: Signature, return_type: DataType, accumulator: AccumulatorFactoryFunction, - state_type: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); Self { @@ -777,7 +783,7 @@ impl SimpleAggregateUDF { signature, return_type, accumulator, - state_type, + state_fields, } } } @@ -806,8 +812,13 @@ impl AggregateUDFImpl for SimpleAggregateUDF { (self.accumulator)(acc_args) } - fn state_type(&self, _return_type: &DataType) -> Result> { - Ok(self.state_type.clone()) + fn state_fields( + &self, + _name: &str, + _value_type: DataType, + _ordering_fields: Vec, + ) -> Result> { + Ok(self.state_fields.clone()) } } @@ -822,7 +833,7 @@ impl Debug for FirstValue { f.debug_struct("FirstValue") .field("name", &self.name) .field("signature", &self.signature) - .field("fun", &"") + .field("accumulator", &"") .finish() } } @@ -866,10 +877,6 @@ impl AggregateUDFImpl for FirstValue { (self.accumulator)(acc_args) } - fn state_type(&self, _return_type: &DataType) -> Result> { - internal_err!("FirstValue does not have a state type") - } - fn state_fields( &self, name: &str, diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index d267e15b88c2..235642682899 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -41,12 +41,16 @@ pub type ReturnTypeFunction = /// Arguments passed to create an accumulator pub struct AccumulatorArgs<'a> { // default arguments - pub data_type: &'a DataType, // the return type of the function - pub schema: &'a Schema, // the schema of the input arguments - pub ignore_nulls: bool, // whether to ignore nulls + /// the return type of the function + pub data_type: &'a DataType, + /// the schema of the input arguments + pub schema: &'a Schema, + /// whether to ignore nulls + pub ignore_nulls: bool, // ordering arguments - pub sort_exprs: &'a [Expr], // the expressions of `order by` + /// the expressions of `order by` + pub sort_exprs: &'a [Expr], } impl<'a> AccumulatorArgs<'a> { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 4a467f571dbe..ba80f39dde43 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -20,14 +20,13 @@ use crate::function::AccumulatorArgs; use crate::groups_accumulator::GroupsAccumulator; use crate::{Accumulator, Expr}; -use crate::{ - AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, -}; +use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{not_impl_err, Result}; use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; +use std::vec; /// Logical representation of a user-defined [aggregate function] (UDAF). /// @@ -91,14 +90,12 @@ impl AggregateUDF { signature: &Signature, return_type: &ReturnTypeFunction, accumulator: &AccumulatorFactoryFunction, - state_type: &StateTypeFunction, ) -> Self { Self::new_from_impl(AggregateUDFLegacyWrapper { name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), accumulator: accumulator.clone(), - state_type: state_type.clone(), }) } @@ -174,12 +171,9 @@ impl AggregateUDF { self.inner.accumulator(acc_args) } - /// Return the type of the intermediate state used by this aggregator, given - /// its return datatype. Supports multi-phase aggregations - pub fn state_type(&self, return_type: &DataType) -> Result> { - self.inner.state_type(return_type) - } - + /// Return the fields of the intermediate state used by this aggregator, given + /// its state name, value type and ordering fields. See [`AggregateUDFImpl::state_fields`] + /// for more details. Supports multi-phase aggregations pub fn state_fields( &self, name: &str, @@ -227,6 +221,7 @@ where /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs}; /// # use arrow::datatypes::Schema; +/// # use arrow::datatypes::Field; /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { /// signature: Signature @@ -253,8 +248,11 @@ where /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. /// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } -/// fn state_type(&self, _return_type: &DataType) -> Result> { -/// Ok(vec![DataType::Float64, DataType::UInt32]) +/// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec) -> Result> { +/// Ok(vec![ +/// Field::new("value", value_type, true), +/// Field::new("ordering", DataType::UInt32, true) +/// ]) /// } /// } /// @@ -282,29 +280,29 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Return a new [`Accumulator`] that aggregates values for a specific /// group during query execution. /// - /// `data_type`: the type of the argument to this accumulator - /// - /// `sort_exprs`: contains a list of `Expr::SortExpr`s if the - /// aggregate is called with an explicit `ORDER BY`. For example, - /// `ARRAY_AGG(x ORDER BY y ASC)`. In this case, `sort_exprs` would contain `[y ASC]` - /// - /// `schema` is the input schema to the udaf + /// `acc_args`: the arguments to the accumulator. See [`AccumulatorArgs`] for more details. fn accumulator(&self, acc_args: AccumulatorArgs) -> Result>; - /// Return the type used to serialize the [`Accumulator`]'s intermediate state. - /// See [`Accumulator::state()`] for more details - fn state_type(&self, return_type: &DataType) -> Result>; - - /// Return the fields of the intermediate state. It is mutually exclusive with [`Self::state_type`]. - /// If you define `state_type`, you don't need to define `state_fields` and vice versa. - /// If you want empty fields, you should define empty `state_type` + /// Return the fields of the intermediate state. + /// + /// name: the name of the state + /// + /// value_type: the type of the value, it should be the result of the `return_type` + /// + /// ordering_fields: the fields used for ordering, empty if no ordering expression is provided fn state_fields( &self, - _name: &str, - _value_type: DataType, - _ordering_fields: Vec, + name: &str, + value_type: DataType, + ordering_fields: Vec, ) -> Result> { - Ok(vec![]) + let value_fields = vec![Field::new( + format_state_name(name, "value"), + value_type, + true, + )]; + + Ok(value_fields.into_iter().chain(ordering_fields).collect()) } /// If the aggregate expression has a specialized @@ -373,10 +371,6 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.accumulator(acc_args) } - fn state_type(&self, return_type: &DataType) -> Result> { - self.inner.state_type(return_type) - } - fn aliases(&self) -> &[String] { &self.aliases } @@ -393,8 +387,6 @@ pub struct AggregateUDFLegacyWrapper { return_type: ReturnTypeFunction, /// actual implementation accumulator: AccumulatorFactoryFunction, - /// the accumulator's state's description as a function of the return type - state_type: StateTypeFunction, } impl Debug for AggregateUDFLegacyWrapper { @@ -429,11 +421,6 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { (self.accumulator)(acc_args) } - - fn state_type(&self, return_type: &DataType) -> Result> { - let res = (self.state_type)(return_type)?; - Ok(res.as_ref().clone()) - } } /// returns the name of the state diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 6e3e14ea0fdf..de1d68294e35 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -761,7 +761,7 @@ mod test { }; use crate::test::assert_analyzed_plan_eq; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; @@ -912,7 +912,6 @@ mod test { fn agg_udaf_invalid_input() -> Result<()> { let empty = empty(); let return_type = DataType::Float64; - let state_type = vec![DataType::UInt64, DataType::Float64]; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::::default())); let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( @@ -920,7 +919,10 @@ mod test { Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), return_type, accumulator, - state_type, + vec![ + Field::new("count", DataType::UInt64, true), + Field::new("avg", DataType::Float64, true), + ], )); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 8dd530d1eae7..6dfd0ee2c5e3 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -801,7 +801,6 @@ mod test { let return_type = DataType::UInt32; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); - let state_type = vec![DataType::UInt32]; let udf_agg = |inner: Expr| { Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature( @@ -809,7 +808,7 @@ mod test { Signature::exact(vec![DataType::UInt32], Volatility::Stable), return_type.clone(), accumulator.clone(), - state_type.clone(), + vec![Field::new("value", DataType::UInt32, true)], ))), vec![inner], false, diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 46d4604ba4f7..74a5603c0c81 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -19,7 +19,6 @@ use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{Expr, GroupsAccumulator}; -use datafusion_physical_expr::expressions::format_state_name; use fmt::Debug; use std::any::Any; use std::fmt; @@ -105,32 +104,11 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - let fields = self.fun.state_fields( + self.fun.state_fields( self.name(), self.data_type.clone(), self.ordering_fields.clone(), - )?; - - if !fields.is_empty() { - return Ok(fields); - } - - // If fields is empty, we will use the default state fields - let fields = self - .fun - .state_type(&self.data_type)? - .iter() - .enumerate() - .map(|(i, data_type)| { - Field::new( - format_state_name(&self.name, &format!("{i}")), - data_type.clone(), - true, - ) - }) - .collect::>(); - - Ok(fields) + ) } fn field(&self) -> Result { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index e95e033a878d..9ce26bc6ed9b 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -412,14 +412,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let return_type = DataType::Int64; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example))); - let state_type = vec![DataType::Int64]; let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "example", Signature::exact(vec![DataType::Int64], Volatility::Immutable), return_type, accumulator, - state_type, + vec![Field::new("value", DataType::Int64, true)], )); let ctx = SessionContext::new(); From 29c4018774e1547f431b400d0e3c04a395daa431 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 2 Apr 2024 09:02:57 +0800 Subject: [PATCH 46/46] add comment Signed-off-by: jayzhan211 --- datafusion/expr/src/function.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 235642682899..7598c805adf6 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -49,7 +49,7 @@ pub struct AccumulatorArgs<'a> { pub ignore_nulls: bool, // ordering arguments - /// the expressions of `order by` + /// the expressions of `order by`, if no ordering is required, this will be an empty slice pub sort_exprs: &'a [Expr], }