Skip to content

Commit

Permalink
Split count_distinct.rs into separate modules (apache#9087)
Browse files Browse the repository at this point in the history
* Split count_distinct.rs into separate modules

* Remove unecessary typedef

* Rename

* improve module comments
  • Loading branch information
alamb authored Feb 1, 2024
1 parent 968c05f commit 8b50774
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 217 deletions.
257 changes: 41 additions & 216 deletions datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,36 @@
// specific language governing permissions and limitations
// under the License.

mod native;
mod strings;

use std::any::Any;
use std::cmp::Eq;
use std::collections::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;

use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow_array::types::{
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type,
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType,
Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::PrimitiveArray;

use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;

use crate::aggregate::count_distinct::native::{
FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator,
};
use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator;
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};

type DistinctScalarValues = ScalarValue;

/// Expression for a COUNT(DISTINCT) aggregation.
#[derive(Debug)]
pub struct DistinctCount {
Expand Down Expand Up @@ -101,46 +98,46 @@ impl AggregateExpr for DistinctCount {
use TimeUnit::*;

Ok(match &self.state_data_type {
Int8 => Box::new(NativeDistinctCountAccumulator::<Int8Type>::new()),
Int16 => Box::new(NativeDistinctCountAccumulator::<Int16Type>::new()),
Int32 => Box::new(NativeDistinctCountAccumulator::<Int32Type>::new()),
Int64 => Box::new(NativeDistinctCountAccumulator::<Int64Type>::new()),
UInt8 => Box::new(NativeDistinctCountAccumulator::<UInt8Type>::new()),
UInt16 => Box::new(NativeDistinctCountAccumulator::<UInt16Type>::new()),
UInt32 => Box::new(NativeDistinctCountAccumulator::<UInt32Type>::new()),
UInt64 => Box::new(NativeDistinctCountAccumulator::<UInt64Type>::new()),
Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new()),
Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new()),
Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new()),
Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new()),
UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new()),
UInt16 => Box::new(PrimitiveDistinctCountAccumulator::<UInt16Type>::new()),
UInt32 => Box::new(PrimitiveDistinctCountAccumulator::<UInt32Type>::new()),
UInt64 => Box::new(PrimitiveDistinctCountAccumulator::<UInt64Type>::new()),
Decimal128(_, _) => {
Box::new(NativeDistinctCountAccumulator::<Decimal128Type>::new())
Box::new(PrimitiveDistinctCountAccumulator::<Decimal128Type>::new())
}
Decimal256(_, _) => {
Box::new(NativeDistinctCountAccumulator::<Decimal256Type>::new())
Box::new(PrimitiveDistinctCountAccumulator::<Decimal256Type>::new())
}

Date32 => Box::new(NativeDistinctCountAccumulator::<Date32Type>::new()),
Date64 => Box::new(NativeDistinctCountAccumulator::<Date64Type>::new()),
Time32(Millisecond) => {
Box::new(NativeDistinctCountAccumulator::<Time32MillisecondType>::new())
}
Date32 => Box::new(PrimitiveDistinctCountAccumulator::<Date32Type>::new()),
Date64 => Box::new(PrimitiveDistinctCountAccumulator::<Date64Type>::new()),
Time32(Millisecond) => Box::new(PrimitiveDistinctCountAccumulator::<
Time32MillisecondType,
>::new()),
Time32(Second) => {
Box::new(NativeDistinctCountAccumulator::<Time32SecondType>::new())
}
Time64(Microsecond) => {
Box::new(NativeDistinctCountAccumulator::<Time64MicrosecondType>::new())
Box::new(PrimitiveDistinctCountAccumulator::<Time32SecondType>::new())
}
Time64(Microsecond) => Box::new(PrimitiveDistinctCountAccumulator::<
Time64MicrosecondType,
>::new()),
Time64(Nanosecond) => {
Box::new(NativeDistinctCountAccumulator::<Time64NanosecondType>::new())
Box::new(PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new())
}
Timestamp(Microsecond, _) => Box::new(NativeDistinctCountAccumulator::<
Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampMicrosecondType,
>::new()),
Timestamp(Millisecond, _) => Box::new(NativeDistinctCountAccumulator::<
Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampMillisecondType,
>::new()),
Timestamp(Nanosecond, _) => {
Box::new(NativeDistinctCountAccumulator::<TimestampNanosecondType>::new())
}
Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampNanosecondType,
>::new()),
Timestamp(Second, _) => {
Box::new(NativeDistinctCountAccumulator::<TimestampSecondType>::new())
Box::new(PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new())
}

Float16 => Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()),
Expand Down Expand Up @@ -175,9 +172,13 @@ impl PartialEq<dyn Any> for DistinctCount {
}
}

/// General purpose distinct accumulator that works for any DataType by using
/// [`ScalarValue`]. Some types have specialized accumulators that are (much)
/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
/// [`StringDistinctCountAccumulator`]
#[derive(Debug)]
struct DistinctCountAccumulator {
values: HashSet<DistinctScalarValues, RandomState>,
values: HashSet<ScalarValue, RandomState>,
state_data_type: DataType,
}

Expand All @@ -186,7 +187,7 @@ impl DistinctCountAccumulator {
// This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types
fn fixed_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<DistinctScalarValues>() * self.values.capacity())
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
+ self
.values
.iter()
Expand All @@ -199,7 +200,7 @@ impl DistinctCountAccumulator {
// calculates the size as accurate as possible, call to this method is expensive
fn full_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<DistinctScalarValues>() * self.values.capacity())
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
+ self
.values
.iter()
Expand Down Expand Up @@ -260,182 +261,6 @@ impl Accumulator for DistinctCountAccumulator {
}
}

#[derive(Debug)]
struct NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
{
values: HashSet<T::Native, RandomState>,
}

impl<T> NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
{
fn new() -> Self {
Self {
values: HashSet::default(),
}
}
}

impl<T> Accumulator for NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send + Debug,
T::Native: Eq + Hash,
{
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().cloned(),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = as_primitive_array::<T>(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.values.insert(value);
}
});

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);

let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_primitive_array::<T>(&list)?;
self.values.extend(list.values())
};
Ok(())
})
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}

fn size(&self) -> usize {
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
/ 7)
.next_power_of_two();

// Size of accumulator
// + size of entry * number of buckets
// + 1 byte for each bucket
// + fixed size of HashSet
std::mem::size_of_val(self)
+ std::mem::size_of::<T::Native>() * estimated_buckets
+ estimated_buckets
+ std::mem::size_of_val(&self.values)
}
}

#[derive(Debug)]
struct FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
values: HashSet<Hashable<T::Native>, RandomState>,
}

impl<T> FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
fn new() -> Self {
Self {
values: HashSet::default(),
}
}
}

impl<T> Accumulator for FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send + Debug,
{
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().map(|v| v.0),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = as_primitive_array::<T>(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.values.insert(Hashable(value));
}
});

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);

let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_primitive_array::<T>(&list)?;
self.values
.extend(list.values().iter().map(|v| Hashable(*v)));
};
Ok(())
})
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}

fn size(&self) -> usize {
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
/ 7)
.next_power_of_two();

// Size of accumulator
// + size of entry * number of buckets
// + 1 byte for each bucket
// + fixed size of HashSet
std::mem::size_of_val(self)
+ std::mem::size_of::<T::Native>() * estimated_buckets
+ estimated_buckets
+ std::mem::size_of_val(&self.values)
}
}

#[cfg(test)]
mod tests {
use arrow::array::{
Expand Down
Loading

0 comments on commit 8b50774

Please sign in to comment.