Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for for AggregateUDFImpl with ORDER BY and IGNORE NULLS #9953

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 118 additions & 12 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
use arrow_schema::Schema;
use arrow_schema::{Schema, SchemaRef};
use std::any::Any;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
Expand Down Expand Up @@ -48,7 +49,7 @@ use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
use datafusion_physical_expr::expressions::{AvgAccumulator, MinAccumulator};

/// Test to show the contents of the setup
#[tokio::test]
Expand Down Expand Up @@ -210,25 +211,33 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
}

/// tests the creation, registration and usage of a UDAF
#[tokio::test]
async fn simple_udaf() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
/// Return a SessionContext with a basic table "t"
fn simple_udf_context() -> Result<SessionContext> {
let schema: SchemaRef =
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));

let batch1 = RecordBatch::try_new(
Arc::new(schema.clone()),
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let batch2 = RecordBatch::try_new(
Arc::new(schema.clone()),
schema.clone(),
vec![Arc::new(Int32Array::from(vec![4, 5]))],
)?;

let ctx = SessionContext::new();

let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?;
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;

Ok(ctx)
}

/// tests the creation, registration and usage of a UDAF
#[tokio::test]
async fn simple_udaf() -> Result<()> {
let ctx = simple_udf_context()?;

// define a udaf, using a DataFusion's accumulator
let my_avg = create_udaf(
"my_avg",
Expand All @@ -255,6 +264,106 @@ async fn simple_udaf() -> Result<()> {
Ok(())
}

/// tests the creation, registration and usage of a AggregateUDFImpl based aggregate
#[tokio::test]
async fn simple_udaf_trait() -> Result<()> {
let ctx = simple_udf_context()?;

// define a udaf, using a DataFusion's accumulator
ctx.register_udaf(AggregateUDF::from(MyMin::new()));

let result = ctx.sql("SELECT MY_MIN(a) FROM t").await?.collect().await?;

let expected = [
"+-------------+",
"| my_min(t.a) |",
"+-------------+",
"| 1.0 |",
"+-------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

/// Tests checking for syntax errors
#[tokio::test]
async fn simple_udaf_trait_ignore_nulls() -> Result<()> {
let ctx = simple_udf_context()?;
ctx.register_udaf(AggregateUDF::from(MyMin::new()));

/// Runs a query that errors and returns the error message
async fn run_err(ctx: &SessionContext, sql: &str) -> String {
ctx.sql(sql)
.await
.unwrap()
.collect()
.await
.unwrap_err()
.strip_backtrace()
.to_string()
}

/// Run a query that should succeed
async fn run_ok(ctx: &SessionContext, sql: &str) {
ctx.sql(sql).await.unwrap().collect().await.unwrap();
}

// You can pass IGNORE NULLs to the UDAF
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here are the tests

assert_eq!(
run_err(&ctx, "SELECT MY_MIN(a) IGNORE NULLS FROM t").await,
"This feature is not implemented: IGNORE NULLS not implemented for my_min"
);
// RESPECT NULLS should work (the default)
run_ok(&ctx, "SELECT MY_MIN(a) RESPECT NULLS FROM t").await;

// You can pass ORDER BY to the UDAF as well, which should error if it isn't supported
assert_eq!(
run_err(&ctx, "SELECT MY_MIN(a ORDER BY a) FROM t").await,
"This feature is not implemented: ORDER BY not implemented for my_min"
);

Ok(())
}

#[derive(Debug)]
struct MyMin {
signature: Signature,
}

impl MyMin {
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
}
}
}
impl AggregateUDFImpl for MyMin {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"my_min"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

// TODO error if IGNORE NULLs and ORDER BY are specified

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
// Use MinAccumulator
MinAccumulator::try_new(&DataType::Float64)
.map(|acc| Box::new(acc) as Box<dyn Accumulator>)
}
}

#[tokio::test]
async fn deregister_udaf() -> Result<()> {
let ctx = SessionContext::new();
Expand Down Expand Up @@ -526,7 +635,6 @@ impl Accumulator for TimeSum {
let arr = arr.as_primitive::<TimestampNanosecondType>();

for v in arr.values().iter() {
println!("Adding {v}");
self.sum += v;
}
Ok(())
Expand All @@ -538,7 +646,6 @@ impl Accumulator for TimeSum {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
println!("Evaluating to {}", self.sum);
Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None))
}

Expand All @@ -558,7 +665,6 @@ impl Accumulator for TimeSum {
let arr = arr.as_primitive::<TimestampNanosecondType>();

for v in arr.values().iter() {
println!("Retracting {v}");
self.sum -= v;
}
Ok(())
Expand Down
Loading