diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs index 8bbcf756c37c..80b558fb8550 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs @@ -15,17 +15,13 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}; -use arrow::buffer::NullBuffer; -use arrow::compute; -use arrow::datatypes::ArrowPrimitiveType; +use crate::aggregate::groups_accumulator::accumulate::NullState; +use crate::aggregate::groups_accumulator::nulls::{filtered_null_mask, set_nulls}; +use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, PrimitiveArray}; use arrow::datatypes::DataType; -use datafusion_common::{internal_datafusion_err, DataFusionError, Result}; +use datafusion_common::Result; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; - -use super::accumulate::NullState; +use std::sync::Arc; /// An accumulator that implements a single operation over /// [`ArrowPrimitiveType`] where the accumulated state is the same as @@ -147,44 +143,31 @@ where values: &[ArrayRef], opt_filter: Option<&BooleanArray>, ) -> Result> { - let values = values[0].as_primitive::().clone(); - - // Initializing state with starting values - let initial_state = - PrimitiveArray::::from_value(self.starting_value, values.len()); - - // Recalculating values in case there is filter - let values = match opt_filter { - None => values, - Some(filter) => { - let (filter_values, filter_nulls) = filter.clone().into_parts(); - // Calculating filter mask as a result of bitand of filter, and converting it to null buffer - let filter_bool = match filter_nulls { - Some(filter_nulls) => filter_nulls.inner() & &filter_values, - None => filter_values, - }; - let filter_nulls = NullBuffer::from(filter_bool); - - // Rebuilding input values with a new nulls mask, which is equal to - // the union of original nulls and filter mask - let (dt, values_buf, original_nulls) = values.into_parts(); - let nulls_buf = - NullBuffer::union(original_nulls.as_ref(), Some(&filter_nulls)); - PrimitiveArray::::new(values_buf, nulls_buf).with_data_type(dt) + let values = values[0].as_primitive::(); + + // Figure out which values will be non null in the output + let nulls = filtered_null_mask(opt_filter, values); + + // Initializing state with starting value + let mut state = vec![self.starting_value; values.len()]; + + // update state with any non-filtered input + if nulls.is_some() { + // mask out any filtered / null input values + let values = set_nulls(values.clone(), nulls.clone()); + for (state, value) in state.iter_mut().zip(values.iter()) { + if let Some(value) = value { + (self.prim_fn)(state, value); + } + } + } else { + // no nulls in input, so iterate over all values + let all_values = values.values().iter(); + for (state, value) in state.iter_mut().zip(all_values) { + (self.prim_fn)(state, *value) } }; - - let state_values = compute::binary_mut(initial_state, &values, |mut x, y| { - (self.prim_fn)(&mut x, y); - x - }); - let state_values = state_values - .map_err(|_| { - internal_datafusion_err!( - "initial_values underlying buffer must not be shared" - ) - })? - .map_err(DataFusionError::from)? + let state_values = PrimitiveArray::::new(state.into(), nulls) .with_data_type(self.data_type.clone()); Ok(vec![Arc::new(state_values)]) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 417e28e72a71..8281449d29cd 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -16,7 +16,10 @@ // under the License. use ahash::RandomState; -use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; +use datafusion_functions_aggregate_common::aggregate::count_distinct::{ + BytesDistinctCountAccumulator, BytesViewDistinctCountAccumulator, + FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, +}; use std::collections::HashSet; use std::ops::BitAnd; use std::{fmt::Debug, sync::Arc}; @@ -47,11 +50,8 @@ use datafusion_expr::{ EmitTo, GroupsAccumulator, Signature, Volatility, }; use datafusion_expr::{Expr, ReversedUDAF, TypeSignature}; -use datafusion_functions_aggregate_common::aggregate::count_distinct::{ - BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, - PrimitiveDistinctCountAccumulator, -}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; use datafusion_physical_expr_common::binary_map::OutputType; make_udaf_expr_and_func!( @@ -450,59 +450,29 @@ impl GroupsAccumulator for CountGroupsAccumulator { /// Converts an input batch directly to a state batch /// /// The state of `COUNT` is always a single Int64Array: - /// * `1` (for non-null, non filtered values) - /// * `0` (for null values) + /// * `1` (for non null, non filtered values) + /// * `0` (for filtered or null values) fn convert_to_state( &self, values: &[ArrayRef], opt_filter: Option<&BooleanArray>, ) -> Result> { let values = &values[0]; - - let state_array = match (values.logical_nulls(), opt_filter) { - (None, None) => { - // In case there is no nulls in input and no filter, returning array of 1 - Arc::new(Int64Array::from_value(1, values.len())) - } - (Some(nulls), None) => { - // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls) - // of input array to Int64 - let nulls = BooleanArray::new(nulls.into_inner(), None); - compute::cast(&nulls, &DataType::Int64)? - } - (None, Some(filter)) => { - // If there is only filter - // - applying filter null mask to filter values by bitand filter values and nulls buffers - // (using buffers guarantees absence of nulls in result) - // - casting result of bitand to Int64 array - let (filter_values, filter_nulls) = filter.clone().into_parts(); - - let state_buf = match filter_nulls { - Some(filter_nulls) => &filter_values & filter_nulls.inner(), - None => filter_values, - }; - - let boolean_state = BooleanArray::new(state_buf, None); - compute::cast(&boolean_state, &DataType::Int64)? - } - (Some(nulls), Some(filter)) => { - // For both input nulls and filter - // - applying filter null mask to filter values by bitand filter values and nulls buffers - // (using buffers guarantees absence of nulls in result) - // - applying values null mask to filter buffer by another bitand on filter result and - // nulls from input values - // - casting result to Int64 array - let (filter_values, filter_nulls) = filter.clone().into_parts(); - - let filter_buf = match filter_nulls { - Some(filter_nulls) => &filter_values & filter_nulls.inner(), - None => filter_values, - }; - let state_buf = &filter_buf & nulls.inner(); - - let boolean_state = BooleanArray::new(state_buf, None); - compute::cast(&boolean_state, &DataType::Int64)? - } + let nulls = filtered_null_mask(opt_filter, values); + + let state_array: ArrayRef = if let Some(nulls) = nulls { + // nulls (false) in the filtered mask means we should output 0 + // counts for those values. + // + // cast kernel does the following conversion: + + // * `true` -> `1` + // * `false` -> `0` + let nulls = BooleanArray::new(nulls.into_inner(), None); + compute::cast(&nulls, &DataType::Int64)? + } else { + // all input values contribute a 1 + Arc::new(Int64Array::from_value(1, values.len())) }; Ok(vec![state_array])