From a8b5a05a7a3d96e8cb81f9f8ce4d6274a973a5c2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 30 Jul 2024 17:46:48 -0400 Subject: [PATCH] Support `convert_to_state` for `AVG` accumulator --- datafusion/functions-aggregate/src/average.rs | 46 ++++++++++++++++++- .../test_files/aggregate_skip_partial.slt | 10 ++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 18642fb843293..e0d543f4a366f 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -19,8 +19,9 @@ use arrow::array::{ self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, - AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, + AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, }; +use arrow::buffer::NullBuffer; use arrow::compute::sum; use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, @@ -554,8 +555,51 @@ where Ok(()) } + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let sums = values[0].as_primitive::(); + let counts = Arc::new(UInt64Array::from_value(1, sums.len())); + + let nulls = filtered_null_mask(opt_filter, sums); + let sums = PrimitiveArray::::new(sums.values().clone(), nulls) + .with_data_type(self.sum_data_type.clone()); + + Ok(vec![counts, Arc::new(sums)]) + } + + fn convert_to_state_supported(&self) -> bool { + true + } + fn size(&self) -> usize { self.counts.capacity() * std::mem::size_of::() + self.sums.capacity() * std::mem::size_of::() } } + +/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer` +/// where the NullBuffer is true for all values that were true +/// in the filter and `null` for any values that were false or null +fn filter_to_nulls(filter: &BooleanArray) -> Option { + let (filter_bools, filter_nulls) = filter.clone().into_parts(); + // Only keep values where the filter was true + // convert all false to null + let filter_bools = NullBuffer::from(filter_bools); + NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref()) +} + +/// Compute the final null mask for an array +/// +/// The output null mask : +/// * is true (non null) for all values that were true in the filter and non null in the input +/// * is false (null) for all values that were false in the filter or null in the input +fn filtered_null_mask( + opt_filter: Option<&BooleanArray>, + input: &dyn Array, +) -> Option { + let opt_filter = opt_filter.and_then(filter_to_nulls); + NullBuffer::union(opt_filter.as_ref(), input.nulls()) +} diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index 1d152bf477bc7..35f96637a1293 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -101,6 +101,16 @@ SELECT c2, sum(c5), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; 4 16155718643 9.531112968922 5 6449337880 7.074412226677 +# Test avg for bigint / float +query RR +select avg(c10), avg(c11) from aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +9803675241365404000 0.552420626987 +6843194947657417000 0.435355657881 +10700987547561746000 0.504783755855 +7199224282513317000 0.41439621604 +9295051061697067000 0.505315159048 + # Enabling PG dialect for filtered aggregates tests statement ok set datafusion.sql_parser.dialect = 'Postgres';