Skip to content

Commit

Permalink
Support convert_to_state for AVG accumulator
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Aug 5, 2024
1 parent 0417e54 commit f05c4cd
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 1 deletion.
30 changes: 29 additions & 1 deletion datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
use arrow::array::{
self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType,
AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
};

use arrow::compute::sum;
use arrow::datatypes::{
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
Expand All @@ -35,6 +36,9 @@ use datafusion_expr::{
Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
};
use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState;
use datafusion_physical_expr_common::aggregate::groups_accumulator::nulls::{
filtered_null_mask, set_nulls,
};
use datafusion_physical_expr_common::aggregate::utils::DecimalAverager;
use log::debug;
use std::any::Any;
Expand Down Expand Up @@ -547,6 +551,30 @@ where
Ok(())
}

fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let sums = values[0]
.as_primitive::<T>()
.clone()
.with_data_type(self.sum_data_type.clone());
let counts = UInt64Array::from_value(1, sums.len());

let nulls = filtered_null_mask(opt_filter, &sums);

// set nulls on the arrays
let counts = set_nulls(counts, nulls.clone());
let sums = set_nulls(sums, nulls);

Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
}

fn supports_convert_to_state(&self) -> bool {
true
}

fn size(&self) -> usize {
self.counts.capacity() * std::mem::size_of::<u64>()
+ self.sums.capacity() * std::mem::size_of::<T>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
pub mod accumulate;
pub mod bool_op;
pub mod nulls;
pub mod prim_op;
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! XX utlities for working with nulls
use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray};
use arrow::buffer::NullBuffer;

/// Sets the validity mask for a `PrimitiveArray` to `nulls`
/// replacing any existing null mask
pub fn set_nulls<T: ArrowNumericType + Send>(
array: PrimitiveArray<T>,
nulls: Option<NullBuffer>,
) -> PrimitiveArray<T> {
let (dt, values, _old_nulls) = array.into_parts();
PrimitiveArray::<T>::new(values, nulls).with_data_type(dt)
}

/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer.
///
/// The `NullBuffer` is
/// * `true` (representing valid) for values that were `true` in filter
/// * `false` (representing null) for values that were `false` or `null` in filter
fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
let (filter_bools, filter_nulls) = filter.clone().into_parts();
let filter_bools = NullBuffer::from(filter_bools);
NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())
}

/// Compute an output validity mask for an array that has been filtered
///
/// This can be used to compute nulls for the output of
/// [`GroupsAccumulator::convert_to_state`], which quickly applies an optional
/// filter to the input rows by setting any filtered rows to NULL in the output.
/// Subsequent applications of aggregate functions that ignore NULLs (most of
/// them) will thus ignore the filtered rows as well.
///
/// # Output element is `true`
/// * A `true` in the output represents non null output for all values that were both:
///
/// * `true` in any `opt_filter` (aka values that passed the filter)
///
/// * `non null` in `input`
///
/// # Output element is `false`
/// * is false (null) for all values that were false in the filter or null in the input
///
/// # Example
///
/// ```text
/// ┌─────┐ ┌─────┐ ┌─────┐
/// │true │ │NULL │ │NULL │
/// │true │ │ │true │ │true │
/// │true │ ───┼─── │false│ ────────▶ │false│ filtered_nulls
/// │false│ │ │NULL │ │NULL │
/// │false│ │true │ │true │
/// └─────┘ └─────┘ └─────┘
/// array opt_filter output nulls
/// .nulls()
///
/// false = NULL true = pass false = NULL Meanings
/// true = valid false = filter true = valid
/// NULL = filter
/// ```
///
/// [`GroupsAccumulator::convert_to_state`]: datafusion_expr::groups_accumulator::GroupsAccumulator::convert_to_state
pub fn filtered_null_mask(
opt_filter: Option<&BooleanArray>,
input: &dyn Array,
) -> Option<NullBuffer> {
let opt_filter = opt_filter.and_then(filter_to_nulls);
NullBuffer::union(opt_filter.as_ref(), input.nulls())
}
29 changes: 29 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate_skip_partial.slt
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,21 @@ SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
4 29 9.531112968922
5 -194 7.074412226677

# Test avg for tinyint / float
query TRR
SELECT
c1,
avg(c2),
avg(c11)
FROM aggregate_test_100 GROUP BY c1 ORDER BY c1;
----
a 2.857142857143 0.438223421574
b 3.263157894737 0.496481208425
c 2.666666666667 0.425241138254
d 2.444444444444 0.541519476308
e 3 0.505440263521


# Enabling PG dialect for filtered aggregates tests
statement ok
set datafusion.sql_parser.dialect = 'Postgres';
Expand Down Expand Up @@ -267,6 +282,20 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
4 11 14
5 8 7

# Test avg for tinyint / float
query TRR
SELECT
c1,
avg(c2) FILTER (WHERE c2 != 5),
avg(c11) FILTER (WHERE c2 != 5)
FROM aggregate_test_100 GROUP BY c1 ORDER BY c1;
----
a 2.5 0.449071887467
b 2.642857142857 0.445486298629
c 2.421052631579 0.422882117723
d 2.125 0.518706191331
e 2.789473684211 0.536785323369

# Test count with nullable fields and nullable filter
query III
SELECT c2,
Expand Down

0 comments on commit f05c4cd

Please sign in to comment.