From 67be22aa86218f619732a9b3fcffb3e00bf47cc3 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 3 Apr 2024 06:50:10 -0400 Subject: [PATCH 1/5] Minor: Improve documentation for AggregateUDFImpl::accumulator and `AccumulatorArgs` --- datafusion/expr/src/function.rs | 35 ++++++++++++++++++++++++++------- datafusion/expr/src/udaf.rs | 15 +++++++++++++- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 7598c805adf6..591531d53ca1 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -38,18 +38,39 @@ pub type ScalarFunctionImplementation = pub type ReturnTypeFunction = Arc Result> + Send + Sync>; -/// Arguments passed to create an accumulator +/// [`AccumulatorArgs`] contains information about how an aggregate +/// function was called, including the types of its arguments and any optional +/// ordering expressions. pub struct AccumulatorArgs<'a> { - // default arguments - /// the return type of the function + /// The return type of the aggregate function. pub data_type: &'a DataType, - /// the schema of the input arguments + /// The schema of the input arguments pub schema: &'a Schema, - /// whether to ignore nulls + /// Whether to ignore nulls. + /// + /// SQL allows the user to specify `IGNORE NULLS`, for example: + /// + /// ```sql + /// SELECT FIRST_VALUE(column1) IGNORE NULLS FROM t; + /// ``` + /// + /// Aggregates that do not support this functionality should return a not + /// implemented error when `ignore_nulls` is true. pub ignore_nulls: bool, - // ordering arguments - /// the expressions of `order by`, if no ordering is required, this will be an empty slice + /// The expressions in the `ORDER BY` clause passed to this aggregator. + /// + /// SQL allows the user to specify the ordering of arguments to the + /// aggregate using an `ORDER BY`. For example: + /// + /// ```sql + /// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; + /// ``` + /// + /// If no `ORDER BY` is specified, `sort_exprs`` will be empty. Aggregates + /// that do not support this functionality may return a not implemented + /// error when the slice is non empty, as ordering the arguments is an + /// expensive operation and is wasteful if the aggregate doesn't support it. pub sort_exprs: &'a [Expr], } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index ba80f39dde43..2bd3a0693275 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -280,7 +280,13 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Return a new [`Accumulator`] that aggregates values for a specific /// group during query execution. /// - /// `acc_args`: the arguments to the accumulator. See [`AccumulatorArgs`] for more details. + /// acc_args: [`AccumulatorArgs`] contains information about how the + /// aggregate function was called. + /// + /// # Example + /// ``` + /// todo + /// ``` fn accumulator(&self, acc_args: AccumulatorArgs) -> Result>; /// Return the fields of the intermediate state. @@ -308,6 +314,13 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// If the aggregate expression has a specialized /// [`GroupsAccumulator`] implementation. If this returns true, /// `[Self::create_groups_accumulator]` will be called. + /// + /// # Notes + /// + /// Even if this function returns true, DataFusion will call + /// `Self::accumulator` for certain queries, such as when this aggregate is + /// used as a window function or when there are only aggregates (no GROUP BY + /// columns) in the plan. fn groups_accumulator_supported(&self) -> bool { false } From 315a9a4fba69a07436ba67a0bf51602c2366ef88 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 3 Apr 2024 07:13:40 -0400 Subject: [PATCH 2/5] Add test and helper functions --- .../user_defined/user_defined_aggregates.rs | 131 ++++++++++++++++-- datafusion/expr/src/function.rs | 20 ++- datafusion/expr/src/udaf.rs | 3 +- 3 files changed, 140 insertions(+), 14 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 6085fca8761f..293255518251 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -20,7 +20,8 @@ use arrow::{array::AsArray, datatypes::Fields}; use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::Schema; +use arrow_schema::{Schema, SchemaRef}; +use std::any::Any; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -48,7 +49,7 @@ use datafusion_expr::{ create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF, }; -use datafusion_physical_expr::expressions::AvgAccumulator; +use datafusion_physical_expr::expressions::{AvgAccumulator, MinAccumulator}; /// Test to show the contents of the setup #[tokio::test] @@ -210,25 +211,33 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } -/// tests the creation, registration and usage of a UDAF -#[tokio::test] -async fn simple_udaf() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); +/// Return a SessionContext with a basic table "t" +fn simple_udf_context() -> Result { + let schema: SchemaRef = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), + schema.clone(), vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], )?; let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), + schema.clone(), vec![Arc::new(Int32Array::from(vec![4, 5]))], )?; let ctx = SessionContext::new(); - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +/// tests the creation, registration and usage of a UDAF +#[tokio::test] +async fn simple_udaf() -> Result<()> { + let ctx = simple_udf_context()?; + // define a udaf, using a DataFusion's accumulator let my_avg = create_udaf( "my_avg", @@ -255,6 +264,107 @@ async fn simple_udaf() -> Result<()> { Ok(()) } +/// tests the creation, registration and usage of a AggregateUDFImpl based aggregate +#[tokio::test] +async fn simple_udaf_trait() -> Result<()> { + let ctx = simple_udf_context()?; + + // define a udaf, using a DataFusion's accumulator + ctx.register_udaf(AggregateUDF::from(MyMin::new())); + + let result = ctx.sql("SELECT MY_MIN(a) FROM t").await?.collect().await?; + + let expected = [ + "+-------------+", + "| my_min(t.a) |", + "+-------------+", + "| 1.0 |", + "+-------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +} + +/// Tests checking for syntax errors +#[tokio::test] +async fn simple_udaf_trait_ignore_nulls() -> Result<()> { + let ctx = simple_udf_context()?; + ctx.register_udaf(AggregateUDF::from(MyMin::new())); + + // You can pass IGNORE NULLs to the UDAF + let err = ctx + .sql("SELECT MY_MIN(a) IGNORE NULLS FROM t") + .await? + .collect() + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "This feature is not implemented: IGNORE NULLS not implemented for my_min" + ); + // RESPECT NULLS should work (the default) + ctx.sql("SELECT MY_MIN(a) RESPECT NULLS FROM t") + .await? + .collect() + .await + .unwrap(); + + // You can pass ORDER BY to the UDAF as well, which should error if it isn't supported + let err = ctx + .sql("SELECT MY_MIN(a ORDER BY a) FROM t") + .await? + .collect() + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "This feature is not implemented: ORDER BY not implemented for my_min" + ); + + Ok(()) +} + +#[derive(Debug)] +struct MyMin { + signature: Signature, +} + +impl MyMin { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } +} +impl AggregateUDFImpl for MyMin { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_min" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + // Error if IGNORE NULLs and ORDER BY are specified + acc_args.check_ignore_nulls(self.name())?; + acc_args.check_order_by(self.name())?; + + // Use AvgAccumulator + MinAccumulator::try_new(&DataType::Float64) + .map(|acc| Box::new(acc) as Box) + } +} + #[tokio::test] async fn deregister_udaf() -> Result<()> { let ctx = SessionContext::new(); @@ -526,7 +636,6 @@ impl Accumulator for TimeSum { let arr = arr.as_primitive::(); for v in arr.values().iter() { - println!("Adding {v}"); self.sum += v; } Ok(()) @@ -538,7 +647,6 @@ impl Accumulator for TimeSum { } fn evaluate(&mut self) -> Result { - println!("Evaluating to {}", self.sum); Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None)) } @@ -558,7 +666,6 @@ impl Accumulator for TimeSum { let arr = arr.as_primitive::(); for v in arr.values().iter() { - println!("Retracting {v}"); self.sum -= v; } Ok(()) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 591531d53ca1..84c01ea7c86a 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -20,7 +20,7 @@ use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::Result; +use datafusion_common::{not_impl_err, Result}; use std::sync::Arc; /// Scalar function @@ -88,6 +88,24 @@ impl<'a> AccumulatorArgs<'a> { sort_exprs, } } + + /// Return a not yet implemented error if IGNORE NULLs is true + pub fn check_ignore_nulls(&self, name: &str) -> Result<()> { + if self.ignore_nulls { + not_impl_err!("IGNORE NULLS not implemented for {name}") + } else { + Ok(()) + } + } + + /// Return a not yet implemented error if `ORDER BY` is non empty + pub fn check_order_by(&self, name: &str) -> Result<()> { + if !self.sort_exprs.is_empty() { + not_impl_err!("ORDER BY not implemented for {name}") + } else { + Ok(()) + } + } } /// Factory that returns an accumulator for the given aggregate function. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 2bd3a0693275..e4242e4cdabb 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -285,7 +285,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// /// # Example /// ``` - /// todo + /// struct Aggregate { + /// } /// ``` fn accumulator(&self, acc_args: AccumulatorArgs) -> Result>; From 203c83fbbb9eeea67822588651d1aff2d342c685 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 3 Apr 2024 07:19:01 -0400 Subject: [PATCH 3/5] Improve docs and examples --- datafusion-examples/examples/advanced_udaf.rs | 13 ++++++++++++- .../tests/user_defined/user_defined_aggregates.rs | 2 +- datafusion/expr/src/udaf.rs | 15 +++++++-------- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 342a23b6e73d..c8cbf28d68e8 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -87,7 +87,18 @@ impl AggregateUDFImpl for GeoMeanUdaf { /// is supported, DataFusion will use this row oriented /// accumulator when the aggregate function is used as a window function /// or when there are only aggregates (no GROUP BY columns) in the plan. - fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + // Error if IGNORE NULLs and ORDER BY are specified in the query as this + // UDAF does not support them. + // + // For example `SELECT geo_mean(a) IGNORE NULLS` and `SELECT geo_mean(a) + // ORDER BY b` would fail. + // + // If your Accumulator supports different behavior for these options, + // you can implement it here. + acc_args.check_ignore_nulls(self.name())?; + acc_args.check_order_by(self.name())?; + Ok(Box::new(GeometricMean::new())) } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 293255518251..f02c91dd236b 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -359,7 +359,7 @@ impl AggregateUDFImpl for MyMin { acc_args.check_ignore_nulls(self.name())?; acc_args.check_order_by(self.name())?; - // Use AvgAccumulator + // Use MinAccumulator MinAccumulator::try_new(&DataType::Float64) .map(|acc| Box::new(acc) as Box) } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e4242e4cdabb..f239f88179a9 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -211,8 +211,8 @@ where /// See [`advanced_udaf.rs`] for a full example with complete implementation and /// [`AggregateUDF`] for other available options. /// -/// /// [`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs +/// /// # Basic Example /// ``` /// # use std::any::Any; @@ -247,7 +247,12 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } +/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { +/// // Error if IGNORE NULLs and ORDER BY are specified in the query +// acc_args.check_ignore_nulls(self.name())?; +// acc_args.check_order_by(self.name())?; +/// unimplemented!() +/// } /// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec) -> Result> { /// Ok(vec![ /// Field::new("value", value_type, true), @@ -282,12 +287,6 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// /// acc_args: [`AccumulatorArgs`] contains information about how the /// aggregate function was called. - /// - /// # Example - /// ``` - /// struct Aggregate { - /// } - /// ``` fn accumulator(&self, acc_args: AccumulatorArgs) -> Result>; /// Return the fields of the intermediate state. From 0f5956503c77cb9468acb7de4dc21dee675fe460 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 3 Apr 2024 08:57:50 -0400 Subject: [PATCH 4/5] Fix CI --- .../user_defined/user_defined_aggregates.rs | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index f02c91dd236b..ec009ef35d11 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -292,33 +292,34 @@ async fn simple_udaf_trait_ignore_nulls() -> Result<()> { let ctx = simple_udf_context()?; ctx.register_udaf(AggregateUDF::from(MyMin::new())); + /// Runs a query that errors and returns the error message + async fn run_err(ctx: &SessionContext, sql: &str) -> String { + ctx.sql(sql) + .await + .unwrap() + .collect() + .await + .unwrap_err() + .strip_backtrace() + .to_string() + } + + /// Run a query that should succeed + async fn run_ok(ctx: &SessionContext, sql: &str) { + ctx.sql(sql).await.unwrap().collect().await.unwrap(); + } + // You can pass IGNORE NULLs to the UDAF - let err = ctx - .sql("SELECT MY_MIN(a) IGNORE NULLS FROM t") - .await? - .collect() - .await - .unwrap_err(); assert_eq!( - err.to_string(), + run_err(&ctx, "SELECT MY_MIN(a) IGNORE NULLS FROM t").await, "This feature is not implemented: IGNORE NULLS not implemented for my_min" ); // RESPECT NULLS should work (the default) - ctx.sql("SELECT MY_MIN(a) RESPECT NULLS FROM t") - .await? - .collect() - .await - .unwrap(); + run_ok(&ctx, "SELECT MY_MIN(a) RESPECT NULLS FROM t").await; // You can pass ORDER BY to the UDAF as well, which should error if it isn't supported - let err = ctx - .sql("SELECT MY_MIN(a ORDER BY a) FROM t") - .await? - .collect() - .await - .unwrap_err(); assert_eq!( - err.to_string(), + run_err(&ctx, "SELECT MY_MIN(a ORDER BY a) FROM t").await, "This feature is not implemented: ORDER BY not implemented for my_min" ); From b3add2d2061dde732db4ad7aa275a23a66429241 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 4 Apr 2024 14:59:08 -0400 Subject: [PATCH 5/5] Remove checks for ORDER BY and IGNORE NULLS --- datafusion-examples/examples/advanced_udaf.rs | 13 +- .../user_defined/user_defined_aggregates.rs | 129 ++---------------- datafusion/expr/src/function.rs | 28 +--- datafusion/expr/src/udaf.rs | 13 +- 4 files changed, 16 insertions(+), 167 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index c8cbf28d68e8..342a23b6e73d 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -87,18 +87,7 @@ impl AggregateUDFImpl for GeoMeanUdaf { /// is supported, DataFusion will use this row oriented /// accumulator when the aggregate function is used as a window function /// or when there are only aggregates (no GROUP BY columns) in the plan. - fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - // Error if IGNORE NULLs and ORDER BY are specified in the query as this - // UDAF does not support them. - // - // For example `SELECT geo_mean(a) IGNORE NULLS` and `SELECT geo_mean(a) - // ORDER BY b` would fail. - // - // If your Accumulator supports different behavior for these options, - // you can implement it here. - acc_args.check_ignore_nulls(self.name())?; - acc_args.check_order_by(self.name())?; - + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(GeometricMean::new())) } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index ec009ef35d11..8f02fb30b013 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -20,8 +20,7 @@ use arrow::{array::AsArray, datatypes::Fields}; use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::{Schema, SchemaRef}; -use std::any::Any; +use arrow_schema::Schema; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -49,7 +48,7 @@ use datafusion_expr::{ create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF, }; -use datafusion_physical_expr::expressions::{AvgAccumulator, MinAccumulator}; +use datafusion_physical_expr::expressions::AvgAccumulator; /// Test to show the contents of the setup #[tokio::test] @@ -211,33 +210,25 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } -/// Return a SessionContext with a basic table "t" -fn simple_udf_context() -> Result { - let schema: SchemaRef = - Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); +/// tests the creation, registration and usage of a UDAF +#[tokio::test] +async fn simple_udaf() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let batch1 = RecordBatch::try_new( - schema.clone(), + Arc::new(schema.clone()), vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], )?; let batch2 = RecordBatch::try_new( - schema.clone(), + Arc::new(schema.clone()), vec![Arc::new(Int32Array::from(vec![4, 5]))], )?; let ctx = SessionContext::new(); - let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; ctx.register_table("t", Arc::new(provider))?; - Ok(ctx) -} - -/// tests the creation, registration and usage of a UDAF -#[tokio::test] -async fn simple_udaf() -> Result<()> { - let ctx = simple_udf_context()?; - // define a udaf, using a DataFusion's accumulator let my_avg = create_udaf( "my_avg", @@ -264,108 +255,6 @@ async fn simple_udaf() -> Result<()> { Ok(()) } -/// tests the creation, registration and usage of a AggregateUDFImpl based aggregate -#[tokio::test] -async fn simple_udaf_trait() -> Result<()> { - let ctx = simple_udf_context()?; - - // define a udaf, using a DataFusion's accumulator - ctx.register_udaf(AggregateUDF::from(MyMin::new())); - - let result = ctx.sql("SELECT MY_MIN(a) FROM t").await?.collect().await?; - - let expected = [ - "+-------------+", - "| my_min(t.a) |", - "+-------------+", - "| 1.0 |", - "+-------------+", - ]; - assert_batches_eq!(expected, &result); - - Ok(()) -} - -/// Tests checking for syntax errors -#[tokio::test] -async fn simple_udaf_trait_ignore_nulls() -> Result<()> { - let ctx = simple_udf_context()?; - ctx.register_udaf(AggregateUDF::from(MyMin::new())); - - /// Runs a query that errors and returns the error message - async fn run_err(ctx: &SessionContext, sql: &str) -> String { - ctx.sql(sql) - .await - .unwrap() - .collect() - .await - .unwrap_err() - .strip_backtrace() - .to_string() - } - - /// Run a query that should succeed - async fn run_ok(ctx: &SessionContext, sql: &str) { - ctx.sql(sql).await.unwrap().collect().await.unwrap(); - } - - // You can pass IGNORE NULLs to the UDAF - assert_eq!( - run_err(&ctx, "SELECT MY_MIN(a) IGNORE NULLS FROM t").await, - "This feature is not implemented: IGNORE NULLS not implemented for my_min" - ); - // RESPECT NULLS should work (the default) - run_ok(&ctx, "SELECT MY_MIN(a) RESPECT NULLS FROM t").await; - - // You can pass ORDER BY to the UDAF as well, which should error if it isn't supported - assert_eq!( - run_err(&ctx, "SELECT MY_MIN(a ORDER BY a) FROM t").await, - "This feature is not implemented: ORDER BY not implemented for my_min" - ); - - Ok(()) -} - -#[derive(Debug)] -struct MyMin { - signature: Signature, -} - -impl MyMin { - fn new() -> Self { - Self { - signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), - } - } -} -impl AggregateUDFImpl for MyMin { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "my_min" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Float64) - } - - fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - // Error if IGNORE NULLs and ORDER BY are specified - acc_args.check_ignore_nulls(self.name())?; - acc_args.check_order_by(self.name())?; - - // Use MinAccumulator - MinAccumulator::try_new(&DataType::Float64) - .map(|acc| Box::new(acc) as Box) - } -} - #[tokio::test] async fn deregister_udaf() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 84c01ea7c86a..7a92a50ae15d 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -20,7 +20,7 @@ use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::Result; use std::sync::Arc; /// Scalar function @@ -53,9 +53,6 @@ pub struct AccumulatorArgs<'a> { /// ```sql /// SELECT FIRST_VALUE(column1) IGNORE NULLS FROM t; /// ``` - /// - /// Aggregates that do not support this functionality should return a not - /// implemented error when `ignore_nulls` is true. pub ignore_nulls: bool, /// The expressions in the `ORDER BY` clause passed to this aggregator. @@ -67,10 +64,7 @@ pub struct AccumulatorArgs<'a> { /// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; /// ``` /// - /// If no `ORDER BY` is specified, `sort_exprs`` will be empty. Aggregates - /// that do not support this functionality may return a not implemented - /// error when the slice is non empty, as ordering the arguments is an - /// expensive operation and is wasteful if the aggregate doesn't support it. + /// If no `ORDER BY` is specified, `sort_exprs`` will be empty. pub sort_exprs: &'a [Expr], } @@ -88,24 +82,6 @@ impl<'a> AccumulatorArgs<'a> { sort_exprs, } } - - /// Return a not yet implemented error if IGNORE NULLs is true - pub fn check_ignore_nulls(&self, name: &str) -> Result<()> { - if self.ignore_nulls { - not_impl_err!("IGNORE NULLS not implemented for {name}") - } else { - Ok(()) - } - } - - /// Return a not yet implemented error if `ORDER BY` is non empty - pub fn check_order_by(&self, name: &str) -> Result<()> { - if !self.sort_exprs.is_empty() { - not_impl_err!("ORDER BY not implemented for {name}") - } else { - Ok(()) - } - } } /// Factory that returns an accumulator for the given aggregate function. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index bf639aa2e06b..3cf1845aacd6 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -249,12 +249,7 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { -/// // Error if IGNORE NULLs and ORDER BY are specified in the query -// acc_args.check_ignore_nulls(self.name())?; -// acc_args.check_order_by(self.name())?; -/// unimplemented!() -/// } +/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } /// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec) -> Result> { /// Ok(vec![ /// Field::new("value", value_type, true), @@ -334,10 +329,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// /// # Notes /// - /// Even if this function returns true, DataFusion will call + /// Even if this function returns true, DataFusion will still use /// `Self::accumulator` for certain queries, such as when this aggregate is - /// used as a window function or when there are only aggregates (no GROUP BY - /// columns) in the plan. + /// used as a window function or when there no GROUP BY columns in the + /// query. fn groups_accumulator_supported(&self) -> bool { false }