From 9a3f8d11546ffa68579cf3c68726d8e0db1a6b06 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 26 Sep 2024 17:23:21 -0400 Subject: [PATCH] Minor: Encapsulate type check in GroupValuesColumn, avoid panic (#12620) * Encapsulate type check in GroupValuesColumn, avoid panic * Update datafusion/physical-plan/src/aggregates/group_values/column_wise.rs Co-authored-by: Oleks V * Clarify what supported means * return not implemented error * Fixup doc link --------- Co-authored-by: Oleks V --- .../src/aggregates/group_values/column.rs | 44 +++++++++++++++++-- .../src/aggregates/group_values/mod.rs | 29 +----------- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/column.rs b/datafusion/physical-plan/src/aggregates/group_values/column.rs index 311f48ba9839..977b40922f7c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/column.rs @@ -27,9 +27,9 @@ use arrow::datatypes::{ }; use arrow::record_batch::RecordBatch; use arrow_array::{Array, ArrayRef}; -use arrow_schema::{DataType, SchemaRef}; +use arrow_schema::{DataType, Schema, SchemaRef}; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; @@ -67,6 +67,7 @@ pub struct GroupValuesColumn { } impl GroupValuesColumn { + /// Create a new instance of GroupValuesColumn if supported for the specified schema pub fn try_new(schema: SchemaRef) -> Result { let map = RawTable::with_capacity(0); Ok(Self { @@ -78,6 +79,41 @@ impl GroupValuesColumn { random_state: Default::default(), }) } + + /// Returns true if [`GroupValuesColumn`] supported for the specified schema + pub fn supported_schema(schema: &Schema) -> bool { + schema + .fields() + .iter() + .map(|f| f.data_type()) + .all(Self::supported_type) + } + + /// Returns true if the specified data type is supported by [`GroupValuesColumn`] + /// + /// In order to be supported, there must be a specialized implementation of + /// [`GroupColumn`] for the data type, instantiated in [`Self::intern`] + fn supported_type(data_type: &DataType) -> bool { + matches!( + *data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::Date32 + | DataType::Date64 + ) + } } impl GroupValues for GroupValuesColumn { @@ -154,7 +190,9 @@ impl GroupValues for GroupValuesColumn { let b = ByteGroupValueBuilder::::new(OutputType::Binary); v.push(Box::new(b) as _) } - dt => todo!("{dt} not impl"), + dt => { + return not_impl_err!("{dt} not supported in GroupValuesColumn") + } } } self.group_values = v; diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 3e0474d4c2d0..9256631fa578 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -96,36 +96,9 @@ pub fn new_group_values(schema: SchemaRef) -> Result> { } } - if schema - .fields() - .iter() - .map(|f| f.data_type()) - .all(has_row_like_feature) - { + if GroupValuesColumn::supported_schema(schema.as_ref()) { Ok(Box::new(GroupValuesColumn::try_new(schema)?)) } else { Ok(Box::new(GroupValuesRows::try_new(schema)?)) } } - -fn has_row_like_feature(data_type: &DataType) -> bool { - matches!( - *data_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Binary - | DataType::LargeBinary - | DataType::Date32 - | DataType::Date64 - ) -}