From 00ef8204707b158c0e086506bd7d9d9dd3be5a6f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 12 Aug 2024 17:34:11 -0400 Subject: [PATCH] Support `convert_to_state` for `AVG` accumulator (#11734) * Support `convert_to_state` for `AVG` accumulator * Update datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs * fix documentation * Fix after merge * fix for change in location --- .../src/aggregate/groups_accumulator.rs | 1 + .../src/aggregate/groups_accumulator/nulls.rs | 93 +++++++++++++++++++ datafusion/functions-aggregate/src/average.rs | 32 ++++++- .../test_files/aggregate_skip_partial.slt | 29 ++++++ 4 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index 644221edd04d..3984b02c5fbb 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -20,6 +20,7 @@ pub mod accumulate; pub mod bool_op; +pub mod nulls; pub mod prim_op; use arrow::{ diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs new file mode 100644 index 000000000000..25212f7f0f5f --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`set_nulls`], and [`filtered_null_mask`], utilities for working with nulls + +use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray}; +use arrow::buffer::NullBuffer; + +/// Sets the validity mask for a `PrimitiveArray` to `nulls` +/// replacing any existing null mask +pub fn set_nulls( + array: PrimitiveArray, + nulls: Option, +) -> PrimitiveArray { + let (dt, values, _old_nulls) = array.into_parts(); + PrimitiveArray::::new(values, nulls).with_data_type(dt) +} + +/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer. +/// +/// The `NullBuffer` is +/// * `true` (representing valid) for values that were `true` in filter +/// * `false` (representing null) for values that were `false` or `null` in filter +fn filter_to_nulls(filter: &BooleanArray) -> Option { + let (filter_bools, filter_nulls) = filter.clone().into_parts(); + let filter_bools = NullBuffer::from(filter_bools); + NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref()) +} + +/// Compute an output validity mask for an array that has been filtered +/// +/// This can be used to compute nulls for the output of +/// [`GroupsAccumulator::convert_to_state`], which quickly applies an optional +/// filter to the input rows by setting any filtered rows to NULL in the output. +/// Subsequent applications of aggregate functions that ignore NULLs (most of +/// them) will thus ignore the filtered rows as well. +/// +/// # Output element is `true` (and thus output is non-null) +/// +/// A `true` in the output represents non null output for all values that were *both*: +/// +/// * `true` in any `opt_filter` (aka values that passed the filter) +/// +/// * `non null` in `input` +/// +/// # Output element is `false` (and thus output is null) +/// +/// A `false` in the output represents an input that was *either*: +/// +/// * `null` +/// +/// * filtered (aka the value was `false` or `null` in the filter) +/// +/// # Example +/// +/// ```text +/// ┌─────┐ ┌─────┐ ┌─────┐ +/// │true │ │NULL │ │false│ +/// │true │ │ │true │ │true │ +/// │true │ ───┼─── │false│ ────────▶ │false│ filtered_nulls +/// │false│ │ │NULL │ │false│ +/// │false│ │true │ │false│ +/// └─────┘ └─────┘ └─────┘ +/// array opt_filter output +/// .nulls() +/// +/// false = NULL true = pass false = NULL Meanings +/// true = valid false = filter true = valid +/// NULL = filter +/// ``` +/// +/// [`GroupsAccumulator::convert_to_state`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator +pub fn filtered_null_mask( + opt_filter: Option<&BooleanArray>, + input: &dyn Array, +) -> Option { + let opt_filter = opt_filter.and_then(filter_to_nulls); + NullBuffer::union(opt_filter.as_ref(), input.nulls()) +} diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 1be3cd6b0714..ddad76a8734b 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -19,8 +19,9 @@ use arrow::array::{ self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, - AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, + AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, }; + use arrow::compute::sum; use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, @@ -34,7 +35,12 @@ use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, }; + use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ + filtered_null_mask, set_nulls, +}; + use datafusion_functions_aggregate_common::utils::DecimalAverager; use log::debug; use std::any::Any; @@ -551,6 +557,30 @@ where Ok(()) } + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let sums = values[0] + .as_primitive::() + .clone() + .with_data_type(self.sum_data_type.clone()); + let counts = UInt64Array::from_value(1, sums.len()); + + let nulls = filtered_null_mask(opt_filter, &sums); + + // set nulls on the arrays + let counts = set_nulls(counts, nulls.clone()); + let sums = set_nulls(sums, nulls); + + Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + fn size(&self) -> usize { self.counts.capacity() * std::mem::size_of::() + self.sums.capacity() * std::mem::size_of::() diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index 6c0cf5f800d8..ba378f4230f8 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -209,6 +209,21 @@ SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; 4 29 9.531112968922 5 -194 7.074412226677 +# Test avg for tinyint / float +query TRR +SELECT + c1, + avg(c2), + avg(c11) +FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +a 2.857142857143 0.438223421574 +b 3.263157894737 0.496481208425 +c 2.666666666667 0.425241138254 +d 2.444444444444 0.541519476308 +e 3 0.505440263521 + + # Enabling PG dialect for filtered aggregates tests statement ok set datafusion.sql_parser.dialect = 'Postgres'; @@ -267,6 +282,20 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; 4 11 14 5 8 7 +# Test avg for tinyint / float +query TRR +SELECT + c1, + avg(c2) FILTER (WHERE c2 != 5), + avg(c11) FILTER (WHERE c2 != 5) +FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +a 2.5 0.449071887467 +b 2.642857142857 0.445486298629 +c 2.421052631579 0.422882117723 +d 2.125 0.518706191331 +e 2.789473684211 0.536785323369 + # Test count with nullable fields and nullable filter query III SELECT c2,