Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor: Improve documentation for AggregateUDFImpl::accumulator and AccumulatorArgs #9920

Merged
merged 7 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,18 @@ impl AggregateUDFImpl for GeoMeanUdaf {
/// is supported, DataFusion will use this row oriented
/// accumulator when the aggregate function is used as a window function
/// or when there are only aggregates (no GROUP BY columns) in the plan.
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
// Error if IGNORE NULLs and ORDER BY are specified in the query as this
// UDAF does not support them.
//
// For example `SELECT geo_mean(a) IGNORE NULLS` and `SELECT geo_mean(a)
// ORDER BY b` would fail.
//
// If your Accumulator supports different behavior for these options,
// you can implement it here.
acc_args.check_ignore_nulls(self.name())?;
acc_args.check_order_by(self.name())?;

Ok(Box::new(GeometricMean::new()))
}

Expand Down
131 changes: 119 additions & 12 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
use arrow_schema::Schema;
use arrow_schema::{Schema, SchemaRef};
use std::any::Any;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
Expand Down Expand Up @@ -48,7 +49,7 @@ use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
use datafusion_physical_expr::expressions::{AvgAccumulator, MinAccumulator};

/// Test to show the contents of the setup
#[tokio::test]
Expand Down Expand Up @@ -210,25 +211,33 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
}

/// tests the creation, registration and usage of a UDAF
#[tokio::test]
async fn simple_udaf() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
/// Return a SessionContext with a basic table "t"
fn simple_udf_context() -> Result<SessionContext> {
let schema: SchemaRef =
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));

let batch1 = RecordBatch::try_new(
Arc::new(schema.clone()),
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let batch2 = RecordBatch::try_new(
Arc::new(schema.clone()),
schema.clone(),
vec![Arc::new(Int32Array::from(vec![4, 5]))],
)?;

let ctx = SessionContext::new();

let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?;
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;

Ok(ctx)
}

/// tests the creation, registration and usage of a UDAF
#[tokio::test]
async fn simple_udaf() -> Result<()> {
let ctx = simple_udf_context()?;

// define a udaf, using a DataFusion's accumulator
let my_avg = create_udaf(
"my_avg",
Expand All @@ -255,6 +264,107 @@ async fn simple_udaf() -> Result<()> {
Ok(())
}

/// tests the creation, registration and usage of a AggregateUDFImpl based aggregate
#[tokio::test]
async fn simple_udaf_trait() -> Result<()> {
let ctx = simple_udf_context()?;

// define a udaf, using a DataFusion's accumulator
ctx.register_udaf(AggregateUDF::from(MyMin::new()));

let result = ctx.sql("SELECT MY_MIN(a) FROM t").await?.collect().await?;

let expected = [
"+-------------+",
"| my_min(t.a) |",
"+-------------+",
"| 1.0 |",
"+-------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

/// Tests checking for syntax errors
#[tokio::test]
async fn simple_udaf_trait_ignore_nulls() -> Result<()> {
let ctx = simple_udf_context()?;
ctx.register_udaf(AggregateUDF::from(MyMin::new()));

// You can pass IGNORE NULLs to the UDAF
let err = ctx
.sql("SELECT MY_MIN(a) IGNORE NULLS FROM t")
.await?
.collect()
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"This feature is not implemented: IGNORE NULLS not implemented for my_min"
);
// RESPECT NULLS should work (the default)
ctx.sql("SELECT MY_MIN(a) RESPECT NULLS FROM t")
.await?
.collect()
.await
.unwrap();

// You can pass ORDER BY to the UDAF as well, which should error if it isn't supported
let err = ctx
.sql("SELECT MY_MIN(a ORDER BY a) FROM t")
.await?
.collect()
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"This feature is not implemented: ORDER BY not implemented for my_min"
);

Ok(())
}

#[derive(Debug)]
struct MyMin {
signature: Signature,
}

impl MyMin {
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
}
}
}
impl AggregateUDFImpl for MyMin {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"my_min"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
// Error if IGNORE NULLs and ORDER BY are specified
acc_args.check_ignore_nulls(self.name())?;
acc_args.check_order_by(self.name())?;

// Use MinAccumulator
MinAccumulator::try_new(&DataType::Float64)
.map(|acc| Box::new(acc) as Box<dyn Accumulator>)
}
}

#[tokio::test]
async fn deregister_udaf() -> Result<()> {
let ctx = SessionContext::new();
Expand Down Expand Up @@ -526,7 +636,6 @@ impl Accumulator for TimeSum {
let arr = arr.as_primitive::<TimestampNanosecondType>();

for v in arr.values().iter() {
println!("Adding {v}");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive by cleanups

self.sum += v;
}
Ok(())
Expand All @@ -538,7 +647,6 @@ impl Accumulator for TimeSum {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
println!("Evaluating to {}", self.sum);
Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None))
}

Expand All @@ -558,7 +666,6 @@ impl Accumulator for TimeSum {
let arr = arr.as_primitive::<TimestampNanosecondType>();

for v in arr.values().iter() {
println!("Retracting {v}");
self.sum -= v;
}
Ok(())
Expand Down
55 changes: 47 additions & 8 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use crate::ColumnarValue;
use crate::{Accumulator, Expr, PartitionEvaluator};
use arrow::datatypes::{DataType, Schema};
use datafusion_common::Result;
use datafusion_common::{not_impl_err, Result};
use std::sync::Arc;

/// Scalar function
Expand All @@ -38,18 +38,39 @@ pub type ScalarFunctionImplementation =
pub type ReturnTypeFunction =
Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>;

/// Arguments passed to create an accumulator
/// [`AccumulatorArgs`] contains information about how an aggregate
/// function was called, including the types of its arguments and any optional
/// ordering expressions.
pub struct AccumulatorArgs<'a> {
// default arguments
/// the return type of the function
/// The return type of the aggregate function.
pub data_type: &'a DataType,
/// the schema of the input arguments
/// The schema of the input arguments
pub schema: &'a Schema,
/// whether to ignore nulls
/// Whether to ignore nulls.
///
/// SQL allows the user to specify `IGNORE NULLS`, for example:
///
/// ```sql
/// SELECT FIRST_VALUE(column1) IGNORE NULLS FROM t;
/// ```
///
/// Aggregates that do not support this functionality should return a not
/// implemented error when `ignore_nulls` is true.
pub ignore_nulls: bool,

// ordering arguments
/// the expressions of `order by`, if no ordering is required, this will be an empty slice
/// The expressions in the `ORDER BY` clause passed to this aggregator.
///
/// SQL allows the user to specify the ordering of arguments to the
/// aggregate using an `ORDER BY`. For example:
///
/// ```sql
/// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t;
/// ```
///
/// If no `ORDER BY` is specified, `sort_exprs`` will be empty. Aggregates
/// that do not support this functionality may return a not implemented
/// error when the slice is non empty, as ordering the arguments is an
/// expensive operation and is wasteful if the aggregate doesn't support it.
pub sort_exprs: &'a [Expr],
}

Expand All @@ -67,6 +88,24 @@ impl<'a> AccumulatorArgs<'a> {
sort_exprs,
}
}

/// Return a not yet implemented error if IGNORE NULLs is true
pub fn check_ignore_nulls(&self, name: &str) -> Result<()> {
if self.ignore_nulls {
Copy link
Contributor

@jayzhan211 jayzhan211 Apr 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check !self.ignore_nulls?

I think checkXXX should be added for the user if they think they need to enable it.
In this case, when the user chooses to enable ignore_nulls, they need to add the check. If ignore_nulls is false, it means they should fix their query to contains ignore nulls

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is confusing for the user to understand whether they need to check or not.
I think rename it to disable_xxx help or change the logic an rename it to enable_xxx

Copy link
Contributor

@jayzhan211 jayzhan211 Apr 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also a problem because if they forget to add check_ignore_nulls/check_order_by in the accumulator, they can still run the function successfully. This approach does not force the user to check their options because datafusion implements them, not the user.

To enforce they specified the options for their functions, I think we can add the checking function in AggregateUDFImpl. So the user needs to set the true/false for their options

Either

// We will check if the AccumulatorArgs meet the requirement or not.
fn options() -> AccumulatorArgs {
   AccumulatorArgs {
     ignore_nulls: true/ false,
     has_ordering: true / false
    }
}

or

// The same with this one, but separate each option
fn support_ignore_nulls() -> bool
fn support_ordering() -> bool

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check !self.ignore_nulls?

I don't think we can do this, because the ignore_nulls is true for the following queries

SELECT avg(x) FROM ...;
SELECT avg(x) RESPECT NULLS FROM ...;

In other words, it is the default even when the user doesn't explicitly specify the handling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is confusing for the user to understand whether they need to check or not. I think rename it to disable_xxx help or change the logic an rename it to enable_xxx

I don't quite understand this suggestion -- are you suggesting rename check_ignore_nulls to disable_ignore_nulls?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also a problem because if they forget to add check_ignore_nulls/check_order_by in the accumulator, they can still run the function successfully. This approach does not force the user to check their options because datafusion implements them, not the user.

That is a good point --- in fact it actually affects built in aggregates too today

select count(*) from (values (1), (null), (2));
+----------+
| COUNT(*) |
+----------+
| 3        |
+----------+
1 row in set. Query took 0.039 seconds.

❯ select count(*) IGNORE NULLS from (values (1), (null), (2));
+----------+
| COUNT(*) |
+----------+
| 3        |
+----------+
1 row in set. Query took 0.001 seconds.

I think this is a sepate issue, and not made worse by this PR -- I filed #9924 to track. I suggest we work on improving it as a follow on PR

Copy link
Contributor

@jayzhan211 jayzhan211 Apr 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is confusing for the user to understand whether they need to check or not. I think rename it to disable_xxx help or change the logic an rename it to enable_xxx

I don't quite understand this suggestion -- are you suggesting rename check_ignore_nulls to disable_ignore_nulls?

yes, that is what I suggest, so we know exactly whether it is disabled or not. But I think the comment here also helps, rename is not neccessary

Copy link
Contributor

@jayzhan211 jayzhan211 Apr 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a sepate issue, and not made worse by this PR -- I filed #9924 to track. I suggest we work on improving it as a follow on PR

But I think if we implement these for UDFImpl,

fn support_ignore_nulls() -> bool
fn support_ordering() -> bool

we probably don't need the check_ignore_nulls, because we can check it for them!

create_aggregate_expr is the earliest place we know ignore_nulls and sort_exprs, and we can call fun.support_ignore_nulls() to check for them, so do ordering.

https://github.com/apache/arrow-datafusion/blob/daf182dc789230dbd9cf21ca2e975789213a5365/datafusion/physical-plan/src/udaf.rs#L38-L46

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea and I think it should be done in #9924

I will update this PR to remove the check_* functions and only update the docs

not_impl_err!("IGNORE NULLS not implemented for {name}")
} else {
Ok(())
}
}

/// Return a not yet implemented error if `ORDER BY` is non empty
pub fn check_order_by(&self, name: &str) -> Result<()> {
if !self.sort_exprs.is_empty() {
not_impl_err!("ORDER BY not implemented for {name}")
} else {
Ok(())
}
}
}

/// Factory that returns an accumulator for the given aggregate function.
Expand Down
19 changes: 16 additions & 3 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ where
/// See [`advanced_udaf.rs`] for a full example with complete implementation and
/// [`AggregateUDF`] for other available options.
///
///
/// [`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
///
/// # Basic Example
/// ```
/// # use std::any::Any;
Expand Down Expand Up @@ -247,7 +247,12 @@ where
/// Ok(DataType::Float64)
/// }
/// // This is the accumulator factory; DataFusion uses it to create new accumulators.
/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
/// // Error if IGNORE NULLs and ORDER BY are specified in the query
// acc_args.check_ignore_nulls(self.name())?;
// acc_args.check_order_by(self.name())?;
/// unimplemented!()
/// }
/// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec<Field>) -> Result<Vec<Field>> {
/// Ok(vec![
/// Field::new("value", value_type, true),
Expand Down Expand Up @@ -280,7 +285,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// Return a new [`Accumulator`] that aggregates values for a specific
/// group during query execution.
///
/// `acc_args`: the arguments to the accumulator. See [`AccumulatorArgs`] for more details.
/// acc_args: [`AccumulatorArgs`] contains information about how the
/// aggregate function was called.
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;

/// Return the fields of the intermediate state.
Expand Down Expand Up @@ -308,6 +314,13 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// If the aggregate expression has a specialized
/// [`GroupsAccumulator`] implementation. If this returns true,
/// `[Self::create_groups_accumulator]` will be called.
///
/// # Notes
///
/// Even if this function returns true, DataFusion will call
/// `Self::accumulator` for certain queries, such as when this aggregate is
/// used as a window function or when there are only aggregates (no GROUP BY
/// columns) in the plan.
fn groups_accumulator_supported(&self) -> bool {
false
}
Expand Down
Loading