Skip to content

Commit

Permalink
reuse hash
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Jul 29, 2024
1 parent efdd2e4 commit edb5a1f
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 22 deletions.
66 changes: 54 additions & 12 deletions datafusion/physical-expr-common/src/binary_view_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::array::cast::AsArray;
use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder};
use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_common::utils::proxy::RawTableAllocExt;
use std::fmt::Debug;
use std::sync::Arc;

Expand Down Expand Up @@ -207,6 +207,7 @@ where
values,
make_payload_fn,
observe_payload_fn,
None,
)
}
OutputType::Utf8View => {
Expand All @@ -215,6 +216,43 @@ where
values,
make_payload_fn,
observe_payload_fn,
None,
)
}
_ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
};
}

/// Similar to [`Self::insert_if_new`] but allows the caller to provide the
/// hash values for the values in `values` instead of computing them
pub fn insert_if_new_with_hash<MP, OP>(
&mut self,
values: &ArrayRef,
make_payload_fn: MP,
observe_payload_fn: OP,
provided_hash: &Vec<u64>,
) where
MP: FnMut(Option<&[u8]>) -> V,
OP: FnMut(V),
{
// Sanity check array type
match self.output_type {
OutputType::BinaryView => {
assert!(matches!(values.data_type(), DataType::BinaryView));
self.insert_if_new_inner::<MP, OP, BinaryViewType>(
values,
make_payload_fn,
observe_payload_fn,
Some(provided_hash),
)
}
OutputType::Utf8View => {
assert!(matches!(values.data_type(), DataType::Utf8View));
self.insert_if_new_inner::<MP, OP, StringViewType>(
values,
make_payload_fn,
observe_payload_fn,
Some(provided_hash),
)
}
_ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
Expand All @@ -234,19 +272,26 @@ where
values: &ArrayRef,
mut make_payload_fn: MP,
mut observe_payload_fn: OP,
provided_hash: Option<&Vec<u64>>,
) where
MP: FnMut(Option<&[u8]>) -> V,
OP: FnMut(V),
B: ByteViewType,
{
// step 1: compute hashes
let batch_hashes = &mut self.hashes_buffer;
batch_hashes.clear();
batch_hashes.resize(values.len(), 0);
create_hashes(&[values.clone()], &self.random_state, batch_hashes)
// hash is supported for all types and create_hashes only
// returns errors for unsupported types
.unwrap();
let batch_hashes = match provided_hash {
Some(h) => h,
None => {
let batch_hashes = &mut self.hashes_buffer;
batch_hashes.clear();
batch_hashes.resize(values.len(), 0);
create_hashes(&[values.clone()], &self.random_state, batch_hashes)
// hash is supported for all types and create_hashes only
// returns errors for unsupported types
.unwrap();
batch_hashes
}
};

// step 2: insert each value into the set, if not already present
let values = values.as_byte_view::<B>();
Expand Down Expand Up @@ -353,9 +398,7 @@ where
/// Return the total size, in bytes, of memory used to store the data in
/// this set, not including `self`
pub fn size(&self) -> usize {
self.map_size
+ self.builder.allocated_size()
+ self.hashes_buffer.allocated_size()
self.map_size + self.builder.allocated_size()
}
}

Expand All @@ -369,7 +412,6 @@ where
.field("map_size", &self.map_size)
.field("view_builder", &self.builder)
.field("random_state", &self.random_state)
.field("hashes_buffer", &self.hashes_buffer)
.finish()
}
}
Expand Down
51 changes: 41 additions & 10 deletions datafusion/physical-plan/src/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::{Array, ArrayRef, StringViewArray};
use arrow_schema::{DataType, SchemaRef};
// use datafusion_common::hash_utils::create_hashes;
use datafusion_common::hash_utils::{combine_hashes, create_hashes};
use datafusion_common::{DataFusionError, Result};
use datafusion_execution::memory_pool::proxy::RawTableAllocExt;
use datafusion_expr::EmitTo;
Expand Down Expand Up @@ -81,6 +81,11 @@ pub struct GroupValuesRows {
/// [`Row`]: arrow::row::Row
group_values: Option<Rows>,

/// reused buffer to store hashes
final_hash_buffer: Vec<u64>,

tmp_hash_buffer: Vec<u64>,

/// reused buffer to store rows
rows_buffer: Rows,

Expand Down Expand Up @@ -121,21 +126,34 @@ impl GroupValuesRows {
map,
map_size: 0,
group_values: None,
final_hash_buffer: Default::default(),
tmp_hash_buffer: Default::default(),
rows_buffer,
var_len_map,
random_state: Default::default(),
})
}

fn transform_col_to_fixed_len(&mut self, input: &[ArrayRef]) -> Vec<ArrayRef> {
let n_rows = input[0].len();
// 1.1 Calculate the group keys for the group values
let final_hash_buffer = &mut self.final_hash_buffer;
final_hash_buffer.clear();
final_hash_buffer.resize(n_rows, 0);
let tmp_hash_buffer = &mut self.tmp_hash_buffer;
tmp_hash_buffer.clear();
tmp_hash_buffer.resize(n_rows, 0);

let mut cur_var_len_idx = 0;
let transformed_cols: Vec<ArrayRef> = input
.iter()
.map(|c| {
if let DataType::Utf8View = c.data_type() {
create_hashes(&[Arc::clone(c)], &self.random_state, tmp_hash_buffer)
.unwrap();
let mut var_groups = Vec::with_capacity(c.len());
let group_values = &mut self.var_len_map[cur_var_len_idx];
group_values.map.insert_if_new(
group_values.map.insert_if_new_with_hash(
c,
|_value| {
let group_idx = group_values.num_groups;
Expand All @@ -145,12 +163,27 @@ impl GroupValuesRows {
|group_idx| {
var_groups.push(group_idx);
},
tmp_hash_buffer,
);
cur_var_len_idx += 1;
final_hash_buffer
.iter_mut()
.zip(tmp_hash_buffer.iter())
.for_each(|(result, tmp)| {
*result = combine_hashes(*result, *tmp);
});
std::sync::Arc::new(arrow_array::UInt32Array::from(var_groups))
as ArrayRef
} else {
c.clone()
create_hashes(&[Arc::clone(c)], &self.random_state, tmp_hash_buffer)
.unwrap();
final_hash_buffer
.iter_mut()
.zip(tmp_hash_buffer.iter())
.for_each(|(result, tmp)| {
*result = combine_hashes(*result, *tmp);
});
Arc::clone(c)
}
})
.collect();
Expand Down Expand Up @@ -186,7 +219,7 @@ impl GroupValuesRows {
StringViewArray::new_unchecked(
views.into(),
map_content.data_buffers().to_vec(),
map_content.nulls().map(|v| v.clone()),
map_content.nulls().cloned(),
)
};
cur_var_len_idx += 1;
Expand Down Expand Up @@ -216,14 +249,12 @@ impl GroupValues for GroupValuesRows {
// tracks to which group each of the input rows belongs
groups.clear();

for row in group_rows.iter() {
let hash = self.random_state.hash_one(row.as_ref());
let entry = self.map.get_mut(hash, |(_hash, group_idx)| {
for (row, hash) in group_rows.iter().zip(self.final_hash_buffer.iter()) {
let entry = self.map.get_mut(*hash, |(_hash, group_idx)| {
// verify that a group that we are inserting with hash is
// actually the same key value as the group in
// existing_idx (aka group_values @ row)
row.as_ref().len() == group_values.row(*group_idx).as_ref().len()
&& row == group_values.row(*group_idx)
row == group_values.row(*group_idx)
});
let group_idx = match entry {
// Existing group_index for this group value
Expand All @@ -236,7 +267,7 @@ impl GroupValues for GroupValuesRows {

// for hasher function, use precomputed hash value
self.map.insert_accounted(
(hash, group_idx),
(*hash, group_idx),
|(hash, _group_index)| *hash,
&mut self.map_size,
);
Expand Down

0 comments on commit edb5a1f

Please sign in to comment.