Skip to content

Commit

Permalink
simplify count
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Aug 5, 2024
1 parent 30461ba commit cee9356
Showing 1 changed file with 18 additions and 47 deletions.
65 changes: 18 additions & 47 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use datafusion_expr::{
};
use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
use datafusion_physical_expr_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
use datafusion_physical_expr_common::{
aggregate::count_distinct::{
BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
Expand Down Expand Up @@ -444,59 +445,29 @@ impl GroupsAccumulator for CountGroupsAccumulator {
/// Converts an input batch directly to a state batch
///
/// The state of `COUNT` is always a single Int64Array:
/// * `1` (for non-null, non filtered values)
/// * `0` (for null values)
/// * `1` (for non null, non filtered values)
/// * `0` (for filtered or null values)
fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let values = &values[0];

let state_array = match (values.logical_nulls(), opt_filter) {
(None, None) => {
// In case there is no nulls in input and no filter, returning array of 1
Arc::new(Int64Array::from_value(1, values.len()))
}
(Some(nulls), None) => {
// If there are any nulls in input values -- casting `nulls` (true for values, false for nulls)
// of input array to Int64
let nulls = BooleanArray::new(nulls.into_inner(), None);
compute::cast(&nulls, &DataType::Int64)?
}
(None, Some(filter)) => {
// If there is only filter
// - applying filter null mask to filter values by bitand filter values and nulls buffers
// (using buffers guarantees absence of nulls in result)
// - casting result of bitand to Int64 array
let (filter_values, filter_nulls) = filter.clone().into_parts();

let state_buf = match filter_nulls {
Some(filter_nulls) => &filter_values & filter_nulls.inner(),
None => filter_values,
};

let boolean_state = BooleanArray::new(state_buf, None);
compute::cast(&boolean_state, &DataType::Int64)?
}
(Some(nulls), Some(filter)) => {
// For both input nulls and filter
// - applying filter null mask to filter values by bitand filter values and nulls buffers
// (using buffers guarantees absence of nulls in result)
// - applying values null mask to filter buffer by another bitand on filter result and
// nulls from input values
// - casting result to Int64 array
let (filter_values, filter_nulls) = filter.clone().into_parts();

let filter_buf = match filter_nulls {
Some(filter_nulls) => &filter_values & filter_nulls.inner(),
None => filter_values,
};
let state_buf = &filter_buf & nulls.inner();

let boolean_state = BooleanArray::new(state_buf, None);
compute::cast(&boolean_state, &DataType::Int64)?
}
let nulls = filtered_null_mask(opt_filter, values);

let state_array: ArrayRef = if let Some(nulls) = nulls {
// nulls (false) in the filtered mask means we should output 0
// counts for those values.
//
// cast kernel does the following conversion:

// * `true` -> `1`
// * `false` -> `0`
let nulls = BooleanArray::new(nulls.into_inner(), None);
compute::cast(&nulls, &DataType::Int64)?
} else {
// all input values contribute a 1
Arc::new(Int64Array::from_value(1, values.len()))
};

Ok(vec![state_array])
Expand Down

0 comments on commit cee9356

Please sign in to comment.