From b4069a65a9bb207370d382bdde93f1c98d69b9eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 5 Aug 2024 08:15:57 +0800 Subject: [PATCH] Remove `AggregateFunctionDefinition` (#11803) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove [200~if udf.name() == count => { * Apply review suggestions --- datafusion/core/src/physical_planner.rs | 69 +++++++++---------- datafusion/expr/src/expr.rs | 34 +++------ datafusion/expr/src/expr_schema.rs | 47 ++++++------- datafusion/expr/src/tree_node.rs | 31 ++++----- datafusion/functions-nested/src/planner.rs | 4 +- .../src/analyzer/count_wildcard_rule.rs | 8 +-- .../optimizer/src/analyzer/type_coercion.rs | 42 ++++++----- datafusion/optimizer/src/decorrelate.rs | 20 ++---- .../simplify_expressions/expr_simplifier.rs | 8 +-- .../src/single_distinct_to_groupby.rs | 17 +++-- datafusion/proto/src/logical_plan/to_proto.rs | 50 +++++++------- datafusion/sql/src/unparser/expr.rs | 2 +- .../substrait/src/logical_plan/producer.rs | 11 +-- 13 files changed, 144 insertions(+), 199 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 03e20b886e2c..378a892111c5 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -74,8 +74,8 @@ use datafusion_common::{ }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ - self, AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, - Cast, GroupingSet, InList, Like, TryCast, WindowFunction, + self, AggregateFunction, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, + TryCast, WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::expr_vec_fmt; @@ -223,18 +223,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { create_function_physical_name(&fun.to_string(), false, args, Some(order_by)) } Expr::AggregateFunction(AggregateFunction { - func_def, + func, distinct, args, filter: _, order_by, null_treatment: _, - }) => create_function_physical_name( - func_def.name(), - *distinct, - args, - order_by.as_ref(), - ), + }) => { + create_function_physical_name(func.name(), *distinct, args, order_by.as_ref()) + } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( "ROLLUP ({})", @@ -1817,7 +1814,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ) -> Result { match e { Expr::AggregateFunction(AggregateFunction { - func_def, + func, distinct, args, filter, @@ -1839,36 +1836,34 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; - let (agg_expr, filter, order_by) = match func_def { - AggregateFunctionDefinition::UDF(fun) => { - let sort_exprs = order_by.clone().unwrap_or(vec![]); - let physical_sort_exprs = match order_by { - Some(exprs) => Some(create_physical_sort_exprs( - exprs, - logical_input_schema, - execution_props, - )?), - None => None, - }; + let (agg_expr, filter, order_by) = { + let sort_exprs = order_by.clone().unwrap_or(vec![]); + let physical_sort_exprs = match order_by { + Some(exprs) => Some(create_physical_sort_exprs( + exprs, + logical_input_schema, + execution_props, + )?), + None => None, + }; - let ordering_reqs: Vec = - physical_sort_exprs.clone().unwrap_or(vec![]); + let ordering_reqs: Vec = + physical_sort_exprs.clone().unwrap_or(vec![]); - let agg_expr = udaf::create_aggregate_expr_with_dfschema( - fun, - &physical_args, - args, - &sort_exprs, - &ordering_reqs, - logical_input_schema, - name, - ignore_nulls, - *distinct, - false, - )?; + let agg_expr = udaf::create_aggregate_expr_with_dfschema( + func, + &physical_args, + args, + &sort_exprs, + &ordering_reqs, + logical_input_schema, + name, + ignore_nulls, + *distinct, + false, + )?; - (agg_expr, filter, physical_sort_exprs) - } + (agg_expr, filter, physical_sort_exprs) }; Ok((agg_expr, filter, order_by)) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 708843494814..1a51c181f49f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -627,22 +627,6 @@ impl Sort { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -/// Defines which implementation of an aggregate function DataFusion should call. -pub enum AggregateFunctionDefinition { - /// Resolved to a user defined aggregate function - UDF(Arc), -} - -impl AggregateFunctionDefinition { - /// Function's name for display - pub fn name(&self) -> &str { - match self { - AggregateFunctionDefinition::UDF(udf) => udf.name(), - } - } -} - /// Aggregate function /// /// See also [`ExprFunctionExt`] to set these fields on `Expr` @@ -651,7 +635,7 @@ impl AggregateFunctionDefinition { #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function - pub func_def: AggregateFunctionDefinition, + pub func: Arc, /// List of expressions to feed to the functions as arguments pub args: Vec, /// Whether this is a DISTINCT aggregation or not @@ -666,7 +650,7 @@ pub struct AggregateFunction { impl AggregateFunction { /// Create a new AggregateFunction expression with a user-defined function (UDF) pub fn new_udf( - udf: Arc, + func: Arc, args: Vec, distinct: bool, filter: Option>, @@ -674,7 +658,7 @@ impl AggregateFunction { null_treatment: Option, ) -> Self { Self { - func_def: AggregateFunctionDefinition::UDF(udf), + func, args, distinct, filter, @@ -1666,14 +1650,14 @@ impl Expr { func.hash(hasher); } Expr::AggregateFunction(AggregateFunction { - func_def, + func, args: _args, distinct, filter: _filter, order_by: _order_by, null_treatment, }) => { - func_def.hash(hasher); + func.hash(hasher); distinct.hash(hasher); null_treatment.hash(hasher); } @@ -1870,7 +1854,7 @@ impl fmt::Display for Expr { Ok(()) } Expr::AggregateFunction(AggregateFunction { - func_def, + func, distinct, ref args, filter, @@ -1878,7 +1862,7 @@ impl fmt::Display for Expr { null_treatment, .. }) => { - fmt_function(f, func_def.name(), *distinct, args, true)?; + fmt_function(f, func.name(), *distinct, args, true)?; if let Some(nt) = null_treatment { write!(f, " {}", nt)?; } @@ -2190,14 +2174,14 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { write!(w, "{window_frame}")?; } Expr::AggregateFunction(AggregateFunction { - func_def, + func, distinct, args, filter, order_by, null_treatment, }) => { - write_function_name(w, func_def.name(), *distinct, args)?; + write_function_name(w, func.name(), *distinct, args)?; if let Some(fe) = filter { write!(w, " FILTER (WHERE {fe})")?; }; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 6344b892adb7..676903d59a07 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,8 +17,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast, InList, - InSubquery, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, + AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, + ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; use crate::type_coercion::binary::get_result_type; use crate::type_coercion::functions::{ @@ -193,28 +193,24 @@ impl ExprSchemable for Expr { _ => fun.return_type(&data_types, &nullability), } } - Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { + Expr::AggregateFunction(AggregateFunction { func, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - match func_def { - AggregateFunctionDefinition::UDF(fun) => { - let new_types = data_types_with_aggregate_udf(&data_types, fun) - .map_err(|err| { - plan_datafusion_err!( - "{} {}", - err, - utils::generate_signature_error_msg( - fun.name(), - fun.signature().clone(), - &data_types - ) + let new_types = data_types_with_aggregate_udf(&data_types, func) + .map_err(|err| { + plan_datafusion_err!( + "{} {}", + err, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &data_types ) - })?; - Ok(fun.return_type(&new_types)?) - } - } + ) + })?; + Ok(func.return_type(&new_types)?) } Expr::Not(_) | Expr::IsNull(_) @@ -329,13 +325,12 @@ impl ExprSchemable for Expr { } } Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), - Expr::AggregateFunction(AggregateFunction { func_def, .. }) => { - match func_def { - // TODO: UDF should be able to customize nullability - AggregateFunctionDefinition::UDF(udf) if udf.name() == "count" => { - Ok(false) - } - AggregateFunctionDefinition::UDF(_) => Ok(true), + Expr::AggregateFunction(AggregateFunction { func, .. }) => { + // TODO: UDF should be able to customize nullability + if func.name() == "count" { + Ok(false) + } else { + Ok(true) } } Expr::ScalarVariable(_, _) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index a8062c0c07ee..450ebb6c2275 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -18,9 +18,8 @@ //! Tree node implementation for logical expr use crate::expr::{ - AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case, - Cast, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, Sort, - TryCast, Unnest, WindowFunction, + AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, + InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; use crate::{Expr, ExprFunctionExt}; @@ -304,7 +303,7 @@ impl TreeNode for Expr { }), Expr::AggregateFunction(AggregateFunction { args, - func_def, + func, distinct, filter, order_by, @@ -316,20 +315,16 @@ impl TreeNode for Expr { order_by, transform_option_vec(order_by, &mut f) )? - .map_data( - |(new_args, new_filter, new_order_by)| match func_def { - AggregateFunctionDefinition::UDF(fun) => { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - fun, - new_args, - distinct, - new_filter, - new_order_by, - null_treatment, - ))) - } - }, - )?, + .map_data(|(new_args, new_filter, new_order_by)| { + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + func, + new_args, + distinct, + new_filter, + new_order_by, + null_treatment, + ))) + })?, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index fee3e83a0d65..f980362105a1 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -20,7 +20,6 @@ use datafusion_common::{exec_err, utils::list_ndims, DFSchema, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - expr::AggregateFunctionDefinition, planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, sqlparser, Expr, ExprSchemable, GetFieldAccess, }; @@ -171,6 +170,5 @@ impl ExprPlanner for FieldAccessPlanner { } fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def; - return udf.name() == "array_agg"; + return agg_func.func.name() == "array_agg"; } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 6f832966671c..e2da6c66abc4 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -21,9 +21,7 @@ use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; -use datafusion_expr::expr::{ - AggregateFunction, AggregateFunctionDefinition, WindowFunction, -}; +use datafusion_expr::expr::{AggregateFunction, WindowFunction}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; @@ -56,10 +54,10 @@ fn is_wildcard(expr: &Expr) -> bool { fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { matches!(aggregate_function, AggregateFunction { - func_def: AggregateFunctionDefinition::UDF(udf), + func, args, .. - } if udf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) + } if func.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index bcd1cbcce23e..2823b0fca2d1 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -28,8 +28,8 @@ use datafusion_common::{ DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::{ - self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, WindowFunction, + self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, + WindowFunction, }; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::tree_node::unwrap_arc; @@ -393,31 +393,29 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { ))) } Expr::AggregateFunction(expr::AggregateFunction { - func_def, + func, args, distinct, filter, order_by, null_treatment, - }) => match func_def { - AggregateFunctionDefinition::UDF(fun) => { - let new_expr = coerce_arguments_for_signature_with_aggregate_udf( - args, - self.schema, - &fun, - )?; - Ok(Transformed::yes(Expr::AggregateFunction( - expr::AggregateFunction::new_udf( - fun, - new_expr, - distinct, - filter, - order_by, - null_treatment, - ), - ))) - } - }, + }) => { + let new_expr = coerce_arguments_for_signature_with_aggregate_udf( + args, + self.schema, + &func, + )?; + Ok(Transformed::yes(Expr::AggregateFunction( + expr::AggregateFunction::new_udf( + func, + new_expr, + distinct, + filter, + order_by, + null_treatment, + ), + ))) + } Expr::WindowFunction(WindowFunction { fun, args, diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 6dbf1641bd7c..fdd9ef8a8b0b 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -28,7 +28,7 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; -use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; +use datafusion_expr::expr::Alias; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; @@ -433,19 +433,13 @@ fn agg_exprs_evaluation_result_on_empty_batch( .clone() .transform_up(|expr| { let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { - func_def, .. - }) => match func_def { - AggregateFunctionDefinition::UDF(fun) => { - if fun.name() == "count" { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( - 0, - )))) - } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } + Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { + if func.name() == "count" { + Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0)))) + } else { + Transformed::yes(Expr::Literal(ScalarValue::Null)) } - }, + } _ => Transformed::no(expr), }; Ok(new_expr) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 1e1418744fb8..979a1499d0de 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,9 +32,7 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{ - AggregateFunctionDefinition, InList, InSubquery, WindowFunction, -}; +use datafusion_expr::expr::{InList, InSubquery, WindowFunction}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, @@ -1408,9 +1406,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction { - func_def: AggregateFunctionDefinition::UDF(ref udaf), + ref func, .. - }) => match (udaf.simplify(), expr) { + }) => match (func.simplify(), expr) { (Some(simplify_function), Expr::AggregateFunction(af)) => { Transformed::yes(simplify_function(af, info)?) } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index fba42d7f880b..9a0fab14d3e0 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -26,7 +26,6 @@ use datafusion_common::{ internal_err, qualified_name, tree_node::Transformed, DataFusionError, Result, }; use datafusion_expr::builder::project; -use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::{ col, expr::AggregateFunction, @@ -70,7 +69,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::UDF(fun), + func, distinct, args, filter, @@ -86,9 +85,9 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } - } else if fun.name() != "sum" - && fun.name().to_lowercase() != "min" - && fun.name().to_lowercase() != "max" + } else if func.name() != "sum" + && func.name().to_lowercase() != "min" + && func.name().to_lowercase() != "max" { return Ok(false); } @@ -184,7 +183,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .into_iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::UDF(udf), + func, mut args, distinct, .. @@ -200,7 +199,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .push(arg.alias(SINGLE_DISTINCT_ALIAS)); } Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - udf, + func, vec![col(SINGLE_DISTINCT_ALIAS)], false, // intentional to remove distinct here None, @@ -213,7 +212,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { let alias_str = format!("alias{}", index); inner_aggr_exprs.push( Expr::AggregateFunction(AggregateFunction::new_udf( - Arc::clone(&udf), + Arc::clone(&func), args, false, None, @@ -223,7 +222,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .alias(&alias_str), ); Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - udf, + func, vec![col(&alias_str)], false, None, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 74d9d61b3a7f..ab81ce8af9cb 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -21,8 +21,8 @@ use datafusion_common::{TableReference, UnnestOptions}; use datafusion_expr::expr::{ - self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, GroupingSet, - InList, Like, Placeholder, ScalarFunction, Sort, Unnest, + self, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, Placeholder, + ScalarFunction, Sort, Unnest, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, BuiltInWindowFunction, Expr, @@ -361,38 +361,34 @@ pub fn serialize_expr( } } Expr::AggregateFunction(expr::AggregateFunction { - ref func_def, + ref func, ref args, ref distinct, ref filter, ref order_by, null_treatment: _, - }) => match func_def { - AggregateFunctionDefinition::UDF(fun) => { - let mut buf = Vec::new(); - let _ = codec.try_encode_udaf(fun, &mut buf); - protobuf::LogicalExprNode { - expr_type: Some(ExprType::AggregateUdfExpr(Box::new( - protobuf::AggregateUdfExprNode { - fun_name: fun.name().to_string(), - args: serialize_exprs(args, codec)?, - distinct: *distinct, - filter: match filter { - Some(e) => { - Some(Box::new(serialize_expr(e.as_ref(), codec)?)) - } - None => None, - }, - order_by: match order_by { - Some(e) => serialize_exprs(e, codec)?, - None => vec![], - }, - fun_definition: (!buf.is_empty()).then_some(buf), + }) => { + let mut buf = Vec::new(); + let _ = codec.try_encode_udaf(func, &mut buf); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::AggregateUdfExpr(Box::new( + protobuf::AggregateUdfExprNode { + fun_name: func.name().to_string(), + args: serialize_exprs(args, codec)?, + distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), + None => None, }, - ))), - } + order_by: match order_by { + Some(e) => serialize_exprs(e, codec)?, + None => vec![], + }, + fun_definition: (!buf.is_empty()).then_some(buf), + }, + ))), } - }, + } Expr::ScalarVariable(_, _) => { return Err(Error::General( diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 9b44848a91a8..de130754ab1a 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -337,7 +337,7 @@ impl Unparser<'_> { escape_char: escape_char.map(|c| c.to_string()), }), Expr::AggregateFunction(agg) => { - let func_name = agg.func_def.name(); + let func_name = agg.func.name(); let args = self.function_args_to_sql(&agg.args)?; let filter = match &agg.filter { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index a782af8eb247..ee04749f5e6b 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -49,8 +49,7 @@ use datafusion::common::{ use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] use datafusion::logical_expr::expr::{ - AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, Sort, WindowFunction, + Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -764,9 +763,7 @@ pub fn to_substrait_agg_measure( extensions: &mut Extensions, ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by, null_treatment: _, }) => { - match func_def { - AggregateFunctionDefinition::UDF(fun) => { + Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }) => { let sorts = if let Some(order_by) = order_by { order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? } else { @@ -776,7 +773,7 @@ pub fn to_substrait_agg_measure( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); } - let function_anchor = extensions.register_function(fun.name().to_string()); + let function_anchor = extensions.register_function(func.name().to_string()); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -796,8 +793,6 @@ pub fn to_substrait_agg_measure( None => None } }) - } - } } Expr::Alias(Alias{expr,..})=> {