diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 53af19e80b3f..59b4e2d866be 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -384,7 +384,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand", + "fastrand 1.9.0", "hex", "http", "hyper", @@ -404,7 +404,7 @@ checksum = "1fcdb2f7acbc076ff5ad05e7864bdb191ca70a6fd07668dc3a1a8bcd051de5ae" dependencies = [ "aws-smithy-async", "aws-smithy-types", - "fastrand", + "fastrand 1.9.0", "tokio", "tracing", "zeroize", @@ -550,7 +550,7 @@ dependencies = [ "aws-smithy-http-tower", "aws-smithy-types", "bytes", - "fastrand", + "fastrand 1.9.0", "http", "http-body", "hyper", @@ -1068,6 +1068,7 @@ dependencies = [ "flate2", "futures", "glob", + "half", "hashbrown 0.14.0", "indexmap 2.0.0", "itertools 0.11.0", @@ -1362,6 +1363,12 @@ dependencies = [ "instant", ] +[[package]] +name = "fastrand" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764" + [[package]] name = "fd-lock" version = "3.0.13" @@ -1369,7 +1376,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" dependencies = [ "cfg-if", - "rustix 0.38.4", + "rustix", "windows-sys", ] @@ -1794,17 +1801,6 @@ version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" -[[package]] -name = "io-lifetimes" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" -dependencies = [ - "hermit-abi 0.3.2", - "libc", - "windows-sys", -] - [[package]] name = "ipnet" version = "2.8.0" @@ -1945,12 +1941,6 @@ dependencies = [ "libc", ] -[[package]] -name = "linux-raw-sys" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" - [[package]] name = "linux-raw-sys" version = "0.4.3" @@ -2694,20 +2684,6 @@ dependencies = [ "semver", ] -[[package]] -name = "rustix" -version = "0.37.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d69718bf81c6127a49dc64e44a742e8bb9213c0ff8869a22c308f84c1d4ab06" -dependencies = [ - "bitflags 1.3.2", - "errno", - "io-lifetimes", - "libc", - "linux-raw-sys 0.3.8", - "windows-sys", -] - [[package]] name = "rustix" version = "0.38.4" @@ -2717,7 +2693,7 @@ dependencies = [ "bitflags 2.3.3", "errno", "libc", - "linux-raw-sys 0.4.3", + "linux-raw-sys", "windows-sys", ] @@ -2882,18 +2858,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.171" +version = "1.0.173" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30e27d1e4fd7659406c492fd6cfaf2066ba8773de45ca75e855590f856dc34a9" +checksum = "e91f70896d6720bc714a4a57d22fc91f1db634680e65c8efe13323f1fa38d53f" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.171" +version = "1.0.173" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389894603bd18c46fa56231694f8d827779c0951a667087194cf9de94ed24682" +checksum = "a6250dde8342e0232232be9ca3db7aa40aceb5a3e5dd9bddbc00d99a007cde49" dependencies = [ "proc-macro2", "quote", @@ -3109,15 +3085,14 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.6.0" +version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" +checksum = "5486094ee78b2e5038a6382ed7645bc084dc2ec433426ca4c3cb61e2007b8998" dependencies = [ - "autocfg", "cfg-if", - "fastrand", + "fastrand 2.0.0", "redox_syscall 0.3.5", - "rustix 0.37.23", + "rustix", "windows-sys", ] @@ -3745,18 +3720,18 @@ checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" [[package]] name = "zstd" -version = "0.12.3+zstd.1.5.2" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76eea132fb024e0e13fd9c2f5d5d595d8a967aa72382ac2f9d39fcc95afd0806" +checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" dependencies = [ "zstd-safe", ] [[package]] name = "zstd-safe" -version = "6.0.5+zstd.1.5.4" +version = "6.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d56d9e60b4b1758206c238a10165fbcae3ca37b01744e394c463463f6529d23b" +checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" dependencies = [ "libc", "zstd-sys", diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index a95a8266052f..eaed7f00207c 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -71,6 +71,7 @@ datafusion-sql = { path = "../sql", version = "27.0.0" } flate2 = { version = "1.0.24", optional = true } futures = "0.3" glob = "0.3.0" +half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } indexmap = "2.0.0" itertools = "0.11" diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs b/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs new file mode 100644 index 000000000000..46f372b6ad28 --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{downcast_primitive, ArrayRef}; +use arrow_schema::SchemaRef; +use datafusion_common::Result; +use datafusion_physical_expr::EmitTo; + +mod primitive; +use primitive::GroupValuesPrimitive; + +mod row; +use row::GroupValuesRows; + +/// An interning store for group keys +pub trait GroupValues: Send { + /// Calculates the `groups` for each input row of `cols` + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; + + /// Returns the number of bytes used by this [`GroupValues`] + fn size(&self) -> usize; + + /// Returns true if this [`GroupValues`] is empty + fn is_empty(&self) -> bool; + + /// The number of values stored in this [`GroupValues`] + fn len(&self) -> usize; + + /// Emits the group values + fn emit(&mut self, emit_to: EmitTo) -> Result>; +} + +pub fn new_group_values(schema: SchemaRef) -> Result> { + if schema.fields.len() == 1 { + let d = schema.fields[0].data_type(); + + macro_rules! downcast_helper { + ($t:ty, $d:ident) => { + return Ok(Box::new(GroupValuesPrimitive::<$t>::new($d.clone()))) + }; + } + + downcast_primitive! { + d => (downcast_helper, d), + _ => {} + } + } + + Ok(Box::new(GroupValuesRows::try_new(schema)?)) +} diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs b/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs new file mode 100644 index 000000000000..7b8691c67fdd --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::physical_plan::aggregates::group_values::GroupValues; +use ahash::RandomState; +use arrow::array::BooleanBufferBuilder; +use arrow::buffer::NullBuffer; +use arrow::datatypes::i256; +use arrow_array::cast::AsArray; +use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray}; +use arrow_schema::DataType; +use datafusion_common::Result; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_physical_expr::EmitTo; +use half::f16; +use hashbrown::raw::RawTable; +use std::sync::Arc; + +/// A trait to allow hashing of floating point numbers +trait HashValue { + fn hash(self, state: &RandomState) -> u64; +} + +macro_rules! hash_integer { + ($($t:ty),+) => { + $(impl HashValue for $t { + #[cfg(not(feature = "force_hash_collisions"))] + fn hash(self, state: &RandomState) -> u64 { + state.hash_one(self) + } + + #[cfg(feature = "force_hash_collisions")] + fn hash(self, _state: &RandomState) -> u64 { + 0 + } + })+ + }; +} +hash_integer!(i8, i16, i32, i64, i128, i256); +hash_integer!(u8, u16, u32, u64); + +macro_rules! hash_float { + ($($t:ty),+) => { + $(impl HashValue for $t { + #[cfg(not(feature = "force_hash_collisions"))] + fn hash(self, state: &RandomState) -> u64 { + state.hash_one(self.to_bits()) + } + + #[cfg(feature = "force_hash_collisions")] + fn hash(self, _state: &RandomState) -> u64 { + 0 + } + })+ + }; +} + +hash_float!(f16, f32, f64); + +/// A [`GroupValues`] storing a single column of primitive values +/// +/// This specialization is significantly faster than using the more general +/// purpose `Row`s format +pub struct GroupValuesPrimitive { + /// The data type of the output array + data_type: DataType, + /// Stores the group index based on the hash of its value + /// + /// We don't store the hashes as hashing fixed width primitives + /// is fast enough for this not to benefit performance + map: RawTable, + /// The group index of the null value if any + null_group: Option, + /// The values for each group index + values: Vec, + /// The random state used to generate hashes + random_state: RandomState, +} + +impl GroupValuesPrimitive { + pub fn new(data_type: DataType) -> Self { + assert!(PrimitiveArray::::is_compatible(&data_type)); + Self { + data_type, + map: RawTable::with_capacity(128), + values: Vec::with_capacity(128), + null_group: None, + random_state: Default::default(), + } + } +} + +impl GroupValues for GroupValuesPrimitive +where + T::Native: HashValue, +{ + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + assert_eq!(cols.len(), 1); + groups.clear(); + + for v in cols[0].as_primitive::() { + let group_id = match v { + None => *self.null_group.get_or_insert_with(|| { + let group_id = self.values.len(); + self.values.push(Default::default()); + group_id + }), + Some(key) => { + let state = &self.random_state; + let hash = key.hash(state); + let insert = self.map.find_or_find_insert_slot( + hash, + |g| unsafe { self.values.get_unchecked(*g).is_eq(key) }, + |g| unsafe { self.values.get_unchecked(*g).hash(state) }, + ); + + // SAFETY: No mutation occurred since find_or_find_insert_slot + unsafe { + match insert { + Ok(v) => *v.as_ref(), + Err(slot) => { + let g = self.values.len(); + self.map.insert_in_slot(hash, slot, g); + self.values.push(key); + g + } + } + } + } + }; + groups.push(group_id) + } + Ok(()) + } + + fn size(&self) -> usize { + self.map.capacity() * std::mem::size_of::() + self.values.allocated_size() + } + + fn is_empty(&self) -> bool { + self.values.is_empty() + } + + fn len(&self) -> usize { + self.values.len() + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + fn build_primitive( + values: Vec, + null_idx: Option, + ) -> PrimitiveArray { + let nulls = null_idx.map(|null_idx| { + let mut buffer = BooleanBufferBuilder::new(values.len()); + buffer.append_n(values.len(), true); + buffer.set_bit(null_idx, false); + unsafe { NullBuffer::new_unchecked(buffer.finish(), 1) } + }); + PrimitiveArray::::new(values.into(), nulls) + } + + let array: PrimitiveArray = match emit_to { + EmitTo::All => { + self.map.clear(); + build_primitive(std::mem::take(&mut self.values), self.null_group.take()) + } + EmitTo::First(n) => { + // SAFETY: self.map outlives iterator and is not modified concurrently + unsafe { + for bucket in self.map.iter() { + // Decrement group index by n + match bucket.as_ref().checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => *bucket.as_mut() = sub, + // Group index was < n, so remove from table + None => self.map.erase(bucket), + } + } + } + let null_group = match &mut self.null_group { + Some(v) if *v >= n => { + *v -= n; + None + } + Some(_) => self.null_group.take(), + None => None, + }; + let mut split = self.values.split_off(n); + std::mem::swap(&mut self.values, &mut split); + build_primitive(split, null_group) + } + }; + Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) + } +} diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/row.rs b/datafusion/core/src/physical_plan/aggregates/group_values/row.rs new file mode 100644 index 000000000000..4eb660d52590 --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/group_values/row.rs @@ -0,0 +1,184 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::physical_plan::aggregates::group_values::GroupValues; +use ahash::RandomState; +use arrow::row::{RowConverter, Rows, SortField}; +use arrow_array::ArrayRef; +use arrow_schema::SchemaRef; +use datafusion_common::Result; +use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_physical_expr::hash_utils::create_hashes; +use datafusion_physical_expr::EmitTo; +use hashbrown::raw::RawTable; + +/// A [`GroupValues`] making use of [`Rows`] +pub struct GroupValuesRows { + /// Converter for the group values + row_converter: RowConverter, + + /// Logically maps group values to a group_index in + /// [`Self::group_values`] and in each accumulator + /// + /// Uses the raw API of hashbrown to avoid actually storing the + /// keys (group values) in the table + /// + /// keys: u64 hashes of the GroupValue + /// values: (hash, group_index) + map: RawTable<(u64, usize)>, + + /// The size of `map` in bytes + map_size: usize, + + /// The actual group by values, stored in arrow [`Row`] format. + /// `group_values[i]` holds the group value for group_index `i`. + /// + /// The row format is used to compare group keys quickly and store + /// them efficiently in memory. Quick comparison is especially + /// important for multi-column group keys. + /// + /// [`Row`]: arrow::row::Row + group_values: Rows, + + // buffer to be reused to store hashes + hashes_buffer: Vec, + + /// Random state for creating hashes + random_state: RandomState, +} + +impl GroupValuesRows { + pub fn try_new(schema: SchemaRef) -> Result { + let row_converter = RowConverter::new( + schema + .fields() + .iter() + .map(|f| SortField::new(f.data_type().clone())) + .collect(), + )?; + + let map = RawTable::with_capacity(0); + let group_values = row_converter.empty_rows(0, 0); + + Ok(Self { + row_converter, + map, + map_size: 0, + group_values, + hashes_buffer: Default::default(), + random_state: Default::default(), + }) + } +} + +impl GroupValues for GroupValuesRows { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + // Convert the group keys into the row format + // Avoid reallocation when https://github.com/apache/arrow-rs/issues/4479 is available + let group_rows = self.row_converter.convert_columns(cols)?; + let n_rows = group_rows.num_rows(); + + // tracks to which group each of the input rows belongs + groups.clear(); + + // 1.1 Calculate the group keys for the group values + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(n_rows, 0); + create_hashes(cols, &self.random_state, batch_hashes)?; + + for (row, &hash) in batch_hashes.iter().enumerate() { + let entry = self.map.get_mut(hash, |(_hash, group_idx)| { + // verify that a group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + group_rows.row(row) == self.group_values.row(*group_idx) + }); + + let group_idx = match entry { + // Existing group_index for this group value + Some((_hash, group_idx)) => *group_idx, + // 1.2 Need to create new entry for the group + None => { + // Add new entry to aggr_state and save newly created index + let group_idx = self.group_values.num_rows(); + self.group_values.push(group_rows.row(row)); + + // for hasher function, use precomputed hash value + self.map.insert_accounted( + (hash, group_idx), + |(hash, _group_index)| *hash, + &mut self.map_size, + ); + group_idx + } + }; + groups.push(group_idx); + } + + Ok(()) + } + + fn size(&self) -> usize { + self.row_converter.size() + + self.group_values.size() + + self.map_size + + self.hashes_buffer.allocated_size() + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn len(&self) -> usize { + self.group_values.num_rows() + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + Ok(match emit_to { + EmitTo::All => { + // Eventually we may also want to clear the hash table here + self.row_converter.convert_rows(&self.group_values)? + } + EmitTo::First(n) => { + let groups_rows = self.group_values.iter().take(n); + let output = self.row_converter.convert_rows(groups_rows)?; + // Clear out first n group keys by copying them to a new Rows. + // TODO file some ticket in arrow-rs to make this more efficent? + let mut new_group_values = self.row_converter.empty_rows(0, 0); + for row in self.group_values.iter().skip(n) { + new_group_values.push(row); + } + std::mem::swap(&mut new_group_values, &mut self.group_values); + + // SAFETY: self.map outlives iterator and is not modified concurrently + unsafe { + for bucket in self.map.iter() { + // Decrement group index by n + match bucket.as_ref().1.checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => bucket.as_mut().1 = sub, + // Group index was < n, so remove from table + None => self.map.erase(bucket), + } + } + } + output + } + }) + } +} diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index 5b4e6dbdf024..f35b186f0815 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -44,6 +44,7 @@ use std::any::Any; use std::collections::HashMap; use std::sync::Arc; +mod group_values; mod no_grouping; mod order; mod row_hash; diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index e3ac5c49a94b..4613a2e46443 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -25,12 +25,10 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; -use ahash::RandomState; -use arrow::row::{RowConverter, Rows, SortField}; -use datafusion_physical_expr::hash_utils::create_hashes; use futures::ready; use futures::stream::{Stream, StreamExt}; +use crate::physical_plan::aggregates::group_values::{new_group_values, GroupValues}; use crate::physical_plan::aggregates::{ evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, PhysicalGroupBy, @@ -41,10 +39,9 @@ use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use datafusion_common::Result; -use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; -use hashbrown::raw::RawTable; #[derive(Debug, Clone)] /// This object tracks the aggregation phase (input/output) @@ -59,181 +56,6 @@ pub(crate) enum ExecutionState { use super::order::GroupOrdering; use super::AggregateExec; -/// An interning store for group keys -trait GroupValues: Send { - /// Calculates the `groups` for each input row of `cols` - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; - - /// Returns the number of bytes used by this [`GroupValues`] - fn size(&self) -> usize; - - /// Returns true if this [`GroupValues`] is empty - fn is_empty(&self) -> bool; - - /// The number of values stored in this [`GroupValues`] - fn len(&self) -> usize; - - /// Emits the group values - fn emit(&mut self, emit_to: EmitTo) -> Result>; -} - -/// A [`GroupValues`] making use of [`Rows`] -struct GroupValuesRows { - /// Converter for the group values - row_converter: RowConverter, - - /// Logically maps group values to a group_index in - /// [`Self::group_values`] and in each accumulator - /// - /// Uses the raw API of hashbrown to avoid actually storing the - /// keys (group values) in the table - /// - /// keys: u64 hashes of the GroupValue - /// values: (hash, group_index) - map: RawTable<(u64, usize)>, - - /// The size of `map` in bytes - map_size: usize, - - /// The actual group by values, stored in arrow [`Row`] format. - /// `group_values[i]` holds the group value for group_index `i`. - /// - /// The row format is used to compare group keys quickly and store - /// them efficiently in memory. Quick comparison is especially - /// important for multi-column group keys. - /// - /// [`Row`]: arrow::row::Row - group_values: Rows, - - // buffer to be reused to store hashes - hashes_buffer: Vec, - - /// Random state for creating hashes - random_state: RandomState, -} - -impl GroupValuesRows { - fn try_new(schema: SchemaRef) -> Result { - let row_converter = RowConverter::new( - schema - .fields() - .iter() - .map(|f| SortField::new(f.data_type().clone())) - .collect(), - )?; - - let map = RawTable::with_capacity(0); - let group_values = row_converter.empty_rows(0, 0); - - Ok(Self { - row_converter, - map, - map_size: 0, - group_values, - hashes_buffer: Default::default(), - random_state: Default::default(), - }) - } -} - -impl GroupValues for GroupValuesRows { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - // Convert the group keys into the row format - // Avoid reallocation when https://github.com/apache/arrow-rs/issues/4479 is available - let group_rows = self.row_converter.convert_columns(cols)?; - let n_rows = group_rows.num_rows(); - - // tracks to which group each of the input rows belongs - groups.clear(); - - // 1.1 Calculate the group keys for the group values - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(n_rows, 0); - create_hashes(cols, &self.random_state, batch_hashes)?; - - for (row, &hash) in batch_hashes.iter().enumerate() { - let entry = self.map.get_mut(hash, |(_hash, group_idx)| { - // verify that a group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) - group_rows.row(row) == self.group_values.row(*group_idx) - }); - - let group_idx = match entry { - // Existing group_index for this group value - Some((_hash, group_idx)) => *group_idx, - // 1.2 Need to create new entry for the group - None => { - // Add new entry to aggr_state and save newly created index - let group_idx = self.group_values.num_rows(); - self.group_values.push(group_rows.row(row)); - - // for hasher function, use precomputed hash value - self.map.insert_accounted( - (hash, group_idx), - |(hash, _group_index)| *hash, - &mut self.map_size, - ); - group_idx - } - }; - groups.push(group_idx); - } - - Ok(()) - } - - fn size(&self) -> usize { - self.row_converter.size() - + self.group_values.size() - + self.map_size - + self.hashes_buffer.allocated_size() - } - - fn is_empty(&self) -> bool { - self.len() == 0 - } - - fn len(&self) -> usize { - self.group_values.num_rows() - } - - fn emit(&mut self, emit_to: EmitTo) -> Result> { - Ok(match emit_to { - EmitTo::All => { - // Eventually we may also want to clear the hash table here - self.row_converter.convert_rows(&self.group_values)? - } - EmitTo::First(n) => { - let groups_rows = self.group_values.iter().take(n); - let output = self.row_converter.convert_rows(groups_rows)?; - // Clear out first n group keys by copying them to a new Rows. - // TODO file some ticket in arrow-rs to make this more efficent? - let mut new_group_values = self.row_converter.empty_rows(0, 0); - for row in self.group_values.iter().skip(n) { - new_group_values.push(row); - } - std::mem::swap(&mut new_group_values, &mut self.group_values); - - // SAFETY: self.map outlives iterator and is not modified concurrently - unsafe { - for bucket in self.map.iter() { - // Decrement group index by n - match bucket.as_ref().1.checked_sub(n) { - // Group index was >= n, shift value down - Some(sub) => bucket.as_mut().1 = sub, - // Group index was < n, so remove from table - None => self.map.erase(bucket), - } - } - } - output - } - }) - } -} - /// Hash based Grouping Aggregator /// /// # Design Goals @@ -416,8 +238,7 @@ impl GroupedHashAggregateStream { .transpose()? .unwrap_or(GroupOrdering::None); - let group = Box::new(GroupValuesRows::try_new(group_schema)?); - + let group_values = new_group_values(group_schema)?; timer.done(); let exec_state = ExecutionState::ReadingInput; @@ -431,7 +252,7 @@ impl GroupedHashAggregateStream { filter_expressions, group_by: agg_group_by, reservation, - group_values: group, + group_values, current_group_indices: Default::default(), exec_state, baseline_metrics,