Skip to content

Commit

Permalink
Support convert_to_state for AVG accumulator
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 31, 2024
1 parent 0d994a6 commit 7020dcf
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
62 changes: 61 additions & 1 deletion datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
use arrow::array::{
self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType,
AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
};
use arrow::buffer::NullBuffer;
use arrow::compute::sum;
use arrow::datatypes::{
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
Expand Down Expand Up @@ -554,8 +555,67 @@ where
Ok(())
}

fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let sums = values[0]
.as_primitive::<T>()
.clone()
.with_data_type(self.sum_data_type.clone());
let counts = UInt64Array::from_value(1, sums.len());

let nulls = filtered_null_mask(opt_filter, &sums);

// set nulls on the arrays
let counts = set_nulls(counts, nulls.clone());
let sums = set_nulls(sums, nulls);

Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
}

fn convert_to_state_supported(&self) -> bool {
true
}

fn size(&self) -> usize {
self.counts.capacity() * std::mem::size_of::<u64>()
+ self.sums.capacity() * std::mem::size_of::<T>()
}
}

/// Sets the null mask for the specified primitive array to be `nulls`
/// replacing any existing null mask
fn set_nulls<T: ArrowNumericType + Send>(
array: PrimitiveArray<T>,
nulls: Option<NullBuffer>,
) -> PrimitiveArray<T> {
// replace null values
let (dt, values, _old_nulls) = array.into_parts();
PrimitiveArray::<T>::new(values, nulls).with_data_type(dt)
}

/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer`
/// where the NullBuffer is true for all values that were true
/// in the filter and `null` for any values that were false or null
fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
let (filter_bools, filter_nulls) = filter.clone().into_parts();
// Only keep values where the filter was true
// convert all false to null
let filter_bools = NullBuffer::from(filter_bools);
NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())
}

/// Compute the final null mask for an array
///
/// The output null mask :
/// * is true (non null) for all values that were true in the filter and non null in the input
/// * is false (null) for all values that were false in the filter or null in the input
fn filtered_null_mask(
opt_filter: Option<&BooleanArray>,
input: &dyn Array,
) -> Option<NullBuffer> {
let opt_filter = opt_filter.and_then(filter_to_nulls);
NullBuffer::union(opt_filter.as_ref(), input.nulls())
}
28 changes: 28 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate_skip_partial.slt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ SELECT c2, sum(c5), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
4 16155718643 9.531112968922
5 6449337880 7.074412226677

# Test avg for bigint / float
query IRR
SELECT
c2,
avg(c10),
avg(c11)
FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
----
1 9803675241365398000 0.552420626987
2 6843194947657418000 0.435355657881
3 10700987547561746000 0.504783755855
4 7199224282513318000 0.41439621604
5 9295051061697067000 0.505315159048

# Enabling PG dialect for filtered aggregates tests
statement ok
set datafusion.sql_parser.dialect = 'Postgres';
Expand Down Expand Up @@ -146,6 +160,20 @@ FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
4 417
5 284

# Test avg for bigint / float with filter
query IRR
SELECT
c2,
avg(c10) FILTER (WHERE c2 != 'e'),
avg(c11) FILTER (WHERE c2 != 'e')
FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
----
1 9803675241365398000 0.552420626987
2 6843194947657418000 0.435355657881
3 10700987547561746000 0.504783755855
4 7199224282513318000 0.41439621604
5 9295051061697067000 0.505315159048

# Test count with nullable fields
query III
SELECT c2, count(c3), count(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
Expand Down

0 comments on commit 7020dcf

Please sign in to comment.