diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 6085fca8761f..9b7fac052ca2 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -20,7 +20,8 @@ use arrow::{array::AsArray, datatypes::Fields}; use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::Schema; +use arrow_schema::{Schema, SchemaRef}; +use std::any::Any; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -48,7 +49,7 @@ use datafusion_expr::{ create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF, }; -use datafusion_physical_expr::expressions::AvgAccumulator; +use datafusion_physical_expr::expressions::{AvgAccumulator, MinAccumulator}; /// Test to show the contents of the setup #[tokio::test] @@ -210,25 +211,33 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } -/// tests the creation, registration and usage of a UDAF -#[tokio::test] -async fn simple_udaf() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); +/// Return a SessionContext with a basic table "t" +fn simple_udf_context() -> Result { + let schema: SchemaRef = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), + schema.clone(), vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], )?; let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), + schema.clone(), vec![Arc::new(Int32Array::from(vec![4, 5]))], )?; let ctx = SessionContext::new(); - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +/// tests the creation, registration and usage of a UDAF +#[tokio::test] +async fn simple_udaf() -> Result<()> { + let ctx = simple_udf_context()?; + // define a udaf, using a DataFusion's accumulator let my_avg = create_udaf( "my_avg", @@ -255,6 +264,106 @@ async fn simple_udaf() -> Result<()> { Ok(()) } +/// tests the creation, registration and usage of a AggregateUDFImpl based aggregate +#[tokio::test] +async fn simple_udaf_trait() -> Result<()> { + let ctx = simple_udf_context()?; + + // define a udaf, using a DataFusion's accumulator + ctx.register_udaf(AggregateUDF::from(MyMin::new())); + + let result = ctx.sql("SELECT MY_MIN(a) FROM t").await?.collect().await?; + + let expected = [ + "+-------------+", + "| my_min(t.a) |", + "+-------------+", + "| 1.0 |", + "+-------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +} + +/// Tests checking for syntax errors +#[tokio::test] +async fn simple_udaf_trait_ignore_nulls() -> Result<()> { + let ctx = simple_udf_context()?; + ctx.register_udaf(AggregateUDF::from(MyMin::new())); + + /// Runs a query that errors and returns the error message + async fn run_err(ctx: &SessionContext, sql: &str) -> String { + ctx.sql(sql) + .await + .unwrap() + .collect() + .await + .unwrap_err() + .strip_backtrace() + .to_string() + } + + /// Run a query that should succeed + async fn run_ok(ctx: &SessionContext, sql: &str) { + ctx.sql(sql).await.unwrap().collect().await.unwrap(); + } + + // You can pass IGNORE NULLs to the UDAF + assert_eq!( + run_err(&ctx, "SELECT MY_MIN(a) IGNORE NULLS FROM t").await, + "This feature is not implemented: IGNORE NULLS not implemented for my_min" + ); + // RESPECT NULLS should work (the default) + run_ok(&ctx, "SELECT MY_MIN(a) RESPECT NULLS FROM t").await; + + // You can pass ORDER BY to the UDAF as well, which should error if it isn't supported + assert_eq!( + run_err(&ctx, "SELECT MY_MIN(a ORDER BY a) FROM t").await, + "This feature is not implemented: ORDER BY not implemented for my_min" + ); + + Ok(()) +} + +#[derive(Debug)] +struct MyMin { + signature: Signature, +} + +impl MyMin { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } +} +impl AggregateUDFImpl for MyMin { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_min" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + // TODO error if IGNORE NULLs and ORDER BY are specified + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + // Use MinAccumulator + MinAccumulator::try_new(&DataType::Float64) + .map(|acc| Box::new(acc) as Box) + } +} + #[tokio::test] async fn deregister_udaf() -> Result<()> { let ctx = SessionContext::new(); @@ -526,7 +635,6 @@ impl Accumulator for TimeSum { let arr = arr.as_primitive::(); for v in arr.values().iter() { - println!("Adding {v}"); self.sum += v; } Ok(()) @@ -538,7 +646,6 @@ impl Accumulator for TimeSum { } fn evaluate(&mut self) -> Result { - println!("Evaluating to {}", self.sum); Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None)) } @@ -558,7 +665,6 @@ impl Accumulator for TimeSum { let arr = arr.as_primitive::(); for v in arr.values().iter() { - println!("Retracting {v}"); self.sum -= v; } Ok(())