From f4e519f9df9ab5972638d3f2743da01887a52668 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Sat, 3 Aug 2024 07:18:10 -0400 Subject: [PATCH] Move min and max to user defined aggregate function, remove `AggregateFunction` / `AggregateFunctionDefinition::BuiltIn` (#11013) * Moving min and max to new API and removing from protobuf * Using input_type rather than data_type * Adding type coercion * Fixed doctests * Implementing feedback from code review * Implementing feedback from code review * Fixed wrong name * Fixing name --- .../examples/dataframe_subquery.rs | 1 + datafusion/core/src/dataframe/mod.rs | 8 +- .../src/datasource/file_format/parquet.rs | 2 +- datafusion/core/src/datasource/statistics.rs | 2 +- datafusion/core/src/execution/context/mod.rs | 1 + datafusion/core/src/lib.rs | 1 + .../aggregate_statistics.rs | 15 +- datafusion/core/src/physical_planner.rs | 28 +- datafusion/core/tests/dataframe/mod.rs | 8 +- .../core/tests/fuzz_cases/window_fuzz.rs | 21 +- datafusion/expr/src/aggregate_function.rs | 156 ---- datafusion/expr/src/expr.rs | 59 +- datafusion/expr/src/expr_fn.rs | 30 +- datafusion/expr/src/expr_rewriter/order_by.rs | 6 +- datafusion/expr/src/expr_schema.rs | 8 - datafusion/expr/src/lib.rs | 2 - datafusion/expr/src/test/function_stub.rs | 174 ++++ datafusion/expr/src/tree_node.rs | 10 - .../expr/src/type_coercion/aggregates.rs | 66 +- datafusion/expr/src/utils.rs | 19 +- datafusion/functions-aggregate/Cargo.toml | 3 + datafusion/functions-aggregate/src/lib.rs | 12 +- .../src}/min_max.rs | 856 ++++++++++++------ datafusion/functions-nested/src/planner.rs | 7 +- .../src/analyzer/count_wildcard_rule.rs | 6 +- .../optimizer/src/analyzer/type_coercion.rs | 58 +- datafusion/optimizer/src/decorrelate.rs | 3 - .../optimizer/src/optimize_projections/mod.rs | 12 +- datafusion/optimizer/src/push_down_limit.rs | 4 +- .../optimizer/src/scalar_subquery_to_join.rs | 4 +- .../simplify_expressions/simplify_exprs.rs | 6 +- .../src/single_distinct_to_groupby.rs | 115 +-- .../physical-expr/src/aggregate/build_in.rs | 208 ----- .../src/aggregate/groups_accumulator/mod.rs | 4 - datafusion/physical-expr/src/aggregate/mod.rs | 3 - .../physical-expr/src/expressions/mod.rs | 5 - .../physical-plan/src/aggregates/mod.rs | 3 - datafusion/physical-plan/src/windows/mod.rs | 18 - datafusion/proto/gen/src/main.rs | 7 +- datafusion/proto/proto/datafusion.proto | 50 - datafusion/proto/src/generated/pbjson.rs | 293 ------ datafusion/proto/src/generated/prost.rs | 90 +- .../proto/src/logical_plan/from_proto.rs | 54 +- datafusion/proto/src/logical_plan/to_proto.rs | 44 +- .../proto/src/physical_plan/from_proto.rs | 9 - datafusion/proto/src/physical_plan/mod.rs | 24 +- .../proto/src/physical_plan/to_proto.rs | 75 +- .../tests/cases/roundtrip_logical_plan.rs | 17 +- .../tests/cases/roundtrip_physical_plan.rs | 19 +- datafusion/sql/src/expr/function.rs | 68 +- datafusion/sql/tests/cases/plan_to_sql.rs | 6 +- datafusion/sql/tests/sql_integration.rs | 5 +- .../substrait/src/logical_plan/consumer.rs | 10 +- .../substrait/src/logical_plan/producer.rs | 32 - docs/source/user-guide/dataframe.md | 1 + docs/source/user-guide/example-usage.md | 2 + 56 files changed, 937 insertions(+), 1813 deletions(-) delete mode 100644 datafusion/expr/src/aggregate_function.rs rename datafusion/{physical-expr/src/aggregate => functions-aggregate/src}/min_max.rs (60%) delete mode 100644 datafusion/physical-expr/src/aggregate/build_in.rs diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs index e798751b3353..3e3d0c1b5a84 100644 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ b/datafusion-examples/examples/dataframe_subquery.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg; +use datafusion::functions_aggregate::min_max::max; use datafusion::prelude::*; use datafusion::test_util::arrow_test_data; use datafusion_common::ScalarValue; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 8feccfb43d6b..cacfa4c6f2aa 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -53,9 +53,11 @@ use datafusion_common::{ }; use datafusion_expr::{case, is_null, lit}; use datafusion_expr::{ - max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, + utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, +}; +use datafusion_functions_aggregate::expr_fn::{ + avg, count, max, median, min, stddev, sum, }; -use datafusion_functions_aggregate::expr_fn::{avg, count, median, stddev, sum}; use async_trait::async_trait; use datafusion_catalog::Session; @@ -144,6 +146,7 @@ impl Default for DataFrameWriteOptions { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; +/// # use datafusion::functions_aggregate::expr_fn::min; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); @@ -407,6 +410,7 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion::functions_aggregate::expr_fn::min; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 829b69c297ee..f233f3842c8c 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -50,7 +50,7 @@ use datafusion_common::{ use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator}; +use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; diff --git a/datafusion/core/src/datasource/statistics.rs b/datafusion/core/src/datasource/statistics.rs index a243a1c3558f..8c789e461b08 100644 --- a/datafusion/core/src/datasource/statistics.rs +++ b/datafusion/core/src/datasource/statistics.rs @@ -18,7 +18,7 @@ use super::listing::PartitionedFile; use crate::arrow::datatypes::{Schema, SchemaRef}; use crate::error::Result; -use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; +use crate::functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics}; use arrow_schema::DataType; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 9b889c37ab52..24704bc794c2 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -144,6 +144,7 @@ where /// /// ``` /// use datafusion::prelude::*; +/// # use datafusion::functions_aggregate::expr_fn::min; /// # use datafusion::{error::Result, assert_batches_eq}; /// # #[tokio::main] /// # async fn main() -> Result<()> { diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index cf5a184e3416..3bb0636652c0 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -52,6 +52,7 @@ //! ```rust //! # use datafusion::prelude::*; //! # use datafusion::error::Result; +//! # use datafusion::functions_aggregate::expr_fn::min; //! # use datafusion::arrow::record_batch::RecordBatch; //! //! # #[tokio::main] diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index a8332d1d55e4..a0f6f6a65b1f 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -272,39 +272,28 @@ fn is_non_distinct_count(agg_expr: &dyn AggregateExpr) -> bool { return true; } } - false } // TODO: Move this check into AggregateUDFImpl // https://github.com/apache/datafusion/issues/11153 fn is_min(agg_expr: &dyn AggregateExpr) -> bool { - if agg_expr.as_any().is::() { - return true; - } - if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { - if agg_expr.fun().name() == "min" { + if agg_expr.fun().name().to_lowercase() == "min" { return true; } } - false } // TODO: Move this check into AggregateUDFImpl // https://github.com/apache/datafusion/issues/11153 fn is_max(agg_expr: &dyn AggregateExpr) -> bool { - if agg_expr.as_any().is::() { - return true; - } - if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { - if agg_expr.fun().name() == "max" { + if agg_expr.fun().name().to_lowercase() == "max" { return true; } } - false } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 329d343f13fc..03e20b886e2c 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -59,8 +59,8 @@ use crate::physical_plan::unnest::UnnestExec; use crate::physical_plan::values::ValuesExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ - aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan, - ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, WindowExpr, + displayable, udaf, windows, AggregateExpr, ExecutionPlan, ExecutionPlanProperties, + InputOrderMode, Partitioning, PhysicalExpr, WindowExpr, }; use arrow::compute::SortOptions; @@ -1812,7 +1812,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( e: &Expr, name: impl Into, logical_input_schema: &DFSchema, - physical_input_schema: &Schema, + _physical_input_schema: &Schema, execution_props: &ExecutionProps, ) -> Result { match e { @@ -1840,28 +1840,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let physical_sort_exprs = match order_by { - Some(exprs) => Some(create_physical_sort_exprs( - exprs, - logical_input_schema, - execution_props, - )?), - None => None, - }; - let ordering_reqs: Vec = - physical_sort_exprs.clone().unwrap_or(vec![]); - let agg_expr = aggregates::create_aggregate_expr( - fun, - *distinct, - &physical_args, - &ordering_reqs, - physical_input_schema, - name, - ignore_nulls, - )?; - (agg_expr, filter, physical_sort_exprs) - } AggregateFunctionDefinition::UDF(fun) => { let sort_exprs = order_by.clone().unwrap_or(vec![]); let physical_sort_exprs = match order_by { diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index d83a47ceb069..86cacbaa06d8 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -54,11 +54,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, - scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder, scalar_subquery, + when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum}; +use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, max, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index c97621ec4d01..813862c4cc2f 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -32,13 +32,13 @@ use datafusion::physical_plan::{collect, InputOrderMode}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; -use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; use datafusion_expr::{ - AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunctionDefinition, + BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -361,14 +361,14 @@ fn get_random_function( window_fn_map.insert( "min", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![arg.clone()], ), ); window_fn_map.insert( "max", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![arg.clone()], ), ); @@ -465,16 +465,7 @@ fn get_random_function( let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, args) = window_fn_map.values().collect::>()[rand_fn_idx]; let mut args = args.clone(); - if let WindowFunctionDefinition::AggregateFunction(f) = window_fn { - if !args.is_empty() { - // Do type coercion first argument - let a = args[0].clone(); - let dt = a.data_type(schema.as_ref()).unwrap(); - let sig = f.signature(); - let coerced = coerce_types(f, &[dt], &sig).unwrap(); - args[0] = cast(a, schema, coerced[0].clone()).unwrap(); - } - } else if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn { + if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn { if !args.is_empty() { // Do type coercion first argument let a = args[0].clone(); diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs deleted file mode 100644 index 4037e3c5db9b..000000000000 --- a/datafusion/expr/src/aggregate_function.rs +++ /dev/null @@ -1,156 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Aggregate function module contains all built-in aggregate functions definitions - -use std::{fmt, str::FromStr}; - -use crate::utils; -use crate::{type_coercion::aggregates::*, Signature, Volatility}; - -use arrow::datatypes::DataType; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; - -use strum_macros::EnumIter; - -/// Enum of all built-in aggregate functions -// Contributor's guide for adding new aggregate functions -// https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] -pub enum AggregateFunction { - /// Minimum - Min, - /// Maximum - Max, -} - -impl AggregateFunction { - pub fn name(&self) -> &str { - use AggregateFunction::*; - match self { - Min => "MIN", - Max => "MAX", - } - } -} - -impl fmt::Display for AggregateFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name()) - } -} - -impl FromStr for AggregateFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name { - // general - "max" => AggregateFunction::Max, - "min" => AggregateFunction::Min, - _ => { - return plan_err!("There is no built-in function named {name}"); - } - }) - } -} - -impl AggregateFunction { - /// Returns the datatype of the aggregate function given its argument types - /// - /// This is used to get the returned data type for aggregate expr. - pub fn return_type( - &self, - input_expr_types: &[DataType], - _input_expr_nullable: &[bool], - ) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - let coerced_data_types = coerce_types(self, input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message - .map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - - match self { - AggregateFunction::Max | AggregateFunction::Min => { - // For min and max agg function, the returned type is same as input type. - // The coerced_data_types is same with input_types. - Ok(coerced_data_types[0].clone()) - } - } - } - - /// Returns if the return type of the aggregate function is nullable given its argument - /// nullability - pub fn nullable(&self) -> Result { - match self { - AggregateFunction::Max | AggregateFunction::Min => Ok(true), - } - } -} - -impl AggregateFunction { - /// the signatures supported by the function `fun`. - pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match self { - AggregateFunction::Min | AggregateFunction::Max => { - let valid = STRINGS - .iter() - .chain(NUMERICS.iter()) - .chain(TIMESTAMPS.iter()) - .chain(DATES.iter()) - .chain(TIMES.iter()) - .chain(BINARYS.iter()) - .cloned() - .collect::>(); - Signature::uniform(1, valid, Volatility::Immutable) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use strum::IntoEnumIterator; - - #[test] - // Test for AggregateFunction's Display and from_str() implementations. - // For each variant in AggregateFunction, it converts the variant to a string - // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. This assertion is also necessary for - // function suggestion. See https://github.com/apache/datafusion/issues/8082 - fn test_display_and_from_str() { - for func_original in AggregateFunction::iter() { - let func_name = func_original.to_string(); - let func_from_str = - AggregateFunction::from_str(func_name.to_lowercase().as_str()).unwrap(); - assert_eq!(func_from_str, func_original); - } - } -} diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 68d5504eea48..708843494814 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,8 +28,8 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; use crate::{ - aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction, - ExprSchemable, Operator, Signature, WindowFrame, WindowUDF, + built_in_window_function, udaf, BuiltInWindowFunction, ExprSchemable, Operator, + Signature, WindowFrame, WindowUDF, }; use crate::{window_frame, Volatility}; @@ -630,7 +630,6 @@ impl Sort { #[derive(Debug, Clone, PartialEq, Eq, Hash)] /// Defines which implementation of an aggregate function DataFusion should call. pub enum AggregateFunctionDefinition { - BuiltIn(aggregate_function::AggregateFunction), /// Resolved to a user defined aggregate function UDF(Arc), } @@ -639,7 +638,6 @@ impl AggregateFunctionDefinition { /// Function's name for display pub fn name(&self) -> &str { match self { - AggregateFunctionDefinition::BuiltIn(fun) => fun.name(), AggregateFunctionDefinition::UDF(udf) => udf.name(), } } @@ -666,24 +664,6 @@ pub struct AggregateFunction { } impl AggregateFunction { - pub fn new( - fun: aggregate_function::AggregateFunction, - args: Vec, - distinct: bool, - filter: Option>, - order_by: Option>, - null_treatment: Option, - ) -> Self { - Self { - func_def: AggregateFunctionDefinition::BuiltIn(fun), - args, - distinct, - filter, - order_by, - null_treatment, - } - } - /// Create a new AggregateFunction expression with a user-defined function (UDF) pub fn new_udf( udf: Arc, @@ -709,7 +689,6 @@ impl AggregateFunction { /// Defines which implementation of an aggregate function DataFusion should call. pub enum WindowFunctionDefinition { /// A built in aggregate function that leverages an aggregate function - AggregateFunction(aggregate_function::AggregateFunction), /// A a built-in window function BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), /// A user defined aggregate function @@ -723,12 +702,9 @@ impl WindowFunctionDefinition { pub fn return_type( &self, input_expr_types: &[DataType], - input_expr_nullable: &[bool], + _input_expr_nullable: &[bool], ) -> Result { match self { - WindowFunctionDefinition::AggregateFunction(fun) => { - fun.return_type(input_expr_types, input_expr_nullable) - } WindowFunctionDefinition::BuiltInWindowFunction(fun) => { fun.return_type(input_expr_types) } @@ -742,7 +718,6 @@ impl WindowFunctionDefinition { /// the signatures supported by the function `fun`. pub fn signature(&self) -> Signature { match self { - WindowFunctionDefinition::AggregateFunction(fun) => fun.signature(), WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(), WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(), @@ -754,7 +729,6 @@ impl WindowFunctionDefinition { match self { WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.name(), WindowFunctionDefinition::WindowUDF(fun) => fun.name(), - WindowFunctionDefinition::AggregateFunction(fun) => fun.name(), WindowFunctionDefinition::AggregateUDF(fun) => fun.name(), } } @@ -763,9 +737,6 @@ impl WindowFunctionDefinition { impl fmt::Display for WindowFunctionDefinition { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - WindowFunctionDefinition::AggregateFunction(fun) => { - std::fmt::Display::fmt(fun, f) - } WindowFunctionDefinition::BuiltInWindowFunction(fun) => { std::fmt::Display::fmt(fun, f) } @@ -775,12 +746,6 @@ impl fmt::Display for WindowFunctionDefinition { } } -impl From for WindowFunctionDefinition { - fn from(value: aggregate_function::AggregateFunction) -> Self { - Self::AggregateFunction(value) - } -} - impl From for WindowFunctionDefinition { fn from(value: BuiltInWindowFunction) -> Self { Self::BuiltInWindowFunction(value) @@ -866,10 +831,6 @@ pub fn find_df_window_func(name: &str) -> Option { Some(WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, )) - } else if let Ok(aggregate) = - aggregate_function::AggregateFunction::from_str(name.as_str()) - { - Some(WindowFunctionDefinition::AggregateFunction(aggregate)) } else { None } @@ -2589,8 +2550,6 @@ mod test { "first_value", "last_value", "nth_value", - "min", - "max", ]; for name in names { let fun = find_df_window_func(name).unwrap(); @@ -2607,18 +2566,6 @@ mod test { #[test] fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("max"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Max - )) - ); - assert_eq!( - find_df_window_func("min"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Min - )) - ); assert_eq!( find_df_window_func("cume_dist"), Some(WindowFunctionDefinition::BuiltInWindowFunction( diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1f51cded2239..e9c5485656c8 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -26,9 +26,9 @@ use crate::function::{ StateFieldsArgs, }; use crate::{ - aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, - AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + conditional_expressions::CaseBuilder, logical_plan::Subquery, AggregateUDF, Expr, + LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, + Volatility, }; use crate::{ AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, @@ -150,30 +150,6 @@ pub fn not(expr: Expr) -> Expr { expr.not() } -/// Create an expression to represent the min() aggregate function -pub fn min(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Min, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Create an expression to represent the max() aggregate function -pub fn max(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Max, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 4b56ca3d1c2e..2efdcae1a790 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -156,11 +156,13 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use crate::{ - cast, col, lit, logical_plan::builder::LogicalTableSource, min, - test::function_stub::avg, try_cast, LogicalPlanBuilder, + cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast, + LogicalPlanBuilder, }; use super::*; + use crate::test::function_stub::avg; + use crate::test::function_stub::min; #[test] fn rewrite_sort_cols_by_agg() { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 5e0571f712ee..6344b892adb7 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -198,14 +198,7 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - let nullability = args - .iter() - .map(|e| e.nullable(schema)) - .collect::>>()?; match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - fun.return_type(&data_types, &nullability) - } AggregateFunctionDefinition::UDF(fun) => { let new_types = data_types_with_aggregate_udf(&data_types, fun) .map_err(|err| { @@ -338,7 +331,6 @@ impl ExprSchemable for Expr { Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), Expr::AggregateFunction(AggregateFunction { func_def, .. }) => { match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => fun.nullable(), // TODO: UDF should be able to customize nullability AggregateFunctionDefinition::UDF(udf) if udf.name() == "count" => { Ok(false) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 0a5cf4653a22..f5460918fa70 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -39,7 +39,6 @@ mod udaf; mod udf; mod udwf; -pub mod aggregate_function; pub mod conditional_expressions; pub mod execution_props; pub mod expr; @@ -64,7 +63,6 @@ pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; -pub use aggregate_function::AggregateFunction; pub use built_in_window_function::BuiltInWindowFunction; pub use columnar_value::ColumnarValue; pub use expr::{ diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 3e0760b5c0de..72b73ccee44f 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -289,6 +289,180 @@ impl AggregateUDFImpl for Count { } } +create_func!(Min, min_udaf); + +pub fn min(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + min_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Testing stub implementation of Min aggregate +pub struct Min { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Min { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Min") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Min { + fn default() -> Self { + Self::new() + } +} + +impl Min { + pub fn new() -> Self { + Self { + aliases: vec!["min".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Min { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MIN" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + fn is_descending(&self) -> Option { + Some(false) + } +} + +create_func!(Max, max_udaf); + +pub fn max(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + max_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Testing stub implementation of MAX aggregate +pub struct Max { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Max { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Max") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Max { + fn default() -> Self { + Self::new() + } +} + +impl Max { + pub fn new() -> Self { + Self { + aliases: vec!["max".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Max { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MAX" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + fn is_descending(&self) -> Option { + Some(true) + } +} + /// Testing stub implementation of avg aggregate #[derive(Debug)] pub struct Avg { diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index a97b9f010f79..a8062c0c07ee 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -318,16 +318,6 @@ impl TreeNode for Expr { )? .map_data( |(new_args, new_filter, new_order_by)| match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun, - new_args, - distinct, - new_filter, - new_order_by, - null_treatment, - ))) - } AggregateFunctionDefinition::UDF(fun) => { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( fun, diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index a024401e18d5..e7e58bf84362 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::ops::Deref; - +use crate::TypeSignature; use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, @@ -24,8 +23,6 @@ use arrow::datatypes::{ use datafusion_common::{internal_err, plan_err, Result}; -use crate::{AggregateFunction, Signature, TypeSignature}; - pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; pub static SIGNED_INTEGERS: &[DataType] = &[ @@ -84,25 +81,6 @@ pub static TIMES: &[DataType] = &[ DataType::Time64(TimeUnit::Nanosecond), ]; -/// Returns the coerced data type for each `input_types`. -/// Different aggregate function with different input data type will get corresponding coerced data type. -pub fn coerce_types( - agg_fun: &AggregateFunction, - input_types: &[DataType], - signature: &Signature, -) -> Result> { - // Validate input_types matches (at least one of) the func signature. - check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; - - match agg_fun { - AggregateFunction::Min | AggregateFunction::Max => { - // min and max support the dictionary data type - // unpack the dictionary to get the value - get_min_max_result_type(input_types) - } - } -} - /// Validate the length of `input_types` matches the `signature` for `agg_fun`. /// /// This method DOES NOT validate the argument types - only that (at least one, @@ -163,22 +141,6 @@ pub fn check_arg_count( Ok(()) } -fn get_min_max_result_type(input_types: &[DataType]) -> Result> { - // make sure that the input types only has one element. - assert_eq!(input_types.len(), 1); - // min and max support the dictionary data type - // unpack the dictionary to get the value - match &input_types[0] { - DataType::Dictionary(_, dict_value_type) => { - // TODO add checker, if the value type is complex data type - Ok(vec![dict_value_type.deref().clone()]) - } - // TODO add checker for datatype which min and max supported - // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function - _ => Ok(input_types.to_vec()), - } -} - /// function return type of a sum pub fn sum_return_type(arg_type: &DataType) -> Result { match arg_type { @@ -348,32 +310,6 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result Result<()> { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 2ef1597abfd1..683a8e170ed4 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1253,8 +1253,9 @@ mod tests { use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::sum_udaf, AggregateFunction, Cast, ExprFunctionExt, - WindowFrame, WindowFunctionDefinition, + test::function_stub::max_udaf, test::function_stub::min_udaf, + test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFrame, + WindowFunctionDefinition, }; #[test] @@ -1268,15 +1269,15 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( @@ -1299,18 +1300,18 @@ mod tests { let created_at_desc = Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) @@ -1352,7 +1353,7 @@ mod tests { fn test_find_sort_exprs() -> Result<()> { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![ diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 26630a0352d5..43ddd37cfb6f 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -48,3 +48,6 @@ datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.14" sqlparser = { workspace = true } + +[dev-dependencies] +rand = { workspace = true } diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 171186966644..b54cd181a0cb 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -65,6 +65,7 @@ pub mod covariance; pub mod first_last; pub mod hyperloglog; pub mod median; +pub mod min_max; pub mod regr; pub mod stddev; pub mod sum; @@ -110,7 +111,8 @@ pub mod expr_fn { pub use super::first_last::last_value; pub use super::grouping::grouping; pub use super::median::median; - pub use super::nth_value::nth_value; + pub use super::min_max::max; + pub use super::min_max::min; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; pub use super::regr::regr_count; @@ -137,6 +139,8 @@ pub fn all_default_aggregate_functions() -> Vec> { covariance::covar_pop_udaf(), correlation::corr_udaf(), sum::sum_udaf(), + min_max::max_udaf(), + min_max::min_udaf(), median::median_udaf(), count::count_udaf(), regr::regr_slope_udaf(), @@ -192,11 +196,11 @@ mod tests { #[test] fn test_no_duplicate_name() -> Result<()> { let mut names = HashSet::new(); + let migrated_functions = ["array_agg", "count", "max", "min"]; for func in all_default_aggregate_functions() { // TODO: remove this - // These functions are in intermidiate migration state, skip them - let name_lower_case = func.name().to_lowercase(); - if name_lower_case == "count" || name_lower_case == "array_agg" { + // These functions are in intermediate migration state, skip them + if migrated_functions.contains(&func.name().to_lowercase().as_str()) { continue; } assert!( diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs similarity index 60% rename from datafusion/physical-expr/src/aggregate/min_max.rs rename to datafusion/functions-aggregate/src/min_max.rs index f9362db30196..4d743983411d 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -2,7 +2,6 @@ // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // @@ -15,103 +14,107 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function +//! [`Min`] and [`MinAccumulator`] accumulator for the `max` function -use std::any::Any; -use std::sync::Arc; +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. -use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::{ + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, + IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray, + StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, +}; use arrow::compute; use arrow::datatypes::{ - DataType, Date32Type, Date64Type, IntervalUnit, Time32MillisecondType, - Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, -}; -use arrow::{ - array::{ - ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, - LargeBinaryArray, LargeStringArray, StringArray, Time32MillisecondArray, - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }, - datatypes::Field, -}; -use arrow_array::types::{ - Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, + DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use arrow_array::{BinaryViewArray, StringViewArray}; -use datafusion_common::internal_err; -use datafusion_common::ScalarValue; -use datafusion_common::{downcast_value, DataFusionError, Result}; -use datafusion_expr::{Accumulator, GroupsAccumulator}; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use arrow::array::Array; -use arrow::array::Decimal128Array; -use arrow::array::Decimal256Array; -use arrow::datatypes::i256; -use arrow::datatypes::Decimal256Type; +use arrow_schema::IntervalUnit; +use datafusion_common::{downcast_value, internal_err, DataFusionError, Result}; +use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use std::fmt::Debug; -use super::moving_min_max; +use arrow::datatypes::i256; +use arrow::datatypes::{ + Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; -// Min/max aggregation can take Dictionary encode input but always produces unpacked -// (aka non Dictionary) output. We need to adjust the output data type to reflect this. -// The reason min/max aggregate produces unpacked output because there is only one -// min/max value per group; there is no needs to keep them Dictionary encode -fn min_max_aggregate_data_type(input_type: DataType) -> DataType { - if let DataType::Dictionary(_, value_type) = input_type { - *value_type - } else { - input_type +use datafusion_common::ScalarValue; +use datafusion_expr::GroupsAccumulator; +use datafusion_expr::{ + function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, +}; +use std::ops::Deref; + +fn get_min_max_result_type(input_types: &[DataType]) -> Result> { + // make sure that the input types only has one element. + assert_eq!(input_types.len(), 1); + // min and max support the dictionary data type + // unpack the dictionary to get the value + match &input_types[0] { + DataType::Dictionary(_, dict_value_type) => { + // TODO add checker, if the value type is complex data type + Ok(vec![dict_value_type.deref().clone()]) + } + // TODO add checker for datatype which min and max supported + // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function + _ => Ok(input_types.to_vec()), } } -/// MAX aggregate expression -#[derive(Debug, Clone)] +// MAX aggregate UDF +#[derive(Debug)] pub struct Max { - name: String, - data_type: DataType, - nullable: bool, - expr: Arc, + aliases: Vec, + signature: Signature, } impl Max { - /// Create a new MAX aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { + pub fn new() -> Self { Self { - name: name.into(), - expr, - data_type: min_max_aggregate_data_type(data_type), - nullable: true, + aliases: vec!["max".to_owned()], + signature: Signature::user_defined(Volatility::Immutable), } } } + +impl Default for Max { + fn default() -> Self { + Self::new() + } +} /// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX` /// the specified [`ArrowPrimitiveType`]. /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType macro_rules! instantiate_max_accumulator { - ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{ + ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( - PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( - &$SELF.data_type, - |cur, new| { - if *cur < new { - *cur = new - } - }, - ) + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| { + if *cur < new { + *cur = new + } + }) // Initialize each accumulator to $NATIVE::MIN .with_starting_value($NATIVE::MIN), )) @@ -124,60 +127,48 @@ macro_rules! instantiate_max_accumulator { /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType macro_rules! instantiate_min_accumulator { - ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{ + ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( - PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( - &$SELF.data_type, - |cur, new| { - if *cur > new { - *cur = new - } - }, - ) + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| { + if *cur > new { + *cur = new + } + }) // Initialize each accumulator to $NATIVE::MAX .with_starting_value($NATIVE::MAX), )) }}; } -impl AggregateExpr for Max { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { +impl AggregateUDFImpl for Max { + fn as_any(&self) -> &dyn std::any::Any { self } - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) + fn name(&self) -> &str { + "MAX" } - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "max"), - self.data_type.clone(), - true, - )]) + fn signature(&self) -> &Signature { + &self.signature } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].to_owned()) } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(MaxAccumulator::try_new(&self.data_type)?)) + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(MaxAccumulator::try_new(acc_args.data_type)?)) } - fn name(&self) -> &str { - &self.name + fn aliases(&self) -> &[String] { + &self.aliases } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - self.data_type, + args.data_type, Int8 | Int16 | Int32 | Int64 @@ -197,97 +188,92 @@ impl AggregateExpr for Max { ) } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { use DataType::*; use TimeUnit::*; - - match self.data_type { - Int8 => instantiate_max_accumulator!(self, i8, Int8Type), - Int16 => instantiate_max_accumulator!(self, i16, Int16Type), - Int32 => instantiate_max_accumulator!(self, i32, Int32Type), - Int64 => instantiate_max_accumulator!(self, i64, Int64Type), - UInt8 => instantiate_max_accumulator!(self, u8, UInt8Type), - UInt16 => instantiate_max_accumulator!(self, u16, UInt16Type), - UInt32 => instantiate_max_accumulator!(self, u32, UInt32Type), - UInt64 => instantiate_max_accumulator!(self, u64, UInt64Type), + let data_type = args.data_type; + match data_type { + Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type), + Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type), + Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type), + Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type), + UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type), + UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type), + UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type), + UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type), Float32 => { - instantiate_max_accumulator!(self, f32, Float32Type) + instantiate_max_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_max_accumulator!(self, f64, Float64Type) + instantiate_max_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_max_accumulator!(self, i32, Date32Type), - Date64 => instantiate_max_accumulator!(self, i64, Date64Type), + Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type), + Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_max_accumulator!(self, i32, Time32SecondType) + instantiate_max_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_max_accumulator!(self, i32, Time32MillisecondType) + instantiate_max_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_max_accumulator!(self, i64, Time64MicrosecondType) + instantiate_max_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_max_accumulator!(self, i64, Time64NanosecondType) + instantiate_max_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_max_accumulator!(self, i64, TimestampSecondType) + instantiate_max_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampMillisecondType) + instantiate_max_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampMicrosecondType) + instantiate_max_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampNanosecondType) + instantiate_max_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_max_accumulator!(self, i128, Decimal128Type) + instantiate_max_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_max_accumulator!(self, i256, Decimal256Type) + instantiate_max_accumulator!(data_type, i256, Decimal256Type) } // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync - _ => internal_err!( - "GroupsAccumulator not supported for max({})", - self.data_type - ), + _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), } } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(SlidingMaxAccumulator::try_new(args.data_type)?)) } - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(SlidingMaxAccumulator::try_new(&self.data_type)?)) + fn is_descending(&self) -> Option { + Some(true) } - - fn get_minmax_desc(&self) -> Option<(Field, bool)> { - Some((self.field().ok()?, true)) + fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { + datafusion_expr::utils::AggregateOrderSensitivity::Insensitive } -} -impl PartialEq for Max { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + get_min_max_result_type(arg_types) + } + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Identical } } -// Statically-typed version of min/max(array) -> ScalarValue for string types. +// Statically-typed version of min/max(array) -> ScalarValue for string types macro_rules! typed_min_max_batch_string { ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ let array = downcast_value!($VALUES, $ARRAYTYPE); @@ -296,8 +282,7 @@ macro_rules! typed_min_max_batch_string { ScalarValue::$SCALAR(value) }}; } - -// Statically-typed version of min/max(array) -> ScalarValue for binary types. +// Statically-typed version of min/max(array) -> ScalarValue for binay types. macro_rules! typed_min_max_batch_binary { ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ let array = downcast_value!($VALUES, $ARRAYTYPE); @@ -545,7 +530,6 @@ macro_rules! typed_min_max { ) }}; } - macro_rules! typed_min_max_float { ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ ScalarValue::$SCALAR(match ($VALUE, $DELTA) { @@ -804,16 +788,6 @@ macro_rules! min_max { }}; } -/// the minimum of two scalar values -pub fn min(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { - min_max!(lhs, rhs, min) -} - -/// the maximum of two scalar values -pub fn max(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { - min_max!(lhs, rhs, max) -} - /// An accumulator to compute the maximum value #[derive(Debug)] pub struct MaxAccumulator { @@ -833,7 +807,9 @@ impl Accumulator for MaxAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; let delta = &max_batch(values)?; - self.max = max(&self.max, delta)?; + let new_max: Result = + min_max!(&self.max, delta, max); + self.max = new_max?; Ok(()) } @@ -842,9 +818,8 @@ impl Accumulator for MaxAccumulator { } fn state(&mut self) -> Result> { - Ok(vec![self.max.clone()]) + Ok(vec![self.evaluate()?]) } - fn evaluate(&mut self) -> Result { Ok(self.max.clone()) } @@ -854,11 +829,10 @@ impl Accumulator for MaxAccumulator { } } -/// An accumulator to compute the maximum value #[derive(Debug)] pub struct SlidingMaxAccumulator { max: ScalarValue, - moving_max: moving_min_max::MovingMax, + moving_max: MovingMax, } impl SlidingMaxAccumulator { @@ -866,7 +840,7 @@ impl SlidingMaxAccumulator { pub fn try_new(datatype: &DataType) -> Result { Ok(Self { max: ScalarValue::try_from(datatype)?, - moving_max: moving_min_max::MovingMax::::new(), + moving_max: MovingMax::::new(), }) } } @@ -914,69 +888,56 @@ impl Accumulator for SlidingMaxAccumulator { } } -/// MIN aggregate expression -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Min { - name: String, - data_type: DataType, - nullable: bool, - expr: Arc, + signature: Signature, + aliases: Vec, } impl Min { - /// Create a new MIN aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { + pub fn new() -> Self { Self { - name: name.into(), - expr, - data_type: min_max_aggregate_data_type(data_type), - nullable: true, + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["min".to_owned()], } } } -impl AggregateExpr for Min { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { +impl Default for Min { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Min { + fn as_any(&self) -> &dyn std::any::Any { self } - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) + fn name(&self) -> &str { + "MIN" } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(MinAccumulator::try_new(&self.data_type)?)) + fn signature(&self) -> &Signature { + &self.signature } - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "min"), - self.data_type.clone(), - true, - )]) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].to_owned()) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(MinAccumulator::try_new(acc_args.data_type)?)) } - fn name(&self) -> &str { - &self.name + fn aliases(&self) -> &[String] { + &self.aliases } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - self.data_type, + args.data_type, Int8 | Int16 | Int32 | Int64 @@ -996,91 +957,92 @@ impl AggregateExpr for Min { ) } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { use DataType::*; use TimeUnit::*; - match self.data_type { - Int8 => instantiate_min_accumulator!(self, i8, Int8Type), - Int16 => instantiate_min_accumulator!(self, i16, Int16Type), - Int32 => instantiate_min_accumulator!(self, i32, Int32Type), - Int64 => instantiate_min_accumulator!(self, i64, Int64Type), - UInt8 => instantiate_min_accumulator!(self, u8, UInt8Type), - UInt16 => instantiate_min_accumulator!(self, u16, UInt16Type), - UInt32 => instantiate_min_accumulator!(self, u32, UInt32Type), - UInt64 => instantiate_min_accumulator!(self, u64, UInt64Type), + let data_type = args.data_type; + match data_type { + Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type), + Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type), + Int32 => instantiate_min_accumulator!(data_type, i32, Int32Type), + Int64 => instantiate_min_accumulator!(data_type, i64, Int64Type), + UInt8 => instantiate_min_accumulator!(data_type, u8, UInt8Type), + UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type), + UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type), + UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type), Float32 => { - instantiate_min_accumulator!(self, f32, Float32Type) + instantiate_min_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_min_accumulator!(self, f64, Float64Type) + instantiate_min_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_min_accumulator!(self, i32, Date32Type), - Date64 => instantiate_min_accumulator!(self, i64, Date64Type), + Date32 => instantiate_min_accumulator!(data_type, i32, Date32Type), + Date64 => instantiate_min_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_min_accumulator!(self, i32, Time32SecondType) + instantiate_min_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_min_accumulator!(self, i32, Time32MillisecondType) + instantiate_min_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_min_accumulator!(self, i64, Time64MicrosecondType) + instantiate_min_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_min_accumulator!(self, i64, Time64NanosecondType) + instantiate_min_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_min_accumulator!(self, i64, TimestampSecondType) + instantiate_min_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_min_accumulator!(self, i64, TimestampMillisecondType) + instantiate_min_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_min_accumulator!(self, i64, TimestampMicrosecondType) + instantiate_min_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_min_accumulator!(self, i64, TimestampNanosecondType) + instantiate_min_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_min_accumulator!(self, i128, Decimal128Type) + instantiate_min_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_min_accumulator!(self, i256, Decimal256Type) + instantiate_min_accumulator!(data_type, i256, Decimal256Type) } + + // It would be nice to have a fast implementation for Strings as well + // https://github.com/apache/datafusion/issues/6906 + // This is only reached if groups_accumulator_supported is out of sync - _ => internal_err!( - "GroupsAccumulator not supported for min({})", - self.data_type - ), + _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), } } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(SlidingMinAccumulator::try_new(args.data_type)?)) } - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(SlidingMinAccumulator::try_new(&self.data_type)?)) + fn is_descending(&self) -> Option { + Some(false) } - fn get_minmax_desc(&self) -> Option<(Field, bool)> { - Some((self.field().ok()?, false)) + fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { + datafusion_expr::utils::AggregateOrderSensitivity::Insensitive } -} -impl PartialEq for Min { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + get_min_max_result_type(arg_types) } -} + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Identical + } +} /// An accumulator to compute the minimum value #[derive(Debug)] pub struct MinAccumulator { @@ -1098,13 +1060,15 @@ impl MinAccumulator { impl Accumulator for MinAccumulator { fn state(&mut self) -> Result> { - Ok(vec![self.min.clone()]) + Ok(vec![self.evaluate()?]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; let delta = &min_batch(values)?; - self.min = min(&self.min, delta)?; + let new_min: Result = + min_max!(&self.min, delta, min); + self.min = new_min?; Ok(()) } @@ -1121,19 +1085,17 @@ impl Accumulator for MinAccumulator { } } -/// An accumulator to compute the minimum value #[derive(Debug)] pub struct SlidingMinAccumulator { min: ScalarValue, - moving_min: moving_min_max::MovingMin, + moving_min: MovingMin, } impl SlidingMinAccumulator { - /// new min accumulator pub fn try_new(datatype: &DataType) -> Result { Ok(Self { min: ScalarValue::try_from(datatype)?, - moving_min: moving_min_max::MovingMin::::new(), + moving_min: MovingMin::::new(), }) } } @@ -1186,12 +1148,278 @@ impl Accumulator for SlidingMinAccumulator { } } +// +// Moving min and moving max +// The implementation is taken from https://github.com/spebern/moving_min_max/blob/master/src/lib.rs. + +// Keep track of the minimum or maximum value in a sliding window. +// +// `moving min max` provides one data structure for keeping track of the +// minimum value and one for keeping track of the maximum value in a sliding +// window. +// +// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, +// push to this stack all elements popped from first stack while updating their current min/max. Now pop from +// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, +// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. +// +// The complexity of the operations are +// - O(1) for getting the minimum/maximum +// - O(1) for push +// - amortized O(1) for pop + +/// ``` +/// # use datafusion_functions_aggregate::min_max::MovingMin; +/// let mut moving_min = MovingMin::::new(); +/// moving_min.push(2); +/// moving_min.push(1); +/// moving_min.push(3); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(2)); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(1)); +/// +/// assert_eq!(moving_min.min(), Some(&3)); +/// assert_eq!(moving_min.pop(), Some(3)); +/// +/// assert_eq!(moving_min.min(), None); +/// assert_eq!(moving_min.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMin { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMin { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMin { + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window with `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the minimum of the sliding window or `None` if the window is + /// empty. + #[inline] + pub fn min(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, min)), None) => Some(min), + (None, Some((_, min))) => Some(min), + (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, min)) => { + if val > *min { + (val, min.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let min = if last.1 < val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), min); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} +/// ``` +/// # use datafusion_functions_aggregate::min_max::MovingMax; +/// let mut moving_max = MovingMax::::new(); +/// moving_max.push(2); +/// moving_max.push(3); +/// moving_max.push(1); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(2)); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(3)); +/// +/// assert_eq!(moving_max.max(), Some(&1)); +/// assert_eq!(moving_max.pop(), Some(1)); +/// +/// assert_eq!(moving_max.max(), None); +/// assert_eq!(moving_max.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMax { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMax { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMax { + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with + /// `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the maximum of the sliding window or `None` if the window is empty. + #[inline] + pub fn max(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, max)), None) => Some(max), + (None, Some((_, max))) => Some(max), + (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, max)) => { + if val < *max { + (val, max.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let max = if last.1 > val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), max); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +make_udaf_expr_and_func!( + Max, + max, + expression, + "Returns the maximum of a group of values.", + max_udaf +); + +make_udaf_expr_and_func!( + Min, + min, + expression, + "Returns the minimum of a group of values.", + min_udaf +); + #[cfg(test)] mod tests { use super::*; use arrow::datatypes::{ IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, }; + use std::sync::Arc; #[test] fn interval_min_max() { @@ -1324,4 +1552,100 @@ mod tests { check(&mut max(), &[&[zero], &[neg_inf]], zero); check(&mut max(), &[&[zero, neg_inf]], zero); } + + use datafusion_common::Result; + use rand::Rng; + + fn get_random_vec_i32(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut input = Vec::with_capacity(len); + for _i in 0..len { + input.push(rng.gen_range(0..100)); + } + input + } + + fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_min = MovingMin::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().min().unwrap()); + + moving_min.push(data[i]); + if i > n_sliding_window { + moving_min.pop(); + } + res.push(*moving_min.min().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_max = MovingMax::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().max().unwrap()); + + moving_max.push(data[i]); + if i > n_sliding_window { + moving_max.pop(); + } + res.push(*moving_max.max().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + #[test] + fn moving_min_tests() -> Result<()> { + moving_min_i32(100, 10)?; + moving_min_i32(100, 20)?; + moving_min_i32(100, 50)?; + moving_min_i32(100, 100)?; + Ok(()) + } + + #[test] + fn moving_max_tests() -> Result<()> { + moving_max_i32(100, 10)?; + moving_max_i32(100, 20)?; + moving_max_i32(100, 50)?; + moving_max_i32(100, 100)?; + Ok(()) + } + + #[test] + fn test_min_max_coerce_types() { + // the coerced types is same with input types + let funs: Vec> = + vec![Box::new(Min::new()), Box::new(Max::new())]; + let input_types = vec![ + vec![DataType::Int32], + vec![DataType::Decimal128(10, 2)], + vec![DataType::Decimal256(1, 1)], + vec![DataType::Utf8], + ]; + for fun in funs { + for input_type in &input_types { + let result = fun.coerce_types(input_type); + assert_eq!(*input_type, result.unwrap()); + } + } + } + + #[test] + fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> { + let data_type = + DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)); + let result = get_min_max_result_type(&[data_type])?; + assert_eq!(result, vec![DataType::Int32]); + Ok(()) + } } diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 97c54cc77beb..fee3e83a0d65 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -171,9 +171,6 @@ impl ExprPlanner for FieldAccessPlanner { } fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - if let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def { - return udf.name() == "array_agg"; - } - - false + let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def; + return udf.name() == "array_agg"; } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 338268e299da..6f832966671c 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -103,11 +103,11 @@ mod tests { use datafusion_expr::expr::Sort; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, - WindowFrameUnits, + col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col, + scalar_subquery, wildcard, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::max; use std::sync::Arc; use datafusion_functions_aggregate::expr_fn::{count, sum}; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 75dbb4d1adcd..bcd1cbcce23e 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -47,9 +47,8 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - type_coercion, AggregateFunction, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, - LogicalPlan, Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, - WindowFrameUnits, + AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, Operator, ScalarUDF, + WindowFrame, WindowFrameBound, WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; @@ -401,24 +400,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { order_by, null_treatment, }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let new_expr = coerce_agg_exprs_for_signature( - &fun, - args, - self.schema, - &fun.signature(), - )?; - Ok(Transformed::yes(Expr::AggregateFunction( - expr::AggregateFunction::new( - fun, - new_expr, - distinct, - filter, - order_by, - null_treatment, - ), - ))) - } AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, @@ -449,14 +430,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { coerce_window_frame(window_frame, self.schema, &order_by)?; let args = match &fun { - expr::WindowFunctionDefinition::AggregateFunction(fun) => { - coerce_agg_exprs_for_signature( - fun, - args, - self.schema, - &fun.signature(), - )? - } expr::WindowFunctionDefinition::AggregateUDF(udf) => { coerce_arguments_for_signature_with_aggregate_udf( args, @@ -692,33 +665,6 @@ fn coerce_arguments_for_fun( } } -/// Returns the coerced exprs for each `input_exprs`. -/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the -/// data type of `input_exprs` need to be coerced. -fn coerce_agg_exprs_for_signature( - agg_fun: &AggregateFunction, - input_exprs: Vec, - schema: &DFSchema, - signature: &Signature, -) -> Result> { - if input_exprs.is_empty() { - return Ok(input_exprs); - } - let current_types = input_exprs - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - let coerced_types = - type_coercion::aggregates::coerce_types(agg_fun, ¤t_types, signature)?; - - input_exprs - .into_iter() - .enumerate() - .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema)) - .collect() -} - fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { // Given expressions like: // diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index c998e8442548..6dbf1641bd7c 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -436,9 +436,6 @@ fn agg_exprs_evaluation_result_on_empty_batch( Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => match func_def { - AggregateFunctionDefinition::BuiltIn(_fun) => { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } AggregateFunctionDefinition::UDF(fun) => { if fun.name() == "count" { Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 16abf93f3807..31d59da13323 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -814,13 +814,13 @@ mod tests { expr::{self, Cast}, lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, - max, min, not, try_cast, when, AggregateFunction, BinaryExpr, Expr, Extension, - Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, - WindowFunctionDefinition, + not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, + Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_functions_aggregate::expr_fn::count; + use datafusion_functions_aggregate::expr_fn::{count, max, min}; + use datafusion_functions_aggregate::min_max::max_udaf; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) @@ -1917,7 +1917,7 @@ mod tests { let table_scan = test_table_scan()?; let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], )) .partition_by(vec![col("test.b")]) @@ -1925,7 +1925,7 @@ mod tests { .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], )); let col1 = col(max1.display_name()?); diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 79980f8fc9ec..d7da3871ee89 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -321,8 +321,8 @@ mod test { use super::*; use crate::test::*; - - use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder, max}; + use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_functions_aggregate::expr_fn::max; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 35691847fb8e..fbec675f6fc4 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -394,7 +394,9 @@ mod tests { use arrow::datatypes::DataType; use datafusion_expr::test::function_stub::sum; - use datafusion_expr::{col, lit, max, min, out_ref_col, scalar_subquery, Between}; + + use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between}; + use datafusion_functions_aggregate::min_max::{max, min}; /// Test multiple correlated subqueries #[test] diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index e650d4c09c23..e44f60d1df22 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -160,6 +160,7 @@ mod tests { ExprSchemable, JoinType, }; use datafusion_expr::{or, BinaryExpr, Cast, Operator}; + use datafusion_functions_aggregate::expr_fn::{max, min}; use crate::test::{assert_fields_eq, test_table_scan_with_name}; use crate::OptimizerContext; @@ -395,10 +396,7 @@ mod tests { .project(vec![col("a"), col("c"), col("b")])? .aggregate( vec![col("a"), col("c")], - vec![ - datafusion_expr::max(col("b").eq(lit(true))), - datafusion_expr::min(col("b")), - ], + vec![max(col("b").eq(lit(true))), min(col("b"))], )? .build()?; diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index d776e6598cbe..69c1b505727d 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -28,7 +28,6 @@ use datafusion_common::{ use datafusion_expr::builder::project; use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::{ - aggregate_function::AggregateFunction::{Max, Min}, col, expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan}, @@ -71,26 +70,6 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(fun), - distinct, - args, - filter, - order_by, - null_treatment: _, - }) = expr - { - if filter.is_some() || order_by.is_some() { - return Ok(false); - } - aggregate_count += 1; - if *distinct { - for e in args { - fields_set.insert(e); - } - } else if !matches!(fun, Min | Max) { - return Ok(false); - } - } else if let Expr::AggregateFunction(AggregateFunction { func_def: AggregateFunctionDefinition::UDF(fun), distinct, args, @@ -107,7 +86,10 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } - } else if fun.name() != "sum" && fun.name() != "MIN" && fun.name() != "MAX" { + } else if fun.name() != "sum" + && fun.name().to_lowercase() != "min" + && fun.name().to_lowercase() != "max" + { return Ok(false); } } else { @@ -173,6 +155,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // // First aggregate(from bottom) refers to `test.a` column. // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate. + // If we were to write plan above as below without alias // // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\ @@ -200,55 +183,6 @@ impl OptimizerRule for SingleDistinctToGroupBy { let outer_aggr_exprs = aggr_expr .into_iter() .map(|aggr_expr| match aggr_expr { - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(fun), - mut args, - distinct, - .. - }) => { - if distinct { - if args.len() != 1 { - return internal_err!("DISTINCT aggregate should have exactly one argument"); - } - let arg = args.swap_remove(0); - - if group_fields_set.insert(arg.display_name()?) { - inner_group_exprs - .push(arg.alias(SINGLE_DISTINCT_ALIAS)); - } - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun, - vec![col(SINGLE_DISTINCT_ALIAS)], - false, // intentional to remove distinct here - None, - None, - None, - ))) - // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation - } else { - index += 1; - let alias_str = format!("alias{}", index); - inner_aggr_exprs.push( - Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - args, - false, - None, - None, - None, - )) - .alias(&alias_str), - ); - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun, - vec![col(&alias_str)], - false, - None, - None, - None, - ))) - } - } Expr::AggregateFunction(AggregateFunction { func_def: AggregateFunctionDefinition::UDF(udf), mut args, @@ -355,13 +289,23 @@ mod tests { use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; use datafusion_expr::ExprFunctionExt; - use datafusion_expr::{ - lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, - }; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_functions_aggregate::expr_fn::{count, count_distinct, sum}; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum}; + use datafusion_functions_aggregate::min_max::max_udaf; use datafusion_functions_aggregate::sum::sum_udaf; + fn max_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + max_udaf(), + vec![expr], + true, + None, + None, + None, + )) + } + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(SingleDistinctToGroupBy::new()), @@ -520,17 +464,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], - vec![ - count_distinct(col("b")), - Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Max, - vec![col("b")], - true, - None, - None, - None, - )), - ], + vec![count_distinct(col("b")), max_distinct(col("b"))], )? .build()?; // Should work @@ -587,14 +521,7 @@ mod tests { vec![ sum(col("c")), count_distinct(col("b")), - Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Max, - vec![col("b")], - true, - None, - None, - None, - )), + max_distinct(col("b")), ], )? .build()?; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs deleted file mode 100644 index bdc41ff0a9bc..000000000000 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ /dev/null @@ -1,208 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Declaration of built-in (aggregate) functions. -//! This module contains built-in aggregates' enumeration and metadata. -//! -//! Generally, an aggregate has: -//! * a signature -//! * a return type, that is a function of the incoming argument's types -//! * the computation, that must accept each valid signature -//! -//! * Signature: see `Signature` -//! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. - -use std::sync::Arc; - -use arrow::datatypes::Schema; - -use datafusion_common::Result; -use datafusion_expr::AggregateFunction; - -use crate::expressions::{self}; -use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; - -/// Create a physical aggregation expression. -/// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. -pub fn create_aggregate_expr( - fun: &AggregateFunction, - distinct: bool, - input_phy_exprs: &[Arc], - _ordering_req: &[PhysicalSortExpr], - input_schema: &Schema, - name: impl Into, - _ignore_nulls: bool, -) -> Result> { - let name = name.into(); - // get the result data type for this aggregate function - let input_phy_types = input_phy_exprs - .iter() - .map(|e| e.data_type(input_schema)) - .collect::>>()?; - let data_type = input_phy_types[0].clone(); - let input_phy_exprs = input_phy_exprs.to_vec(); - Ok(match (fun, distinct) { - (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( - Arc::clone(&input_phy_exprs[0]), - name, - data_type, - )), - (AggregateFunction::Max, _) => Arc::new(expressions::Max::new( - Arc::clone(&input_phy_exprs[0]), - name, - data_type, - )), - }) -} - -#[cfg(test)] -mod tests { - use arrow::datatypes::{DataType, Field}; - - use datafusion_common::plan_err; - use datafusion_expr::{type_coercion, Signature}; - - use crate::expressions::{try_cast, Max, Min}; - - use super::*; - - #[test] - fn test_min_max_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::Min => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::Max => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - }; - } - } - Ok(()) - } - - #[test] - fn test_min_max() -> Result<()> { - let observed = AggregateFunction::Min.return_type(&[DataType::Utf8], &[true])?; - assert_eq!(DataType::Utf8, observed); - - let observed = AggregateFunction::Max.return_type(&[DataType::Int32], &[true])?; - assert_eq!(DataType::Int32, observed); - - // test decimal for min - let observed = AggregateFunction::Min - .return_type(&[DataType::Decimal128(10, 6)], &[true])?; - assert_eq!(DataType::Decimal128(10, 6), observed); - - // test decimal for max - let observed = AggregateFunction::Max - .return_type(&[DataType::Decimal128(28, 13)], &[true])?; - assert_eq!(DataType::Decimal128(28, 13), observed); - - Ok(()) - } - - // Helper function - // Create aggregate expr with type coercion - fn create_physical_agg_expr_for_test( - fun: &AggregateFunction, - distinct: bool, - input_phy_exprs: &[Arc], - input_schema: &Schema, - name: impl Into, - ) -> Result> { - let name = name.into(); - let coerced_phy_exprs = - coerce_exprs_for_test(fun, input_phy_exprs, input_schema, &fun.signature())?; - if coerced_phy_exprs.is_empty() { - return plan_err!( - "Invalid or wrong number of arguments passed to aggregate: '{name}'" - ); - } - create_aggregate_expr( - fun, - distinct, - &coerced_phy_exprs, - &[], - input_schema, - name, - false, - ) - } - - // Returns the coerced exprs for each `input_exprs`. - // Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the - // data type of `input_exprs` need to be coerced. - fn coerce_exprs_for_test( - agg_fun: &AggregateFunction, - input_exprs: &[Arc], - schema: &Schema, - signature: &Signature, - ) -> Result>> { - if input_exprs.is_empty() { - return Ok(vec![]); - } - let input_types = input_exprs - .iter() - .map(|e| e.data_type(schema)) - .collect::>>()?; - - // get the coerced data types - let coerced_types = - type_coercion::aggregates::coerce_types(agg_fun, &input_types, signature)?; - - // try cast if need - input_exprs - .iter() - .zip(coerced_types) - .map(|(expr, coerced_type)| try_cast(Arc::clone(expr), schema, coerced_type)) - .collect::>>() - } -} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs index 1944e2b2d415..3c0f3a28fedb 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -25,7 +25,3 @@ pub(crate) mod accumulate { } pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; - -pub(crate) mod prim_op { - pub use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 264c48513050..0760986a87c6 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -15,12 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -pub(crate) mod min_max; pub(crate) mod groups_accumulator; pub(crate) mod stats; -pub mod build_in; pub mod moving_min_max; pub mod utils { pub use datafusion_physical_expr_common::aggregate::utils::{ diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 7cbe4e796844..cbb697b5f304 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -31,11 +31,6 @@ mod try_cast; mod unknown_column; /// Module with some convenient methods used in expression building -pub mod helpers { - pub use crate::aggregate::min_max::{max, min}; -} -pub use crate::aggregate::build_in::create_aggregate_expr; -pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::stats::StatsType; pub use crate::window::cume_dist::{cume_dist, CumeDist}; pub use crate::window::lead_lag::{lag, lead, WindowShift}; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index d1152038eb2a..43f9f98283bb 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -55,9 +55,6 @@ mod row_hash; mod topk; mod topk_stream; -pub use datafusion_expr::AggregateFunction; -pub use datafusion_physical_expr::expressions::create_aggregate_expr; - /// Hash aggregate modes #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AggregateMode { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index a462430ca381..65cef28efc45 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -21,7 +21,6 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::{ - aggregates, expressions::{ cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, PhysicalSortExpr, RowNumber, @@ -104,23 +103,6 @@ pub fn create_window_expr( ignore_nulls: bool, ) -> Result> { Ok(match fun { - WindowFunctionDefinition::AggregateFunction(fun) => { - let aggregate = aggregates::create_aggregate_expr( - fun, - false, - args, - &[], - input_schema, - name, - ignore_nulls, - )?; - window_expr_from_aggregate_expr( - partition_by, - order_by, - window_frame, - aggregate, - ) - } WindowFunctionDefinition::BuiltInWindowFunction(fun) => { Arc::new(BuiltInWindowExpr::new( create_built_in_window_expr(fun, args, input_schema, name, ignore_nulls)?, diff --git a/datafusion/proto/gen/src/main.rs b/datafusion/proto/gen/src/main.rs index d38a41a01ac2..d3b3c92f6065 100644 --- a/datafusion/proto/gen/src/main.rs +++ b/datafusion/proto/gen/src/main.rs @@ -33,6 +33,7 @@ fn main() -> Result<(), String> { .file_descriptor_set_path(&descriptor_path) .out_dir(out_dir) .compile_well_known_types() + .protoc_arg("--experimental_allow_proto3_optional") .extern_path(".google.protobuf", "::pbjson_types") .compile_protos(&[proto_path], &["proto"]) .map_err(|e| format!("protobuf compilation failed: {e}"))?; @@ -52,7 +53,11 @@ fn main() -> Result<(), String> { let prost = proto_dir.join("src/datafusion.rs"); let pbjson = proto_dir.join("src/datafusion.serde.rs"); let common_path = proto_dir.join("src/datafusion_common.rs"); - + println!( + "Copying {} to {}", + prost.clone().display(), + proto_dir.join("src/generated/prost.rs").display() + ); std::fs::copy(prost, proto_dir.join("src/generated/prost.rs")).unwrap(); std::fs::copy(pbjson, proto_dir.join("src/generated/pbjson.rs")).unwrap(); std::fs::copy( diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 4c90297263c4..819130b08e86 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -311,8 +311,6 @@ message LogicalExprNode { // binary expressions BinaryExprNode binary_expr = 4; - // aggregate expressions - AggregateExprNode aggregate_expr = 5; // null checks IsNull is_null_expr = 6; @@ -466,51 +464,6 @@ message InListNode { bool negated = 3; } -enum AggregateFunction { - MIN = 0; - MAX = 1; - // SUM = 2; - // AVG = 3; - // COUNT = 4; - // APPROX_DISTINCT = 5; - // ARRAY_AGG = 6; - // VARIANCE = 7; - // VARIANCE_POP = 8; - // COVARIANCE = 9; - // COVARIANCE_POP = 10; - // STDDEV = 11; - // STDDEV_POP = 12; - // CORRELATION = 13; - // APPROX_PERCENTILE_CONT = 14; - // APPROX_MEDIAN = 15; - // APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; - // GROUPING = 17; - // MEDIAN = 18; - // BIT_AND = 19; - // BIT_OR = 20; - // BIT_XOR = 21; - // BOOL_AND = 22; - // BOOL_OR = 23; - // REGR_SLOPE = 26; - // REGR_INTERCEPT = 27; - // REGR_COUNT = 28; - // REGR_R2 = 29; - // REGR_AVGX = 30; - // REGR_AVGY = 31; - // REGR_SXX = 32; - // REGR_SYY = 33; - // REGR_SXY = 34; - // STRING_AGG = 35; - // NTH_VALUE_AGG = 36; -} - -message AggregateExprNode { - AggregateFunction aggr_function = 1; - repeated LogicalExprNode expr = 2; - bool distinct = 3; - LogicalExprNode filter = 4; - repeated LogicalExprNode order_by = 5; -} message AggregateUDFExprNode { string fun_name = 1; @@ -543,7 +496,6 @@ enum BuiltInWindowFunction { message WindowExprNode { oneof window_function { - AggregateFunction aggr_function = 1; BuiltInWindowFunction built_in_function = 2; string udaf = 3; string udwf = 9; @@ -853,7 +805,6 @@ message PhysicalScalarUdfNode { message PhysicalAggregateExprNode { oneof AggregateFunction { - AggregateFunction aggr_function = 1; string user_defined_aggr_function = 4; } repeated PhysicalExprNode expr = 2; @@ -865,7 +816,6 @@ message PhysicalAggregateExprNode { message PhysicalWindowExprNode { oneof window_function { - AggregateFunction aggr_function = 1; BuiltInWindowFunction built_in_function = 2; string user_defined_aggr_function = 3; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 163a4c044aeb..521a0d90c1ed 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -362,240 +362,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { deserializer.deserialize_struct("datafusion.AggregateExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for AggregateExprNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.aggr_function != 0 { - len += 1; - } - if !self.expr.is_empty() { - len += 1; - } - if self.distinct { - len += 1; - } - if self.filter.is_some() { - len += 1; - } - if !self.order_by.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExprNode", len)?; - if self.aggr_function != 0 { - let v = AggregateFunction::try_from(self.aggr_function) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.aggr_function)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; - } - if self.distinct { - struct_ser.serialize_field("distinct", &self.distinct)?; - } - if let Some(v) = self.filter.as_ref() { - struct_ser.serialize_field("filter", v)?; - } - if !self.order_by.is_empty() { - struct_ser.serialize_field("orderBy", &self.order_by)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for AggregateExprNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "aggr_function", - "aggrFunction", - "expr", - "distinct", - "filter", - "order_by", - "orderBy", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - AggrFunction, - Expr, - Distinct, - Filter, - OrderBy, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), - "expr" => Ok(GeneratedField::Expr), - "distinct" => Ok(GeneratedField::Distinct), - "filter" => Ok(GeneratedField::Filter), - "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AggregateExprNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.AggregateExprNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut aggr_function__ = None; - let mut expr__ = None; - let mut distinct__ = None; - let mut filter__ = None; - let mut order_by__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::AggrFunction => { - if aggr_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); - } - aggr_function__ = Some(map_.next_value::()? as i32); - } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = Some(map_.next_value()?); - } - GeneratedField::Distinct => { - if distinct__.is_some() { - return Err(serde::de::Error::duplicate_field("distinct")); - } - distinct__ = Some(map_.next_value()?); - } - GeneratedField::Filter => { - if filter__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); - } - filter__ = map_.next_value()?; - } - GeneratedField::OrderBy => { - if order_by__.is_some() { - return Err(serde::de::Error::duplicate_field("orderBy")); - } - order_by__ = Some(map_.next_value()?); - } - } - } - Ok(AggregateExprNode { - aggr_function: aggr_function__.unwrap_or_default(), - expr: expr__.unwrap_or_default(), - distinct: distinct__.unwrap_or_default(), - filter: filter__, - order_by: order_by__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.AggregateExprNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for AggregateFunction { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::Min => "MIN", - Self::Max => "MAX", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for AggregateFunction { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "MIN", - "MAX", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AggregateFunction; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "MIN" => Ok(AggregateFunction::Min), - "MAX" => Ok(AggregateFunction::Max), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) - } -} impl serde::Serialize for AggregateMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -9488,9 +9254,6 @@ impl serde::Serialize for LogicalExprNode { logical_expr_node::ExprType::BinaryExpr(v) => { struct_ser.serialize_field("binaryExpr", v)?; } - logical_expr_node::ExprType::AggregateExpr(v) => { - struct_ser.serialize_field("aggregateExpr", v)?; - } logical_expr_node::ExprType::IsNullExpr(v) => { struct_ser.serialize_field("isNullExpr", v)?; } @@ -9592,8 +9355,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "literal", "binary_expr", "binaryExpr", - "aggregate_expr", - "aggregateExpr", "is_null_expr", "isNullExpr", "is_not_null_expr", @@ -9647,7 +9408,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { Alias, Literal, BinaryExpr, - AggregateExpr, IsNullExpr, IsNotNullExpr, NotExpr, @@ -9701,7 +9461,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "alias" => Ok(GeneratedField::Alias), "literal" => Ok(GeneratedField::Literal), "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), - "aggregateExpr" | "aggregate_expr" => Ok(GeneratedField::AggregateExpr), "isNullExpr" | "is_null_expr" => Ok(GeneratedField::IsNullExpr), "isNotNullExpr" | "is_not_null_expr" => Ok(GeneratedField::IsNotNullExpr), "notExpr" | "not_expr" => Ok(GeneratedField::NotExpr), @@ -9778,13 +9537,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { return Err(serde::de::Error::duplicate_field("binaryExpr")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::BinaryExpr) -; - } - GeneratedField::AggregateExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregateExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateExpr) ; } GeneratedField::IsNullExpr => { @@ -12708,11 +12460,6 @@ impl serde::Serialize for PhysicalAggregateExprNode { } if let Some(v) = self.aggregate_function.as_ref() { match v { - physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => { - let v = AggregateFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => { struct_ser.serialize_field("userDefinedAggrFunction", v)?; } @@ -12736,8 +12483,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { "ignoreNulls", "fun_definition", "funDefinition", - "aggr_function", - "aggrFunction", "user_defined_aggr_function", "userDefinedAggrFunction", ]; @@ -12749,7 +12494,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { Distinct, IgnoreNulls, FunDefinition, - AggrFunction, UserDefinedAggrFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -12777,7 +12521,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { "distinct" => Ok(GeneratedField::Distinct), "ignoreNulls" | "ignore_nulls" => Ok(GeneratedField::IgnoreNulls), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -12838,12 +12581,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) ; } - GeneratedField::AggrFunction => { - if aggregate_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); - } - aggregate_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_aggregate_expr_node::AggregateFunction::AggrFunction(x as i32)); - } GeneratedField::UserDefinedAggrFunction => { if aggregate_function__.is_some() { return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); @@ -15948,11 +15685,6 @@ impl serde::Serialize for PhysicalWindowExprNode { } if let Some(v) = self.window_function.as_ref() { match v { - physical_window_expr_node::WindowFunction::AggrFunction(v) => { - let v = AggregateFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } physical_window_expr_node::WindowFunction::BuiltInFunction(v) => { let v = BuiltInWindowFunction::try_from(*v) .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; @@ -15983,8 +15715,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "name", "fun_definition", "funDefinition", - "aggr_function", - "aggrFunction", "built_in_function", "builtInFunction", "user_defined_aggr_function", @@ -15999,7 +15729,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { WindowFrame, Name, FunDefinition, - AggrFunction, BuiltInFunction, UserDefinedAggrFunction, } @@ -16029,7 +15758,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), "name" => Ok(GeneratedField::Name), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -16098,12 +15826,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) ; } - GeneratedField::AggrFunction => { - if window_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); - } - window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::AggrFunction(x as i32)); - } GeneratedField::BuiltInFunction => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("builtInFunction")); @@ -20483,11 +20205,6 @@ impl serde::Serialize for WindowExprNode { } if let Some(v) = self.window_function.as_ref() { match v { - window_expr_node::WindowFunction::AggrFunction(v) => { - let v = AggregateFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } window_expr_node::WindowFunction::BuiltInFunction(v) => { let v = BuiltInWindowFunction::try_from(*v) .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; @@ -20520,8 +20237,6 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "windowFrame", "fun_definition", "funDefinition", - "aggr_function", - "aggrFunction", "built_in_function", "builtInFunction", "udaf", @@ -20535,7 +20250,6 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { OrderBy, WindowFrame, FunDefinition, - AggrFunction, BuiltInFunction, Udaf, Udwf, @@ -20565,7 +20279,6 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), "udaf" => Ok(GeneratedField::Udaf), "udwf" => Ok(GeneratedField::Udwf), @@ -20628,12 +20341,6 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) ; } - GeneratedField::AggrFunction => { - if window_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); - } - window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| window_expr_node::WindowFunction::AggrFunction(x as i32)); - } GeneratedField::BuiltInFunction => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("builtInFunction")); diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 606fe3c1699f..070c9b31d3d4 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -488,7 +488,7 @@ pub struct SubqueryAliasNode { pub struct LogicalExprNode { #[prost( oneof = "logical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" )] pub expr_type: ::core::option::Option, } @@ -508,9 +508,6 @@ pub mod logical_expr_node { /// binary expressions #[prost(message, tag = "4")] BinaryExpr(super::BinaryExprNode), - /// aggregate expressions - #[prost(message, tag = "5")] - AggregateExpr(::prost::alloc::boxed::Box), /// null checks #[prost(message, tag = "6")] IsNullExpr(::prost::alloc::boxed::Box), @@ -733,20 +730,6 @@ pub struct InListNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct AggregateExprNode { - #[prost(enumeration = "AggregateFunction", tag = "1")] - pub aggr_function: i32, - #[prost(message, repeated, tag = "2")] - pub expr: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "3")] - pub distinct: bool, - #[prost(message, optional, boxed, tag = "4")] - pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, repeated, tag = "5")] - pub order_by: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] pub struct AggregateUdfExprNode { #[prost(string, tag = "1")] pub fun_name: ::prost::alloc::string::String, @@ -785,7 +768,7 @@ pub struct WindowExprNode { pub window_frame: ::core::option::Option, #[prost(bytes = "vec", optional, tag = "10")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, - #[prost(oneof = "window_expr_node::WindowFunction", tags = "1, 2, 3, 9")] + #[prost(oneof = "window_expr_node::WindowFunction", tags = "2, 3, 9")] pub window_function: ::core::option::Option, } /// Nested message and enum types in `WindowExprNode`. @@ -793,8 +776,6 @@ pub mod window_expr_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum WindowFunction { - #[prost(enumeration = "super::AggregateFunction", tag = "1")] - AggrFunction(i32), #[prost(enumeration = "super::BuiltInWindowFunction", tag = "2")] BuiltInFunction(i32), #[prost(string, tag = "3")] @@ -1301,7 +1282,7 @@ pub struct PhysicalAggregateExprNode { pub ignore_nulls: bool, #[prost(bytes = "vec", optional, tag = "7")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, - #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "1, 4")] + #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "4")] pub aggregate_function: ::core::option::Option< physical_aggregate_expr_node::AggregateFunction, >, @@ -1311,8 +1292,6 @@ pub mod physical_aggregate_expr_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum AggregateFunction { - #[prost(enumeration = "super::AggregateFunction", tag = "1")] - AggrFunction(i32), #[prost(string, tag = "4")] UserDefinedAggrFunction(::prost::alloc::string::String), } @@ -1332,7 +1311,7 @@ pub struct PhysicalWindowExprNode { pub name: ::prost::alloc::string::String, #[prost(bytes = "vec", optional, tag = "9")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, - #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1, 2, 3")] + #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "2, 3")] pub window_function: ::core::option::Option< physical_window_expr_node::WindowFunction, >, @@ -1342,8 +1321,6 @@ pub mod physical_window_expr_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum WindowFunction { - #[prost(enumeration = "super::AggregateFunction", tag = "1")] - AggrFunction(i32), #[prost(enumeration = "super::BuiltInWindowFunction", tag = "2")] BuiltInFunction(i32), #[prost(string, tag = "3")] @@ -1941,65 +1918,6 @@ pub struct PartitionStats { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] -pub enum AggregateFunction { - Min = 0, - /// SUM = 2; - /// AVG = 3; - /// COUNT = 4; - /// APPROX_DISTINCT = 5; - /// ARRAY_AGG = 6; - /// VARIANCE = 7; - /// VARIANCE_POP = 8; - /// COVARIANCE = 9; - /// COVARIANCE_POP = 10; - /// STDDEV = 11; - /// STDDEV_POP = 12; - /// CORRELATION = 13; - /// APPROX_PERCENTILE_CONT = 14; - /// APPROX_MEDIAN = 15; - /// APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; - /// GROUPING = 17; - /// MEDIAN = 18; - /// BIT_AND = 19; - /// BIT_OR = 20; - /// BIT_XOR = 21; - /// BOOL_AND = 22; - /// BOOL_OR = 23; - /// REGR_SLOPE = 26; - /// REGR_INTERCEPT = 27; - /// REGR_COUNT = 28; - /// REGR_R2 = 29; - /// REGR_AVGX = 30; - /// REGR_AVGY = 31; - /// REGR_SXX = 32; - /// REGR_SYY = 33; - /// REGR_SXY = 34; - /// STRING_AGG = 35; - /// NTH_VALUE_AGG = 36; - Max = 1, -} -impl AggregateFunction { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - AggregateFunction::Min => "MIN", - AggregateFunction::Max => "MAX", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "MIN" => Some(Self::Min), - "MAX" => Some(Self::Max), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] pub enum BuiltInWindowFunction { RowNumber = 0, Rank = 1, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 5e9b9af49ae9..6c4c07428bd3 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -22,11 +22,13 @@ use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, Result, ScalarValue, TableReference, UnnestOptions, }; +use datafusion_expr::expr::Unnest; +use datafusion_expr::expr::{Alias, Placeholder}; +use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - expr::{self, Alias, InList, Placeholder, Sort, Unnest, WindowFunction}, + expr::{self, InList, Sort, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, - AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr, - ExprFunctionExt, GroupingSet, + Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -136,15 +138,6 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { } } -impl From for AggregateFunction { - fn from(agg_fun: protobuf::AggregateFunction) -> Self { - match agg_fun { - protobuf::AggregateFunction::Min => Self::Min, - protobuf::AggregateFunction::Max => Self::Max, - } - } -} - impl From for BuiltInWindowFunction { fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self { match built_in_function { @@ -231,12 +224,6 @@ impl From for JoinConstraint { } } -pub fn parse_i32_to_aggregate_function(value: &i32) -> Result { - protobuf::AggregateFunction::try_from(*value) - .map(|a| a.into()) - .map_err(|_| Error::unknown("AggregateFunction", *value)) -} - pub fn parse_expr( proto: &protobuf::LogicalExprNode, registry: &dyn FunctionRegistry, @@ -297,24 +284,6 @@ pub fn parse_expr( // TODO: support proto for null treatment match window_function { - window_expr_node::WindowFunction::AggrFunction(i) => { - let aggr_function = parse_i32_to_aggregate_function(i)?; - - Expr::WindowFunction(WindowFunction::new( - expr::WindowFunctionDefinition::AggregateFunction(aggr_function), - vec![parse_required_expr( - expr.expr.as_deref(), - registry, - "expr", - codec, - )?], - )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .build() - .map_err(Error::DataFusionError) - } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) .map_err(|_| Error::unknown("BuiltInWindowFunction", *i))? @@ -379,19 +348,6 @@ pub fn parse_expr( } } } - ExprType::AggregateExpr(expr) => { - let fun = parse_i32_to_aggregate_function(&expr.aggr_function)?; - - Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - parse_exprs(&expr.expr, registry, codec)?, - expr.distinct, - parse_optional_expr(expr.filter.as_deref(), registry, codec)? - .map(Box::new), - parse_vec_expr(&expr.order_by, registry, codec)?, - None, - ))) - } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, alias diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index c2441892e8a8..74d9d61b3a7f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -25,9 +25,9 @@ use datafusion_expr::expr::{ InList, Like, Placeholder, ScalarFunction, Sort, Unnest, }; use datafusion_expr::{ - logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, - BuiltInWindowFunction, Expr, JoinConstraint, JoinType, TryCast, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + logical_plan::PlanType, logical_plan::StringifiedPlan, BuiltInWindowFunction, Expr, + JoinConstraint, JoinType, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, }; use crate::protobuf::{ @@ -111,15 +111,6 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { } } -impl From<&AggregateFunction> for protobuf::AggregateFunction { - fn from(value: &AggregateFunction) -> Self { - match value { - AggregateFunction::Min => Self::Min, - AggregateFunction::Max => Self::Max, - } - } -} - impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction { fn from(value: &BuiltInWindowFunction) -> Self { match value { @@ -319,12 +310,6 @@ pub fn serialize_expr( null_treatment: _, }) => { let (window_function, fun_definition) = match fun { - WindowFunctionDefinition::AggregateFunction(fun) => ( - protobuf::window_expr_node::WindowFunction::AggrFunction( - protobuf::AggregateFunction::from(fun).into(), - ), - None, - ), WindowFunctionDefinition::BuiltInWindowFunction(fun) => ( protobuf::window_expr_node::WindowFunction::BuiltInFunction( protobuf::BuiltInWindowFunction::from(fun).into(), @@ -383,29 +368,6 @@ pub fn serialize_expr( ref order_by, null_treatment: _, }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let aggr_function = match fun { - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: serialize_exprs(args, codec)?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(serialize_expr(e, codec)?)), - None => None, - }, - order_by: match order_by { - Some(e) => serialize_exprs(e, codec)?, - None => vec![], - }, - }; - protobuf::LogicalExprNode { - expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), - } - } AggregateFunctionDefinition::UDF(fun) => { let mut buf = Vec::new(); let _ = codec.try_encode_udaf(fun, &mut buf); diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 5ecca5147805..bc0a19336bae 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -145,15 +145,6 @@ pub fn parse_physical_window_expr( let fun = if let Some(window_func) = proto.window_function.as_ref() { match window_func { - protobuf::physical_window_expr_node::WindowFunction::AggrFunction(n) => { - let f = protobuf::AggregateFunction::try_from(*n).map_err(|_| { - proto_error(format!( - "Received an unknown window aggregate function: {n}" - )) - })?; - - WindowFunctionDefinition::AggregateFunction(f.into()) - } protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { proto_error(format!( diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 1f433ff01d12..fbb9e442980b 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -35,7 +35,7 @@ use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; use datafusion::physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; -use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateMode}; +use datafusion::physical_plan::aggregates::AggregateMode; use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; @@ -477,30 +477,10 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ExprType::AggregateExpr(agg_node) => { let input_phy_expr: Vec> = agg_node.expr.iter() .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec)).collect::>>()?; - let ordering_req: Vec = agg_node.ordering_req.iter() + let _ordering_req: Vec = agg_node.ordering_req.iter() .map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec)).collect::>>()?; agg_node.aggregate_function.as_ref().map(|func| { match func { - AggregateFunction::AggrFunction(i) => { - let aggr_function = protobuf::AggregateFunction::try_from(*i) - .map_err( - |_| { - proto_error(format!( - "Received an unknown aggregate function: {i}" - )) - }, - )?; - - create_aggregate_expr( - &aggr_function.into(), - agg_node.distinct, - input_phy_expr.as_slice(), - &ordering_req, - &physical_schema, - name.to_string(), - agg_node.ignore_nulls, - ) - } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = match &agg_node.fun_definition { Some(buf) => extension_codec.try_decode_udaf(udaf_name, buf)?, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 140482b9903c..57cd22a99ae1 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,8 +24,8 @@ use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, - IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile, Rank, - RankType, RowNumber, TryCastExpr, WindowShift, + IsNullExpr, Literal, NegativeExpr, NotExpr, NthValue, Ntile, Rank, RankType, + RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -60,7 +60,7 @@ pub fn serialize_physical_aggr_expr( let name = a.fun().name().to_string(); let mut buf = Vec::new(); codec.try_encode_udaf(a.fun(), &mut buf)?; - return Ok(protobuf::PhysicalExprNode { + Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), @@ -71,35 +71,15 @@ pub fn serialize_physical_aggr_expr( fun_definition: (!buf.is_empty()).then_some(buf) }, )), - }); + }) + } else { + unreachable!("No other types exists besides AggergationFunctionExpr"); } - - let AggrFn { - inner: aggr_function, - distinct, - } = aggr_expr_to_aggr_fn(aggr_expr.as_ref())?; - - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( - protobuf::PhysicalAggregateExprNode { - aggregate_function: Some( - physical_aggregate_expr_node::AggregateFunction::AggrFunction( - aggr_function as i32, - ), - ), - expr: expressions, - ordering_req, - distinct, - ignore_nulls: false, - fun_definition: None, - }, - )), - }) } fn serialize_physical_window_aggr_expr( aggr_expr: &dyn AggregateExpr, - window_frame: &WindowFrame, + _window_frame: &WindowFrame, codec: &dyn PhysicalExtensionCodec, ) -> Result<(physical_window_expr_node::WindowFunction, Option>)> { if let Some(a) = aggr_expr.as_any().downcast_ref::() { @@ -119,23 +99,7 @@ fn serialize_physical_window_aggr_expr( (!buf.is_empty()).then_some(buf), )) } else { - let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(aggr_expr)?; - if distinct { - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } - - if !window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!( - "Unbounded start bound in WindowFrame = {window_frame}" - ))); - } - - Ok(( - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32), - None, - )) + unreachable!("No other types exists besides AggergationFunctionExpr"); } } @@ -252,29 +216,6 @@ pub fn serialize_physical_window_expr( }) } -struct AggrFn { - inner: protobuf::AggregateFunction, - distinct: bool, -} - -fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { - let aggr_expr = expr.as_any(); - - // TODO: remove Min and Max - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Min - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Max - } else { - return not_impl_err!("Aggregate function not supported: {expr:?}"); - }; - - Ok(AggrFn { - inner, - distinct: false, - }) -} - pub fn serialize_physical_sort_exprs( sort_exprs: I, codec: &dyn PhysicalExtensionCodec, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f7ad2b9b6158..d150c474e88f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -42,9 +42,10 @@ use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, - count_distinct, covar_pop, covar_samp, first_value, grouping, median, stddev, - stddev_pop, sum, var_pop, var_sample, + count_distinct, covar_pop, covar_samp, first_value, grouping, max, median, min, + stddev, stddev_pop, sum, var_pop, var_sample, }; +use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_nested::map::map; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -61,10 +62,10 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateFunction, AggregateUDF, ColumnarValue, ExprFunctionExt, - ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, - Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, WindowUDF, WindowUDFImpl, + Accumulator, AggregateUDF, ColumnarValue, ExprFunctionExt, ExprSchemable, Literal, + LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, + WindowUDFImpl, }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::expr_fn::{ @@ -875,7 +876,9 @@ async fn roundtrip_expr_api() -> Result<()> { covar_pop(lit(1.5), lit(2.2)), corr(lit(1.5), lit(2.2)), sum(lit(1)), + max(lit(1)), median(lit(2)), + min(lit(2)), var_sample(lit(2.2)), var_pop(lit(2.2)), stddev(lit(2.2)), @@ -2284,7 +2287,7 @@ fn roundtrip_window() { ); let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], )) .partition_by(vec![col("col1")]) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 3ddc122e3de2..0e2bc9cbb3e2 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -25,8 +25,10 @@ use std::vec; use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use datafusion::physical_expr_common::aggregate::AggregateExprBuilder; +use datafusion_functions_aggregate::min_max::max_udaf; use prost::Message; +use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -43,7 +45,7 @@ use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::aggregate::utils::down_cast_any_ref; -use datafusion::physical_expr::expressions::{Literal, Max}; +use datafusion::physical_expr::expressions::Literal; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -92,8 +94,6 @@ use datafusion_proto::physical_plan::{ }; use datafusion_proto::protobuf; -use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; - /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is /// lost during serde because the string representation of a plan often only shows a subset of state. @@ -909,11 +909,18 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { )), input, )?); + let aggr_expr = AggregateExprBuilder::new( + max_udaf(), + vec![udf_expr.clone() as Arc], + ) + .schema(schema.clone()) + .name("max") + .build()?; let window = Arc::new(WindowAggExec::try_new( vec![Arc::new(PlainAggregateWindowExpr::new( - Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)), - &[col("author", &schema)?], + aggr_expr.clone(), + &[col("author", &schema.clone())?], &[], Arc::new(WindowFrame::new(None)), ))], @@ -924,7 +931,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::new(vec![], vec![], vec![]), - vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))], + vec![aggr_expr.clone()], vec![None], window, schema.clone(), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 2506ef740fde..d16d08b041ae 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::str::FromStr; - use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; @@ -26,8 +24,7 @@ use datafusion_common::{ }; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{ - expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, - WindowFunctionDefinition, + expr, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition, }; use datafusion_expr::{ expr::{ScalarFunction, Unnest}, @@ -38,7 +35,6 @@ use sqlparser::ast::{ FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, NullTreatment, ObjectName, OrderByExpr, WindowType, }; - use strum::IntoEnumIterator; /// Suggest a valid function based on an invalid input function name @@ -51,7 +47,6 @@ pub fn suggest_valid_function( // All aggregate functions and builtin window functions let mut funcs = Vec::new(); - funcs.extend(AggregateFunction::iter().map(|func| func.to_string())); funcs.extend(ctx.udaf_names()); funcs.extend(BuiltInWindowFunction::iter().map(|func| func.to_string())); funcs.extend(ctx.udwf_names()); @@ -62,7 +57,6 @@ pub fn suggest_valid_function( let mut funcs = Vec::new(); funcs.extend(ctx.udf_names()); - funcs.extend(AggregateFunction::iter().map(|func| func.to_string())); funcs.extend(ctx.udaf_names()); funcs @@ -324,31 +318,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; if let Ok(fun) = self.find_window_func(&name) { - return match fun { - WindowFunctionDefinition::AggregateFunction(aggregate_fun) => { - let args = - self.function_args_to_expr(args, schema, planner_context)?; - - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(aggregate_fun), - args, - )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - } - _ => Expr::WindowFunction(expr::WindowFunction::new( - fun, - self.function_args_to_expr(args, schema, planner_context)?, - )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build(), - }; + return Expr::WindowFunction(expr::WindowFunction::new( + fun, + self.function_args_to_expr(args, schema, planner_context)?, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build(); } } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function @@ -375,32 +353,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { null_treatment, ))); } - - // next, aggregate built-ins - if let Ok(fun) = AggregateFunction::from_str(&name) { - let order_by = self.order_by_to_sort_expr( - order_by, - schema, - planner_context, - true, - None, - )?; - let order_by = (!order_by.is_empty()).then_some(order_by); - let args = self.function_args_to_expr(args, schema, planner_context)?; - let filter: Option> = filter - .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) - .transpose()? - .map(Box::new); - - return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - args, - distinct, - filter, - order_by, - null_treatment, - ))); - }; } // Could not find the relevant function, so return an error diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index d1ac7a0c96d1..bae3ec2e2779 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -20,7 +20,7 @@ use std::vec; use arrow_schema::*; use datafusion_common::{DFSchema, Result, TableReference}; -use datafusion_expr::test::function_stub::{count_udaf, sum_udaf}; +use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_udaf}; use datafusion_expr::{col, table_scan}; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ @@ -381,7 +381,9 @@ fn roundtrip_statement_with_dialect() -> Result<()> { .parse_statement()?; let context = MockContextProvider::default() - .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) + .with_udaf(max_udaf()) + .with_udaf(min_udaf()); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel .sql_statement_to_plan(statement) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 40a58827b388..c1b2246e4980 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -42,7 +42,8 @@ use datafusion_sql::{ use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::{ - approx_median::approx_median_udaf, count::count_udaf, + approx_median::approx_median_udaf, count::count_udaf, min_max::max_udaf, + min_max::min_udaf, }; use datafusion_functions_aggregate::{average::avg_udaf, grouping::grouping_udaf}; use rstest::rstest; @@ -2764,6 +2765,8 @@ fn logical_plan_with_dialect_and_options( .with_udaf(approx_median_udaf()) .with_udaf(count_udaf()) .with_udaf(avg_udaf()) + .with_udaf(min_udaf()) + .with_udaf(max_udaf()) .with_udaf(grouping_udaf()) .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index eebadb239d56..89f2efec66aa 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -30,8 +30,8 @@ use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ - aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case, - EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values, + expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr, + ExprSchemable, LogicalPlan, Operator, Projection, Values, }; use substrait::proto::expression::subquery::set_predicate::PredicateOp; use url::Url; @@ -67,7 +67,6 @@ use datafusion::{ scalar::ScalarValue, }; use std::collections::{HashMap, HashSet}; -use std::str::FromStr; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::literal::user_defined::Val; @@ -1005,11 +1004,6 @@ pub async fn from_substrait_agg_func( Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), ))) - } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) - { - Ok(Arc::new(Expr::AggregateFunction( - expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None), - ))) } else { not_impl_err!( "Aggregate function {} is not supported: function anchor = {:?}", diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 8263209ffccc..bd6e0e00491a 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -48,7 +48,6 @@ use datafusion::common::{ }; use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] -use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Sort, WindowFunction, @@ -767,37 +766,6 @@ pub fn to_substrait_agg_measure( match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by, null_treatment: _, }) => { match func_def { - AggregateFunctionDefinition::BuiltIn (fun) => { - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); - } - let function_anchor = extensions.register_function(fun.to_string()); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), - None => None - } - }) - } AggregateFunctionDefinition::UDF(fun) => { let sorts = if let Some(order_by) = order_by { order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index c3d0b6c2d688..96be1bb9e256 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -53,6 +53,7 @@ Here is a minimal example showing the execution of a query using the DataFrame A ```rust use datafusion::prelude::*; use datafusion::error::Result; +use datafusion::functions_aggregate::expr_fn::min; #[tokio::main] async fn main() -> Result<()> { diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 813dbb1bc02a..6108315f398a 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -60,6 +60,7 @@ library guide for more information on the SQL API. ```rust use datafusion::prelude::*; +use datafusion::functions_aggregate::expr_fn::min; #[tokio::main] async fn main() -> datafusion::error::Result<()> { @@ -148,6 +149,7 @@ async fn main() -> datafusion::error::Result<()> { ```rust use datafusion::prelude::*; +use datafusion::functions_aggregate::expr_fn::min; #[tokio::main] async fn main() -> datafusion::error::Result<()> {