diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index 18bc6801aa60..66ab64220827 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -24,7 +24,7 @@ use arrow::array::cast::AsArray; use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder}; use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::RawTableAllocExt; use std::fmt::Debug; use std::sync::Arc; @@ -207,6 +207,7 @@ where values, make_payload_fn, observe_payload_fn, + None, ) } OutputType::Utf8View => { @@ -215,6 +216,43 @@ where values, make_payload_fn, observe_payload_fn, + None, + ) + } + _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"), + }; + } + + /// Similar to [`Self::insert_if_new`] but allows the caller to provide the + /// hash values for the values in `values` instead of computing them + pub fn insert_if_new_with_hash( + &mut self, + values: &ArrayRef, + make_payload_fn: MP, + observe_payload_fn: OP, + provided_hash: &Vec, + ) where + MP: FnMut(Option<&[u8]>) -> V, + OP: FnMut(V), + { + // Sanity check array type + match self.output_type { + OutputType::BinaryView => { + assert!(matches!(values.data_type(), DataType::BinaryView)); + self.insert_if_new_inner::( + values, + make_payload_fn, + observe_payload_fn, + Some(provided_hash), + ) + } + OutputType::Utf8View => { + assert!(matches!(values.data_type(), DataType::Utf8View)); + self.insert_if_new_inner::( + values, + make_payload_fn, + observe_payload_fn, + Some(provided_hash), ) } _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"), @@ -234,19 +272,26 @@ where values: &ArrayRef, mut make_payload_fn: MP, mut observe_payload_fn: OP, + provided_hash: Option<&Vec>, ) where MP: FnMut(Option<&[u8]>) -> V, OP: FnMut(V), B: ByteViewType, { // step 1: compute hashes - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, batch_hashes) - // hash is supported for all types and create_hashes only - // returns errors for unsupported types - .unwrap(); + let batch_hashes = match provided_hash { + Some(h) => h, + None => { + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes(&[values.clone()], &self.random_state, batch_hashes) + // hash is supported for all types and create_hashes only + // returns errors for unsupported types + .unwrap(); + batch_hashes + } + }; // step 2: insert each value into the set, if not already present let values = values.as_byte_view::(); @@ -353,9 +398,7 @@ where /// Return the total size, in bytes, of memory used to store the data in /// this set, not including `self` pub fn size(&self) -> usize { - self.map_size - + self.builder.allocated_size() - + self.hashes_buffer.allocated_size() + self.map_size + self.builder.allocated_size() } } @@ -369,7 +412,6 @@ where .field("map_size", &self.map_size) .field("view_builder", &self.builder) .field("random_state", &self.random_state) - .field("hashes_buffer", &self.hashes_buffer) .finish() } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 27fa9b092c89..a0f75716d162 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -26,7 +26,7 @@ use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; use arrow_array::{Array, ArrayRef, StringViewArray}; use arrow_schema::{DataType, SchemaRef}; -// use datafusion_common::hash_utils::create_hashes; +use datafusion_common::hash_utils::{combine_hashes, create_hashes}; use datafusion_common::{DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::RawTableAllocExt; use datafusion_expr::EmitTo; @@ -81,6 +81,11 @@ pub struct GroupValuesRows { /// [`Row`]: arrow::row::Row group_values: Option, + /// reused buffer to store hashes + final_hash_buffer: Vec, + + tmp_hash_buffer: Vec, + /// reused buffer to store rows rows_buffer: Rows, @@ -121,6 +126,8 @@ impl GroupValuesRows { map, map_size: 0, group_values: None, + final_hash_buffer: Default::default(), + tmp_hash_buffer: Default::default(), rows_buffer, var_len_map, random_state: Default::default(), @@ -128,14 +135,25 @@ impl GroupValuesRows { } fn transform_col_to_fixed_len(&mut self, input: &[ArrayRef]) -> Vec { + let n_rows = input[0].len(); + // 1.1 Calculate the group keys for the group values + let final_hash_buffer = &mut self.final_hash_buffer; + final_hash_buffer.clear(); + final_hash_buffer.resize(n_rows, 0); + let tmp_hash_buffer = &mut self.tmp_hash_buffer; + tmp_hash_buffer.clear(); + tmp_hash_buffer.resize(n_rows, 0); + let mut cur_var_len_idx = 0; let transformed_cols: Vec = input .iter() .map(|c| { if let DataType::Utf8View = c.data_type() { + create_hashes(&[Arc::clone(c)], &self.random_state, tmp_hash_buffer) + .unwrap(); let mut var_groups = Vec::with_capacity(c.len()); let group_values = &mut self.var_len_map[cur_var_len_idx]; - group_values.map.insert_if_new( + group_values.map.insert_if_new_with_hash( c, |_value| { let group_idx = group_values.num_groups; @@ -145,12 +163,27 @@ impl GroupValuesRows { |group_idx| { var_groups.push(group_idx); }, + tmp_hash_buffer, ); cur_var_len_idx += 1; + final_hash_buffer + .iter_mut() + .zip(tmp_hash_buffer.iter()) + .for_each(|(result, tmp)| { + *result = combine_hashes(*result, *tmp); + }); std::sync::Arc::new(arrow_array::UInt32Array::from(var_groups)) as ArrayRef } else { - c.clone() + create_hashes(&[Arc::clone(c)], &self.random_state, tmp_hash_buffer) + .unwrap(); + final_hash_buffer + .iter_mut() + .zip(tmp_hash_buffer.iter()) + .for_each(|(result, tmp)| { + *result = combine_hashes(*result, *tmp); + }); + Arc::clone(c) } }) .collect(); @@ -186,7 +219,7 @@ impl GroupValuesRows { StringViewArray::new_unchecked( views.into(), map_content.data_buffers().to_vec(), - map_content.nulls().map(|v| v.clone()), + map_content.nulls().cloned(), ) }; cur_var_len_idx += 1; @@ -216,14 +249,12 @@ impl GroupValues for GroupValuesRows { // tracks to which group each of the input rows belongs groups.clear(); - for row in group_rows.iter() { - let hash = self.random_state.hash_one(row.as_ref()); - let entry = self.map.get_mut(hash, |(_hash, group_idx)| { + for (row, hash) in group_rows.iter().zip(self.final_hash_buffer.iter()) { + let entry = self.map.get_mut(*hash, |(_hash, group_idx)| { // verify that a group that we are inserting with hash is // actually the same key value as the group in // existing_idx (aka group_values @ row) - row.as_ref().len() == group_values.row(*group_idx).as_ref().len() - && row == group_values.row(*group_idx) + row == group_values.row(*group_idx) }); let group_idx = match entry { // Existing group_index for this group value @@ -236,7 +267,7 @@ impl GroupValues for GroupValuesRows { // for hasher function, use precomputed hash value self.map.insert_accounted( - (hash, group_idx), + (*hash, group_idx), |(hash, _group_index)| *hash, &mut self.map_size, );