From 6a2d88da90d540c9babb4b91f3d10c10a00e3e31 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 25 Sep 2024 05:58:38 +0800 Subject: [PATCH] Avoid RowConverter for multi column grouping (10% faster clickbench queries) (#12269) * row like group values to avoid rowconverter Signed-off-by: jayzhan211 * comment out unused Signed-off-by: jayzhan211 * implement to Arrow's builder Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * switch back to vector Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 * optimize for non-null Signed-off-by: jayzhan211 * use truncate Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix first N bug Signed-off-by: jayzhan211 * fix null check Signed-off-by: jayzhan211 * fast path null Signed-off-by: jayzhan211 * fix bug Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix error Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 * adjust spill mode max mem Signed-off-by: jayzhan211 * revert test_create_external_table_with_terminator_with_newlines_in_values Signed-off-by: jayzhan211 * fix null handle bug Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * support binary Signed-off-by: jayzhan211 * add binary test Signed-off-by: jayzhan211 * use Vec instead of Option> Signed-off-by: jayzhan211 * add test and doc Signed-off-by: jayzhan211 * debug assert Signed-off-by: jayzhan211 * mv & rename Signed-off-by: jayzhan211 * fix take_n logic Signed-off-by: jayzhan211 * address comment Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../physical-expr-common/src/binary_map.rs | 2 +- .../aggregates/group_values/column_wise.rs | 314 ++++++++++++ .../group_values/group_value_row.rs | 456 ++++++++++++++++++ .../src/aggregates/group_values/mod.rs | 37 +- .../physical-plan/src/aggregates/mod.rs | 5 +- .../sqllogictest/test_files/group_by.slt | 59 +++ datafusion/sqllogictest/test_files/window.slt | 21 + 7 files changed, 891 insertions(+), 3 deletions(-) create mode 100644 datafusion/physical-plan/src/aggregates/group_values/column_wise.rs create mode 100644 datafusion/physical-plan/src/aggregates/group_values/group_value_row.rs diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index d21bdb3434c4..f320ebcc06b5 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -237,7 +237,7 @@ where /// The size, in number of entries, of the initial hash table const INITIAL_MAP_CAPACITY: usize = 128; /// The initial size, in bytes, of the string data -const INITIAL_BUFFER_CAPACITY: usize = 8 * 1024; +pub const INITIAL_BUFFER_CAPACITY: usize = 8 * 1024; impl ArrowBytesMap where V: Debug + PartialEq + Eq + Clone + Copy + Default, diff --git a/datafusion/physical-plan/src/aggregates/group_values/column_wise.rs b/datafusion/physical-plan/src/aggregates/group_values/column_wise.rs new file mode 100644 index 000000000000..b35d58701b5c --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/column_wise.rs @@ -0,0 +1,314 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::group_value_row::{ + ArrayRowEq, ByteGroupValueBuilder, PrimitiveGroupValueBuilder, +}; +use crate::aggregates::group_values::GroupValues; +use ahash::RandomState; +use arrow::compute::cast; +use arrow::datatypes::{ + Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow::record_batch::RecordBatch; +use arrow_array::{Array, ArrayRef}; +use arrow_schema::{DataType, SchemaRef}; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::{DataFusionError, Result}; +use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_expr::EmitTo; +use datafusion_physical_expr::binary_map::OutputType; + +use hashbrown::raw::RawTable; + +/// Compare GroupValue Rows column by column +pub struct GroupValuesColumn { + /// The output schema + schema: SchemaRef, + + /// Logically maps group values to a group_index in + /// [`Self::group_values`] and in each accumulator + /// + /// Uses the raw API of hashbrown to avoid actually storing the + /// keys (group values) in the table + /// + /// keys: u64 hashes of the GroupValue + /// values: (hash, group_index) + map: RawTable<(u64, usize)>, + + /// The size of `map` in bytes + map_size: usize, + + /// The actual group by values, stored column-wise. Compare from + /// the left to right, each column is stored as `ArrayRowEq`. + /// This is shown faster than the row format + group_values: Vec>, + + /// reused buffer to store hashes + hashes_buffer: Vec, + + /// Random state for creating hashes + random_state: RandomState, +} + +impl GroupValuesColumn { + pub fn try_new(schema: SchemaRef) -> Result { + let map = RawTable::with_capacity(0); + Ok(Self { + schema, + map, + map_size: 0, + group_values: vec![], + hashes_buffer: Default::default(), + random_state: Default::default(), + }) + } +} + +impl GroupValues for GroupValuesColumn { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + let n_rows = cols[0].len(); + + if self.group_values.is_empty() { + let mut v = Vec::with_capacity(cols.len()); + + for f in self.schema.fields().iter() { + let nullable = f.is_nullable(); + match f.data_type() { + &DataType::Int8 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::Int16 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::Int32 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::Int64 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::UInt8 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::UInt16 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::UInt32 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::UInt64 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::Float32 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::Float64 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::Date32 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::Date64 => { + let b = PrimitiveGroupValueBuilder::::new(nullable); + v.push(Box::new(b) as _) + } + &DataType::Utf8 => { + let b = ByteGroupValueBuilder::::new(OutputType::Utf8); + v.push(Box::new(b) as _) + } + &DataType::LargeUtf8 => { + let b = ByteGroupValueBuilder::::new(OutputType::Utf8); + v.push(Box::new(b) as _) + } + &DataType::Binary => { + let b = ByteGroupValueBuilder::::new(OutputType::Binary); + v.push(Box::new(b) as _) + } + &DataType::LargeBinary => { + let b = ByteGroupValueBuilder::::new(OutputType::Binary); + v.push(Box::new(b) as _) + } + dt => todo!("{dt} not impl"), + } + } + self.group_values = v; + } + + // tracks to which group each of the input rows belongs + groups.clear(); + + // 1.1 Calculate the group keys for the group values + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(n_rows, 0); + create_hashes(cols, &self.random_state, batch_hashes)?; + + for (row, &target_hash) in batch_hashes.iter().enumerate() { + let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| { + // Somewhat surprisingly, this closure can be called even if the + // hash doesn't match, so check the hash first with an integer + // comparison first avoid the more expensive comparison with + // group value. https://github.com/apache/datafusion/pull/11718 + if target_hash != *exist_hash { + return false; + } + + fn check_row_equal( + array_row: &dyn ArrayRowEq, + lhs_row: usize, + array: &ArrayRef, + rhs_row: usize, + ) -> bool { + array_row.equal_to(lhs_row, array, rhs_row) + } + + for (i, group_val) in self.group_values.iter().enumerate() { + if !check_row_equal(group_val.as_ref(), *group_idx, &cols[i], row) { + return false; + } + } + + true + }); + + let group_idx = match entry { + // Existing group_index for this group value + Some((_hash, group_idx)) => *group_idx, + // 1.2 Need to create new entry for the group + None => { + // Add new entry to aggr_state and save newly created index + // let group_idx = group_values.num_rows(); + // group_values.push(group_rows.row(row)); + + let mut checklen = 0; + let group_idx = self.group_values[0].len(); + for (i, group_value) in self.group_values.iter_mut().enumerate() { + group_value.append_val(&cols[i], row); + let len = group_value.len(); + if i == 0 { + checklen = len; + } else { + debug_assert_eq!(checklen, len); + } + } + + // for hasher function, use precomputed hash value + self.map.insert_accounted( + (target_hash, group_idx), + |(hash, _group_index)| *hash, + &mut self.map_size, + ); + group_idx + } + }; + groups.push(group_idx); + } + + Ok(()) + } + + fn size(&self) -> usize { + let group_values_size: usize = self.group_values.iter().map(|v| v.size()).sum(); + group_values_size + self.map_size + self.hashes_buffer.allocated_size() + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn len(&self) -> usize { + if self.group_values.is_empty() { + return 0; + } + + self.group_values[0].len() + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + let mut output = match emit_to { + EmitTo::All => { + let group_values = std::mem::take(&mut self.group_values); + debug_assert!(self.group_values.is_empty()); + + group_values + .into_iter() + .map(|v| v.build()) + .collect::>() + } + EmitTo::First(n) => { + let output = self + .group_values + .iter_mut() + .map(|v| v.take_n(n)) + .collect::>(); + + // SAFETY: self.map outlives iterator and is not modified concurrently + unsafe { + for bucket in self.map.iter() { + // Decrement group index by n + match bucket.as_ref().1.checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => bucket.as_mut().1 = sub, + // Group index was < n, so remove from table + None => self.map.erase(bucket), + } + } + } + + output + } + }; + + // TODO: Materialize dictionaries in group keys (#7647) + for (field, array) in self.schema.fields.iter().zip(&mut output) { + let expected = field.data_type(); + if let DataType::Dictionary(_, v) = expected { + let actual = array.data_type(); + if v.as_ref() != actual { + return Err(DataFusionError::Internal(format!( + "Converted group rows expected dictionary of {v} got {actual}" + ))); + } + *array = cast(array.as_ref(), expected)?; + } + } + + Ok(output) + } + + fn clear_shrink(&mut self, batch: &RecordBatch) { + let count = batch.num_rows(); + self.group_values.clear(); + self.map.clear(); + self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); + self.hashes_buffer.clear(); + self.hashes_buffer.shrink_to(count); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_value_row.rs b/datafusion/physical-plan/src/aggregates/group_values/group_value_row.rs new file mode 100644 index 000000000000..ad8da37e7ca0 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/group_value_row.rs @@ -0,0 +1,456 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::BooleanBufferBuilder; +use arrow::array::BufferBuilder; +use arrow::array::GenericBinaryArray; +use arrow::array::GenericStringArray; +use arrow::array::OffsetSizeTrait; +use arrow::array::PrimitiveArray; +use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray}; +use arrow::buffer::NullBuffer; +use arrow::buffer::OffsetBuffer; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::ArrowNativeType; +use arrow::datatypes::ByteArrayType; +use arrow::datatypes::DataType; +use arrow::datatypes::GenericBinaryType; +use arrow::datatypes::GenericStringType; +use datafusion_common::utils::proxy::VecAllocExt; + +use std::sync::Arc; +use std::vec; + +use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; + +/// Trait for group values column-wise row comparison +/// +/// Implementations of this trait store a in-progress collection of group values +/// (similar to various builders in Arrow-rs) that allow for quick comparison to +/// incoming rows. +/// +pub trait ArrayRowEq: Send + Sync { + /// Returns equal if the row stored in this builder at `lhs_row` is equal to + /// the row in `array` at `rhs_row` + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool; + /// Appends the row at `row` in `array` to this builder + fn append_val(&mut self, array: &ArrayRef, row: usize); + /// Returns the number of rows stored in this builder + fn len(&self) -> usize; + /// Returns the number of bytes used by this [`ArrayRowEq`] + fn size(&self) -> usize; + /// Builds a new array from all of the stored rows + fn build(self: Box) -> ArrayRef; + /// Builds a new array from the first `n` stored rows, shifting the + /// remaining rows to the start of the builder + fn take_n(&mut self, n: usize) -> ArrayRef; +} + +pub struct PrimitiveGroupValueBuilder { + group_values: Vec, + nulls: Vec, + // whether the array contains at least one null, for fast non-null path + has_null: bool, + nullable: bool, +} + +impl PrimitiveGroupValueBuilder +where + T: ArrowPrimitiveType, +{ + pub fn new(nullable: bool) -> Self { + Self { + group_values: vec![], + nulls: vec![], + has_null: false, + nullable, + } + } +} + +impl ArrayRowEq for PrimitiveGroupValueBuilder { + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + // non-null fast path + // both non-null + if !self.nullable { + return self.group_values[lhs_row] + == array.as_primitive::().value(rhs_row); + } + + // lhs is non-null + if self.nulls[lhs_row] { + if array.is_null(rhs_row) { + return false; + } + + return self.group_values[lhs_row] + == array.as_primitive::().value(rhs_row); + } + + array.is_null(rhs_row) + } + + fn append_val(&mut self, array: &ArrayRef, row: usize) { + if self.nullable && array.is_null(row) { + self.group_values.push(T::default_value()); + self.nulls.push(false); + self.has_null = true; + } else { + let elem = array.as_primitive::().value(row); + self.group_values.push(elem); + self.nulls.push(true); + } + } + + fn len(&self) -> usize { + self.group_values.len() + } + + fn size(&self) -> usize { + self.group_values.allocated_size() + self.nulls.allocated_size() + } + + fn build(self: Box) -> ArrayRef { + if self.has_null { + Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(self.group_values), + Some(NullBuffer::from(self.nulls)), + )) + } else { + Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(self.group_values), + None, + )) + } + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + if self.has_null { + let first_n = self.group_values.drain(0..n).collect::>(); + let first_n_nulls = self.nulls.drain(0..n).collect::>(); + Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(first_n), + Some(NullBuffer::from(first_n_nulls)), + )) + } else { + let first_n = self.group_values.drain(0..n).collect::>(); + self.nulls.truncate(self.nulls.len() - n); + Arc::new(PrimitiveArray::::new(ScalarBuffer::from(first_n), None)) + } + } +} + +pub struct ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + output_type: OutputType, + buffer: BufferBuilder, + /// Offsets into `buffer` for each distinct value. These offsets as used + /// directly to create the final `GenericBinaryArray`. The `i`th string is + /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values + /// are stored as a zero length string. + offsets: Vec, + /// Null indexes in offsets, if `i` is in nulls, `offsets[i]` should be equals to `offsets[i+1]` + nulls: Vec, +} + +impl ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + pub fn new(output_type: OutputType) -> Self { + Self { + output_type, + buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), + offsets: vec![O::default()], + nulls: vec![], + } + } + + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) + where + B: ByteArrayType, + { + let arr = array.as_bytes::(); + if arr.is_null(row) { + self.nulls.push(self.len()); + // nulls need a zero length in the offset buffer + let offset = self.buffer.len(); + + self.offsets.push(O::usize_as(offset)); + return; + } + + let value: &[u8] = arr.value(row).as_ref(); + self.buffer.append_slice(value); + self.offsets.push(O::usize_as(self.buffer.len())); + } + + fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool + where + B: ByteArrayType, + { + // Handle nulls + let is_lhs_null = self.nulls.iter().any(|null_idx| *null_idx == lhs_row); + let arr = array.as_bytes::(); + if is_lhs_null { + return arr.is_null(rhs_row); + } else if arr.is_null(rhs_row) { + return false; + } + + let arr = array.as_bytes::(); + let rhs_elem: &[u8] = arr.value(rhs_row).as_ref(); + let rhs_elem_len = arr.value_length(rhs_row).as_usize(); + debug_assert_eq!(rhs_elem_len, rhs_elem.len()); + let l = self.offsets[lhs_row].as_usize(); + let r = self.offsets[lhs_row + 1].as_usize(); + let existing_elem = unsafe { self.buffer.as_slice().get_unchecked(l..r) }; + rhs_elem == existing_elem + } +} + +impl ArrayRowEq for ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + fn equal_to(&self, lhs_row: usize, column: &ArrayRef, rhs_row: usize) -> bool { + // Sanity array type + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + column.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.equal_to_inner::>(lhs_row, column, rhs_row) + } + OutputType::Utf8 => { + debug_assert!(matches!( + column.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.equal_to_inner::>(lhs_row, column, rhs_row) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } + + fn append_val(&mut self, column: &ArrayRef, row: usize) { + // Sanity array type + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + column.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.append_val_inner::>(column, row) + } + OutputType::Utf8 => { + debug_assert!(matches!( + column.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.append_val_inner::>(column, row) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + }; + } + + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn size(&self) -> usize { + self.buffer.capacity() * std::mem::size_of::() + + self.offsets.allocated_size() + + self.nulls.allocated_size() + } + + fn build(self: Box) -> ArrayRef { + let Self { + output_type, + mut buffer, + offsets, + nulls, + } = *self; + + let null_buffer = if nulls.is_empty() { + None + } else { + // Only make a `NullBuffer` if there was a null value + let num_values = offsets.len() - 1; + let mut bool_builder = BooleanBufferBuilder::new(num_values); + bool_builder.append_n(num_values, true); + nulls.into_iter().for_each(|null_index| { + bool_builder.set_bit(null_index, false); + }); + Some(NullBuffer::from(bool_builder.finish())) + }; + + // SAFETY: the offsets were constructed correctly in `insert_if_new` -- + // monotonically increasing, overflows were checked. + let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; + let values = buffer.finish(); + match output_type { + OutputType::Binary => { + // SAFETY: the offsets were constructed correctly + Arc::new(unsafe { + GenericBinaryArray::new_unchecked(offsets, values, null_buffer) + }) + } + OutputType::Utf8 => { + // SAFETY: + // 1. the offsets were constructed safely + // + // 2. we asserted the input arrays were all the correct type and + // thus since all the values that went in were valid (e.g. utf8) + // so are all the values that come out + Arc::new(unsafe { + GenericStringArray::new_unchecked(offsets, values, null_buffer) + }) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + debug_assert!(self.len() >= n); + + let null_buffer = if self.nulls.is_empty() { + None + } else { + // Only make a `NullBuffer` if there was a null value + let mut bool_builder = BooleanBufferBuilder::new(n); + bool_builder.append_n(n, true); + + let mut new_nulls = vec![]; + self.nulls.iter().for_each(|null_index| { + if *null_index < n { + bool_builder.set_bit(*null_index, false); + } else { + new_nulls.push(null_index - n); + } + }); + + self.nulls = new_nulls; + Some(NullBuffer::from(bool_builder.finish())) + }; + + let first_remaining_offset = O::as_usize(self.offsets[n]); + + // Given offests like [0, 2, 4, 5] and n = 1, we expect to get + // offsets [0, 2, 3]. We first create two offsets for first_n as [0, 2] and the remaining as [2, 4, 5]. + // And we shift the offset starting from 0 for the remaining one, [2, 4, 5] -> [0, 2, 3]. + let mut first_n_offsets = self.offsets.drain(0..n).collect::>(); + let offset_n = *self.offsets.first().unwrap(); + self.offsets + .iter_mut() + .for_each(|offset| *offset = offset.sub(offset_n)); + first_n_offsets.push(offset_n); + + // SAFETY: the offsets were constructed correctly in `insert_if_new` -- + // monotonically increasing, overflows were checked. + let offsets = + unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(first_n_offsets)) }; + + let mut remaining_buffer = + BufferBuilder::new(self.buffer.len() - first_remaining_offset); + // TODO: Current approach copy the remaining and truncate the original one + // Find out a way to avoid copying buffer but split the original one into two. + remaining_buffer.append_slice(&self.buffer.as_slice()[first_remaining_offset..]); + self.buffer.truncate(first_remaining_offset); + let values = self.buffer.finish(); + self.buffer = remaining_buffer; + + match self.output_type { + OutputType::Binary => { + // SAFETY: the offsets were constructed correctly + Arc::new(unsafe { + GenericBinaryArray::new_unchecked(offsets, values, null_buffer) + }) + } + OutputType::Utf8 => { + // SAFETY: + // 1. the offsets were constructed safely + // + // 2. we asserted the input arrays were all the correct type and + // thus since all the values that went in were valid (e.g. utf8) + // so are all the values that come out + Arc::new(unsafe { + GenericStringArray::new_unchecked(offsets, values, null_buffer) + }) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{ArrayRef, StringArray}; + use datafusion_physical_expr::binary_map::OutputType; + + use super::{ArrayRowEq, ByteGroupValueBuilder}; + + #[test] + fn test_take_n() { + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + let array = Arc::new(StringArray::from(vec![Some("a"), None])) as ArrayRef; + // a, null, null + builder.append_val(&array, 0); + builder.append_val(&array, 1); + builder.append_val(&array, 1); + + // (a, null) remaining: null + let output = builder.take_n(2); + assert_eq!(&output, &array); + + // null, a, null, a + builder.append_val(&array, 0); + builder.append_val(&array, 1); + builder.append_val(&array, 0); + + // (null, a) remaining: (null, a) + let output = builder.take_n(2); + let array = Arc::new(StringArray::from(vec![None, Some("a")])) as ArrayRef; + assert_eq!(&output, &array); + + let array = Arc::new(StringArray::from(vec![ + Some("a"), + None, + Some("longstringfortest"), + ])) as ArrayRef; + + // null, a, longstringfortest, null, null + builder.append_val(&array, 2); + builder.append_val(&array, 1); + builder.append_val(&array, 1); + + // (null, a, longstringfortest, null) remaining: (null) + let output = builder.take_n(4); + let array = Arc::new(StringArray::from(vec![ + None, + Some("a"), + Some("longstringfortest"), + None, + ])) as ArrayRef; + assert_eq!(&output, &array); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index be7ac934d7bc..275cc7fcbf4e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -25,7 +25,9 @@ pub(crate) mod primitive; use datafusion_expr::EmitTo; use primitive::GroupValuesPrimitive; +mod column_wise; mod row; +use column_wise::GroupValuesColumn; use row::GroupValuesRows; mod bytes; @@ -33,6 +35,8 @@ mod bytes_view; use bytes::GroupValuesByes; use datafusion_physical_expr::binary_map::OutputType; +mod group_value_row; + /// An interning store for group keys pub trait GroupValues: Send { /// Calculates the `groups` for each input row of `cols` @@ -92,5 +96,36 @@ pub fn new_group_values(schema: SchemaRef) -> Result> { } } - Ok(Box::new(GroupValuesRows::try_new(schema)?)) + if schema + .fields() + .iter() + .map(|f| f.data_type()) + .all(has_row_like_feature) + { + Ok(Box::new(GroupValuesColumn::try_new(schema)?)) + } else { + Ok(Box::new(GroupValuesRows::try_new(schema)?)) + } +} + +fn has_row_like_feature(data_type: &DataType) -> bool { + matches!( + *data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::Date32 + | DataType::Date64 + ) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 637b2b87dd14..2bdaed479655 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1363,7 +1363,8 @@ mod tests { .build()?]; let task_ctx = if spill { - new_spill_ctx(4, 1000) + // adjust the max memory size to have the partial aggregate result for spill mode. + new_spill_ctx(4, 500) } else { Arc::new(TaskContext::default()) }; @@ -1381,6 +1382,8 @@ mod tests { common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let expected = if spill { + // In spill mode, we test with the limited memory, if the mem usage exceeds, + // we trigger the early emit rule, which turns out the partial aggregate result. vec![ "+---+-----+-----------------+", "| a | b | COUNT(1)[count] |", diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 73bfd9844609..86651f6ce43c 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5148,3 +5148,62 @@ NULL statement ok drop table test_case_expr + +statement ok +drop table t; + +# TODO: Current grouping set result is not align with Postgres and DuckDB, we might want to change the result +# See https://github.com/apache/datafusion/issues/12570 +# test multi group by for binary type with nulls +statement ok +create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, null), (null, 0xb), (null, 0xb); + +query I?I +select a, b, count(*) from t group by grouping sets ((a, b), (a), (b)); +---- +1 0a 2 +2 NULL 2 +NULL 0b 4 +1 NULL 2 +NULL NULL 3 +NULL 0a 2 + +statement ok +drop table t; + +# test multi group by for binary type without nulls +statement ok +create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, 0xb), (3, 0xb), (3, 0xb); + +query I?I +select a, b, count(*) from t group by grouping sets ((a, b), (a), (b)); +---- +1 0a 2 +2 0b 1 +3 0b 2 +1 NULL 2 +2 NULL 1 +3 NULL 2 +NULL 0a 2 +NULL 0b 3 + +statement ok +drop table t; + +# test multi group by int + utf8 +statement ok +create table t(a int, b varchar) as values (1, 'a'), (1, 'a'), (2, 'ab'), (3, 'abc'), (3, 'cba'), (null, null), (null, 'a'), (null, null), (null, 'a'), (1, 'null'); + +query ITI rowsort +select a, b, count(*) from t group by a, b; +---- +1 a 2 +1 null 1 +2 ab 1 +3 abc 1 +3 cba 1 +NULL NULL 2 +NULL a 2 + +statement ok +drop table t; diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 1f90b94aee11..7fee84f9bcd9 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4873,3 +4873,24 @@ SELECT NTH_VALUE('+Inf'::Double, v1) OVER (PARTITION BY v1) FROM t1; statement ok DROP TABLE t1; + +statement ok +create table t(c1 int, c2 varchar) as values (1, 'a'), (2, 'b'), (1, 'a'), (3, null), (null, 'a4'), (null, 'de'); + +# test multi group FirstN mode with nulls +query ITI +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER() as rn + FROM t + LIMIT 5) +GROUP BY rn +ORDER BY rn; +---- +1 a 1 +2 b 2 +1 a 3 +3 NULL 4 +NULL a4 5 + +statement ok +drop table t