Skip to content

Commit

Permalink
Support convert_to_state for AVG accumulator
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 31, 2024
1 parent 0d994a6 commit a8b5a05
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
46 changes: 45 additions & 1 deletion datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
use arrow::array::{
self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType,
AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
};
use arrow::buffer::NullBuffer;
use arrow::compute::sum;
use arrow::datatypes::{
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
Expand Down Expand Up @@ -554,8 +555,51 @@ where
Ok(())
}

fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let sums = values[0].as_primitive::<T>();
let counts = Arc::new(UInt64Array::from_value(1, sums.len()));

let nulls = filtered_null_mask(opt_filter, sums);
let sums = PrimitiveArray::<T>::new(sums.values().clone(), nulls)
.with_data_type(self.sum_data_type.clone());

Ok(vec![counts, Arc::new(sums)])
}

fn convert_to_state_supported(&self) -> bool {
true
}

fn size(&self) -> usize {
self.counts.capacity() * std::mem::size_of::<u64>()
+ self.sums.capacity() * std::mem::size_of::<T>()
}
}

/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer`
/// where the NullBuffer is true for all values that were true
/// in the filter and `null` for any values that were false or null
fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
let (filter_bools, filter_nulls) = filter.clone().into_parts();
// Only keep values where the filter was true
// convert all false to null
let filter_bools = NullBuffer::from(filter_bools);
NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())
}

/// Compute the final null mask for an array
///
/// The output null mask :
/// * is true (non null) for all values that were true in the filter and non null in the input
/// * is false (null) for all values that were false in the filter or null in the input
fn filtered_null_mask(
opt_filter: Option<&BooleanArray>,
input: &dyn Array,
) -> Option<NullBuffer> {
let opt_filter = opt_filter.and_then(filter_to_nulls);
NullBuffer::union(opt_filter.as_ref(), input.nulls())
}
10 changes: 10 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate_skip_partial.slt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ SELECT c2, sum(c5), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
4 16155718643 9.531112968922
5 6449337880 7.074412226677

# Test avg for bigint / float
query RR
select avg(c10), avg(c11) from aggregate_test_100 GROUP BY c2 ORDER BY c2;
----
9803675241365404000 0.552420626987
6843194947657417000 0.435355657881
10700987547561746000 0.504783755855
7199224282513317000 0.41439621604
9295051061697067000 0.505315159048

# Enabling PG dialect for filtered aggregates tests
statement ok
set datafusion.sql_parser.dialect = 'Postgres';
Expand Down

0 comments on commit a8b5a05

Please sign in to comment.