From 3df75ac8eb1f51c17a8efbde0713b34382997ea8 Mon Sep 17 00:00:00 2001 From: kamille Date: Sat, 19 Oct 2024 18:36:44 +0800 Subject: [PATCH] impl basic append_batch --- .../src/aggregates/group_values/column.rs | 16 +- .../aggregates/group_values/group_column.rs | 209 ++++++++++-------- 2 files changed, 119 insertions(+), 106 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/column.rs b/datafusion/physical-plan/src/aggregates/group_values/column.rs index 8e90f883668d..26e7c22dcbbb 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/column.rs @@ -305,17 +305,13 @@ impl GroupValues for GroupValuesColumn { // 1.4 Vectorized append values for col_idx in 0..cols.len() { - let col_nullable = self.column_nullables_buffer[col_idx]; + let all_non_null = !self.column_nullables_buffer[col_idx]; let group_value = &mut self.group_values[col_idx]; - if col_nullable { - for &row in self.append_rows_buffer.iter() { - group_value.append_val(&cols[col_idx], row); - } - } else { - for &row in self.append_rows_buffer.iter() { - group_value.append_non_nullable_val(&cols[col_idx], row); - } - } + group_value.append_batch( + &cols[col_idx], + &self.append_rows_buffer, + all_non_null, + ); } Ok(()) diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs index 4a9b42d6f45f..a6961835edef 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs @@ -22,6 +22,7 @@ use arrow::array::GenericBinaryArray; use arrow::array::GenericStringArray; use arrow::array::OffsetSizeTrait; use arrow::array::PrimitiveArray; +use arrow::array::StringViewBuilder; use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray}; use arrow::buffer::OffsetBuffer; use arrow::buffer::ScalarBuffer; @@ -29,9 +30,11 @@ use arrow::datatypes::ByteArrayType; use arrow::datatypes::ByteViewType; use arrow::datatypes::DataType; use arrow::datatypes::GenericBinaryType; +use arrow_array::GenericByteArray; use arrow_array::GenericByteViewArray; use arrow_buffer::Buffer; use datafusion_common::utils::proxy::VecAllocExt; +use datafusion_expr::sqlparser::keywords::NULLABLE; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow_array::types::GenericStringType; @@ -59,9 +62,7 @@ pub trait GroupColumn: Send + Sync { /// Appends the row at `row` in `array` to this builder fn append_val(&mut self, array: &ArrayRef, row: usize); - fn append_non_nullable_val(&mut self, array: &ArrayRef, row: usize); - - fn append_batch(&mut self, array: &ArrayRef, rows: &[usize]); + fn append_batch(&mut self, array: &ArrayRef, rows: &[usize], all_non_null: bool); /// Returns the number of rows stored in this builder fn len(&self) -> usize; @@ -86,8 +87,6 @@ pub trait GroupColumn: Send + Sync { pub struct PrimitiveGroupValueBuilder { group_values: Vec, nulls: MaybeNullBufferBuilder, - nullable_call: usize, - non_nullable_call: usize, } impl PrimitiveGroupValueBuilder @@ -99,8 +98,6 @@ where Self { group_values: vec![], nulls: MaybeNullBufferBuilder::new(), - nullable_call: 0, - non_nullable_call: 0, } } } @@ -122,9 +119,35 @@ impl GroupColumn self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) } - fn append_batch(&mut self, array: &ArrayRef, rows: &[usize]) { - todo!() - } + fn append_batch(&mut self, array: &ArrayRef, rows: &[usize], all_non_null: bool) { + let arr = array.as_primitive::(); + match (NULLABLE, all_non_null) { + (true, true) => { + self.nulls.append_n(rows.len(), false); + self.group_values.reserve(rows.len()); + for &row in rows { + self.group_values.push(arr.value(row)); + } + } + (true, false) => { + for &row in rows { + if array.is_null(row) { + self.nulls.append(true); + self.group_values.push(T::default_value()); + } else { + self.nulls.append(false); + self.group_values.push(arr.value(row)); + } + } + } + (false, _) => { + self.group_values.reserve(rows.len()); + for &row in rows { + self.group_values.push(arr.value(row)); + } + } + } + } fn append_val(&mut self, array: &ArrayRef, row: usize) { // Perf: skip null check if input can't have nulls @@ -141,15 +164,6 @@ impl GroupColumn } } - fn append_non_nullable_val(&mut self, array: &ArrayRef, row: usize) { - if NULLABLE { - self.nulls.append(false); - self.group_values.push(array.as_primitive::().value(row)); - } else { - self.group_values.push(array.as_primitive::().value(row)); - } - } - fn len(&self) -> usize { self.group_values.len() } @@ -162,14 +176,8 @@ impl GroupColumn let Self { group_values, nulls, - nullable_call, - non_nullable_call, } = *self; - println!( - "### nullable_call:{nullable_call}, non_nullable_call:{non_nullable_call}" - ); - let nulls = nulls.build(); if !NULLABLE { assert!(nulls.is_none(), "unexpected nulls in non nullable input"); @@ -213,10 +221,6 @@ where offsets: Vec, /// Nulls nulls: MaybeNullBufferBuilder, - - nullable_call: usize, - - non_nullable_call: usize, } impl ByteGroupValueBuilder @@ -229,8 +233,36 @@ where buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), offsets: vec![O::default()], nulls: MaybeNullBufferBuilder::new(), - nullable_call: 0, - non_nullable_call: 0, + } + } + + fn append_batch_inner( + &mut self, + array: &ArrayRef, + rows: &[usize], + all_non_null: bool, + ) where + B: ByteArrayType, + { + let arr = array.as_bytes::(); + + if all_non_null { + self.nulls.append_n(rows.len(), false); + for &row in rows { + self.append_value(arr, row); + } + } else { + for &row in rows { + if arr.is_null(row) { + self.nulls.append(true); + // nulls need a zero length in the offset buffer + let offset = self.buffer.len(); + self.offsets.push(O::usize_as(offset)); + } else { + self.nulls.append(false); + self.append_value(arr, row); + } + } } } @@ -238,7 +270,6 @@ where where B: ByteArrayType, { - self.nullable_call += 1; let arr = array.as_bytes::(); if arr.is_null(row) { self.nulls.append(true); @@ -247,20 +278,15 @@ where self.offsets.push(O::usize_as(offset)); } else { self.nulls.append(false); - let value: &[u8] = arr.value(row).as_ref(); - self.buffer.append_slice(value); - self.offsets.push(O::usize_as(self.buffer.len())); + self.append_value(arr, row); } } - fn append_non_nullable_val_inner(&mut self, array: &ArrayRef, row: usize) + fn append_value(&mut self, array: &GenericByteArray, row: usize) where B: ByteArrayType, { - self.non_nullable_call += 1; - let arr = array.as_bytes::(); - self.nulls.append(false); - let value: &[u8] = arr.value(row).as_ref(); + let value: &[u8] = array.value(row).as_ref(); self.buffer.append_slice(value); self.offsets.push(O::usize_as(self.buffer.len())); } @@ -313,28 +339,35 @@ where } } - fn append_val(&mut self, column: &ArrayRef, row: usize) { - // Sanity array type + fn append_batch(&mut self, column: &ArrayRef, rows: &[usize], all_non_null: bool) { match self.output_type { OutputType::Binary => { debug_assert!(matches!( column.data_type(), DataType::Binary | DataType::LargeBinary )); - self.append_val_inner::>(column, row) + self.append_batch_inner::>( + column, + rows, + all_non_null, + ) } OutputType::Utf8 => { debug_assert!(matches!( column.data_type(), DataType::Utf8 | DataType::LargeUtf8 )); - self.append_val_inner::>(column, row) + self.append_batch_inner::>( + column, + rows, + all_non_null, + ) } _ => unreachable!("View types should use `ArrowBytesViewMap`"), }; } - fn append_non_nullable_val(&mut self, column: &ArrayRef, row: usize) { + fn append_val(&mut self, column: &ArrayRef, row: usize) { // Sanity array type match self.output_type { OutputType::Binary => { @@ -342,14 +375,14 @@ where column.data_type(), DataType::Binary | DataType::LargeBinary )); - self.append_non_nullable_val_inner::>(column, row) + self.append_val_inner::>(column, row) } OutputType::Utf8 => { debug_assert!(matches!( column.data_type(), DataType::Utf8 | DataType::LargeUtf8 )); - self.append_non_nullable_val_inner::>(column, row) + self.append_val_inner::>(column, row) } _ => unreachable!("View types should use `ArrowBytesViewMap`"), }; @@ -371,14 +404,8 @@ where mut buffer, offsets, nulls, - nullable_call, - non_nullable_call, } = *self; - println!( - "### nullable_call:{nullable_call}, non_nullable_call:{non_nullable_call}" - ); - let null_buffer = nulls.build(); // SAFETY: the offsets were constructed correctly in `insert_if_new` -- @@ -498,10 +525,6 @@ pub struct ByteViewGroupValueBuilder { /// Nulls nulls: MaybeNullBufferBuilder, - nullable_call: usize, - - non_nullable_call: usize, - /// phantom data so the type requires `` _phantom: PhantomData, } @@ -515,8 +538,6 @@ impl ByteViewGroupValueBuilder { max_block_size: BYTE_VIEW_MAX_BLOCK_SIZE, nulls: MaybeNullBufferBuilder::new(), _phantom: PhantomData {}, - nullable_call: 0, - non_nullable_call: 0, } } @@ -526,11 +547,34 @@ impl ByteViewGroupValueBuilder { self } - fn append_val_inner(&mut self, array: &ArrayRef, row: usize) - where - B: ByteViewType, - { - self.nullable_call += 1; + fn append_batch_inner( + &mut self, + array: &ArrayRef, + rows: &[usize], + all_non_null: bool, + ) { + let arr = array.as_byte_view::(); + + if all_non_null { + self.nulls.append_n(rows.len(), false); + for &row in rows { + self.append_value(arr, row); + } + } else { + for &row in rows { + // Null row case, set and return + if arr.is_valid(row) { + self.nulls.append(false); + self.append_value(arr, row); + } else { + self.nulls.append(true); + self.views.push(0); + } + } + } + } + + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) { let arr = array.as_byte_view::(); // Null row case, set and return @@ -542,37 +586,14 @@ impl ByteViewGroupValueBuilder { // Not null row case self.nulls.append(false); - let value: &[u8] = arr.value(row).as_ref(); - - let value_len = value.len(); - let view = if value_len <= 12 { - make_view(value, 0, 0) - } else { - // Ensure big enough block to hold the value firstly - self.ensure_in_progress_big_enough(value_len); - - // Append value - let buffer_index = self.completed.len(); - let offset = self.in_progress.len(); - self.in_progress.extend_from_slice(value); - - make_view(value, buffer_index as u32, offset as u32) - }; - - // Append view - self.views.push(view); + self.append_value(arr, row); } - fn append_val_non_nullable_inner(&mut self, array: &ArrayRef, row: usize) + fn append_value(&mut self, array: &GenericByteViewArray, row: usize) where B: ByteViewType, { - self.non_nullable_call += 1; - let arr = array.as_byte_view::(); - - // Not null row case - self.nulls.append(false); - let value: &[u8] = arr.value(row).as_ref(); + let value: &[u8] = array.value(row).as_ref(); let value_len = value.len(); let view = if value_len <= 12 { @@ -703,10 +724,6 @@ impl ByteViewGroupValueBuilder { let views = ScalarBuffer::from(views); - println!( - "### nullable_call:{}, non_nullable_call:{}", - self.nullable_call, self.non_nullable_call - ); // Safety: // * all views were correctly made // * (if utf8): Input was valid Utf8 so buffer contents are @@ -892,8 +909,8 @@ impl GroupColumn for ByteViewGroupValueBuilder { self.append_val_inner(array, row) } - fn append_non_nullable_val(&mut self, array: &ArrayRef, row: usize) { - self.append_val_non_nullable_inner(array, row); + fn append_batch(&mut self, array: &ArrayRef, rows: &[usize], all_non_null: bool) { + self.append_batch_inner(array, rows, all_non_null); } fn len(&self) -> usize {