From f3bedc004e96c331ef7a1e073276aaa3a9b26d19 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 5 Aug 2024 09:40:10 -0400 Subject: [PATCH 1/5] Support `convert_to_state` for `AVG` accumulator --- datafusion/functions-aggregate/src/average.rs | 30 ++++++- .../src/aggregate/groups_accumulator/mod.rs | 1 + .../src/aggregate/groups_accumulator/nulls.rs | 87 +++++++++++++++++++ .../test_files/aggregate_skip_partial.slt | 29 +++++++ 4 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 288e0b09f809..a671a9a3d81a 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -19,8 +19,9 @@ use arrow::array::{ self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, - AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, + AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, }; + use arrow::compute::sum; use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, @@ -35,6 +36,9 @@ use datafusion_expr::{ Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, }; use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; +use datafusion_physical_expr_common::aggregate::groups_accumulator::nulls::{ + filtered_null_mask, set_nulls, +}; use datafusion_physical_expr_common::aggregate::utils::DecimalAverager; use log::debug; use std::any::Any; @@ -547,6 +551,30 @@ where Ok(()) } + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let sums = values[0] + .as_primitive::() + .clone() + .with_data_type(self.sum_data_type.clone()); + let counts = UInt64Array::from_value(1, sums.len()); + + let nulls = filtered_null_mask(opt_filter, &sums); + + // set nulls on the arrays + let counts = set_nulls(counts, nulls.clone()); + let sums = set_nulls(sums, nulls); + + Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + fn size(&self) -> usize { self.counts.capacity() * std::mem::size_of::() + self.sums.capacity() * std::mem::size_of::() diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs index 5b0182c5db8a..220f142927b3 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs @@ -19,4 +19,5 @@ pub mod accumulate; pub mod bool_op; +pub mod nulls; pub mod prim_op; diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs new file mode 100644 index 000000000000..d4ca238ac9d8 --- /dev/null +++ b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! XX utlities for working with nulls + +use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray}; +use arrow::buffer::NullBuffer; + +/// Sets the validity mask for a `PrimitiveArray` to `nulls` +/// replacing any existing null mask +pub fn set_nulls( + array: PrimitiveArray, + nulls: Option, +) -> PrimitiveArray { + let (dt, values, _old_nulls) = array.into_parts(); + PrimitiveArray::::new(values, nulls).with_data_type(dt) +} + +/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer. +/// +/// The `NullBuffer` is +/// * `true` (representing valid) for values that were `true` in filter +/// * `false` (representing null) for values that were `false` or `null` in filter +fn filter_to_nulls(filter: &BooleanArray) -> Option { + let (filter_bools, filter_nulls) = filter.clone().into_parts(); + let filter_bools = NullBuffer::from(filter_bools); + NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref()) +} + +/// Compute an output validity mask for an array that has been filtered +/// +/// This can be used to compute nulls for the output of +/// [`GroupsAccumulator::convert_to_state`], which quickly applies an optional +/// filter to the input rows by setting any filtered rows to NULL in the output. +/// Subsequent applications of aggregate functions that ignore NULLs (most of +/// them) will thus ignore the filtered rows as well. +/// +/// # Output element is `true` +/// * A `true` in the output represents non null output for all values that were both: +/// +/// * `true` in any `opt_filter` (aka values that passed the filter) +/// +/// * `non null` in `input` +/// +/// # Output element is `false` +/// * is false (null) for all values that were false in the filter or null in the input +/// +/// # Example +/// +/// ```text +/// ┌─────┐ ┌─────┐ ┌─────┐ +/// │true │ │NULL │ │NULL │ +/// │true │ │ │true │ │true │ +/// │true │ ───┼─── │false│ ────────▶ │false│ filtered_nulls +/// │false│ │ │NULL │ │NULL │ +/// │false│ │true │ │true │ +/// └─────┘ └─────┘ └─────┘ +/// array opt_filter output nulls +/// .nulls() +/// +/// false = NULL true = pass false = NULL Meanings +/// true = valid false = filter true = valid +/// NULL = filter +/// ``` +/// +/// [`GroupsAccumulator::convert_to_state`]: datafusion_expr::groups_accumulator::GroupsAccumulator::convert_to_state +pub fn filtered_null_mask( + opt_filter: Option<&BooleanArray>, + input: &dyn Array, +) -> Option { + let opt_filter = opt_filter.and_then(filter_to_nulls); + NullBuffer::union(opt_filter.as_ref(), input.nulls()) +} diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index 6c0cf5f800d8..ba378f4230f8 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -209,6 +209,21 @@ SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; 4 29 9.531112968922 5 -194 7.074412226677 +# Test avg for tinyint / float +query TRR +SELECT + c1, + avg(c2), + avg(c11) +FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +a 2.857142857143 0.438223421574 +b 3.263157894737 0.496481208425 +c 2.666666666667 0.425241138254 +d 2.444444444444 0.541519476308 +e 3 0.505440263521 + + # Enabling PG dialect for filtered aggregates tests statement ok set datafusion.sql_parser.dialect = 'Postgres'; @@ -267,6 +282,20 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; 4 11 14 5 8 7 +# Test avg for tinyint / float +query TRR +SELECT + c1, + avg(c2) FILTER (WHERE c2 != 5), + avg(c11) FILTER (WHERE c2 != 5) +FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +a 2.5 0.449071887467 +b 2.642857142857 0.445486298629 +c 2.421052631579 0.422882117723 +d 2.125 0.518706191331 +e 2.789473684211 0.536785323369 + # Test count with nullable fields and nullable filter query III SELECT c2, From 92decac8d33a9763ce046519acd1cdb9f95433b2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 11 Aug 2024 10:43:07 -0400 Subject: [PATCH 2/5] Update datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs --- .../src/aggregate/groups_accumulator/nulls.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs index d4ca238ac9d8..e9a434b231bc 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! XX utlities for working with nulls +//! [`set_nulls`], and [`filtered_null_mask`], utilities for working with nulls use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray}; use arrow::buffer::NullBuffer; From 149406bb6f45a79112b844d8f69f5c0937a3a6d7 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 11 Aug 2024 10:50:40 -0400 Subject: [PATCH 3/5] fix documentation --- .../src/aggregate/groups_accumulator/nulls.rs | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs index e9a434b231bc..5500eb2b65c2 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs @@ -49,27 +49,33 @@ fn filter_to_nulls(filter: &BooleanArray) -> Option { /// Subsequent applications of aggregate functions that ignore NULLs (most of /// them) will thus ignore the filtered rows as well. /// -/// # Output element is `true` -/// * A `true` in the output represents non null output for all values that were both: +/// # Output element is `true` (and thus output is non-null) +/// +/// A `true` in the output represents non null output for all values that were *both*: /// /// * `true` in any `opt_filter` (aka values that passed the filter) /// /// * `non null` in `input` /// -/// # Output element is `false` -/// * is false (null) for all values that were false in the filter or null in the input +/// # Output element is `false` (and thus output is null) +/// +/// A `false` in the output represents an input that was *either*: +/// +/// * `null` +/// +/// * filtered (aka the value was `false` or `null` in the filter) /// /// # Example /// /// ```text /// ┌─────┐ ┌─────┐ ┌─────┐ -/// │true │ │NULL │ │NULL │ +/// │true │ │NULL │ │false│ /// │true │ │ │true │ │true │ /// │true │ ───┼─── │false│ ────────▶ │false│ filtered_nulls -/// │false│ │ │NULL │ │NULL │ -/// │false│ │true │ │true │ +/// │false│ │ │NULL │ │false│ +/// │false│ │true │ │false│ /// └─────┘ └─────┘ └─────┘ -/// array opt_filter output nulls +/// array opt_filter output /// .nulls() /// /// false = NULL true = pass false = NULL Meanings From 6186171a791756055ad2a5a913c3d4483edb691f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 11 Aug 2024 11:00:18 -0400 Subject: [PATCH 4/5] Fix after merge --- .../src/aggregate/groups_accumulator.rs | 1 + datafusion/functions-aggregate/src/average.rs | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index 644221edd04d..3984b02c5fbb 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -20,6 +20,7 @@ pub mod accumulate; pub mod bool_op; +pub mod nulls; pub mod prim_op; use arrow::{ diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 80fe3fb03152..ddad76a8734b 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -37,6 +37,10 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ + filtered_null_mask, set_nulls, +}; + use datafusion_functions_aggregate_common::utils::DecimalAverager; use log::debug; use std::any::Any; From 8d4ea5b7a50ad7dc543a7e3e427226de6fdb7254 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 12 Aug 2024 13:02:41 -0400 Subject: [PATCH 5/5] fix for change in location --- .../src/aggregate/groups_accumulator/nulls.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 5500eb2b65c2..25212f7f0f5f 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -83,7 +83,7 @@ fn filter_to_nulls(filter: &BooleanArray) -> Option { /// NULL = filter /// ``` /// -/// [`GroupsAccumulator::convert_to_state`]: datafusion_expr::groups_accumulator::GroupsAccumulator::convert_to_state +/// [`GroupsAccumulator::convert_to_state`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator pub fn filtered_null_mask( opt_filter: Option<&BooleanArray>, input: &dyn Array,