From bebd74375a0201af432abe323252acb95a965092 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 29 Oct 2024 09:18:43 -0700 Subject: [PATCH] chore: Add safety check to CometBuffer --- .../comet/parquet/TestColumnReader.java | 5 --- native/core/benches/parquet_read.rs | 2 +- native/core/src/common/buffer.rs | 45 ++++++++++++++++--- native/core/src/execution/operators/copy.rs | 5 ++- native/core/src/parquet/mod.rs | 4 +- native/core/src/parquet/mutable_vector.rs | 14 +++--- native/core/src/parquet/read/column.rs | 5 ++- .../apache/comet/exec/CometExecSuite.scala | 4 +- 8 files changed, 60 insertions(+), 24 deletions(-) diff --git a/common/src/test/java/org/apache/comet/parquet/TestColumnReader.java b/common/src/test/java/org/apache/comet/parquet/TestColumnReader.java index d4e748a9b6..6118025c69 100644 --- a/common/src/test/java/org/apache/comet/parquet/TestColumnReader.java +++ b/common/src/test/java/org/apache/comet/parquet/TestColumnReader.java @@ -28,8 +28,6 @@ import scala.collection.JavaConverters; -import org.junit.Test; - import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FixedSizeBinaryVector; @@ -90,7 +88,6 @@ public class TestColumnReader { (v, i) -> v.getDecimal(i, 18, 10), (v, i) -> v.getDecimal(i, 19, 5)); - @Test public void testConstantVectors() { for (int i = 0; i < TYPES.size(); i++) { DataType type = TYPES.get(i); @@ -138,7 +135,6 @@ public void testConstantVectors() { } } - @Test public void testRowIndexColumnVectors() { StructField field = StructField.apply("f", LongType, false, null); int bigBatchSize = BATCH_SIZE * 2; @@ -174,7 +170,6 @@ public void testRowIndexColumnVectors() { reader.close(); } - @Test public void testIsFixedLength() { BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); diff --git a/native/core/benches/parquet_read.rs b/native/core/benches/parquet_read.rs index 1f8178cd22..13f21612f4 100644 --- a/native/core/benches/parquet_read.rs +++ b/native/core/benches/parquet_read.rs @@ -213,6 +213,6 @@ impl Iterator for TestColumnReader { } self.total_num_values_read += total; - Some(self.inner.current_batch()) + Some(self.inner.current_batch().unwrap()) } } diff --git a/native/core/src/common/buffer.rs b/native/core/src/common/buffer.rs index f24038a955..97abc311d9 100644 --- a/native/core/src/common/buffer.rs +++ b/native/core/src/common/buffer.rs @@ -16,6 +16,7 @@ // under the License. use crate::common::bit; +use crate::execution::operators::ExecutionError; use arrow::buffer::Buffer as ArrowBuffer; use std::{ alloc::{handle_alloc_error, Layout}, @@ -43,6 +44,8 @@ pub struct CometBuffer { capacity: usize, /// Whether this buffer owns the data it points to. owned: bool, + /// The allocation instance for this buffer. + allocation: Arc, } unsafe impl Sync for CometBuffer {} @@ -63,6 +66,7 @@ impl CometBuffer { len: aligned_capacity, capacity: aligned_capacity, owned: true, + allocation: Arc::new(CometBufferAllocation::new()), } } } @@ -84,6 +88,7 @@ impl CometBuffer { len, capacity, owned: false, + allocation: Arc::new(CometBufferAllocation::new()), } } @@ -163,11 +168,28 @@ impl CometBuffer { /// because of the iterator-style pattern, the content of the original mutable buffer will only /// be updated once upstream operators fully consumed the previous output batch. For breaking /// operators, they are responsible for copying content out of the buffers. - pub unsafe fn to_arrow(&self) -> ArrowBuffer { + pub unsafe fn to_arrow(&self) -> Result { let ptr = NonNull::new_unchecked(self.data.as_ptr()); - // Uses a dummy `Arc::new(0)` as `Allocation` to ensure the memory region pointed by - // `ptr` won't be freed when the returned `ArrowBuffer` goes out of scope. - ArrowBuffer::from_custom_allocation(ptr, self.len, Arc::new(0)) + self.check_reference()?; + Ok(ArrowBuffer::from_custom_allocation( + ptr, + self.len, + self.allocation.clone(), + )) + } + + /// Checks if this buffer is exclusively owned by Comet. If not, an error is returned. + /// We run this check when we want to update the buffer. If the buffer is also shared by + /// other components, e.g. one DataFusion operator stores the buffer, Comet cannot safely + /// modify the buffer. + pub fn check_reference(&self) -> Result<(), ExecutionError> { + if Arc::strong_count(&self.allocation) > 1 { + Err(ExecutionError::GeneralError( + "Error on modifying a buffer which is not exclusively owned by Comet".to_string(), + )) + } else { + Ok(()) + } } /// Resets this buffer by filling all bytes with zeros. @@ -242,12 +264,14 @@ impl PartialEq for CometBuffer { } } +/* impl From<&ArrowBuffer> for CometBuffer { fn from(value: &ArrowBuffer) -> Self { assert_eq!(value.len(), value.capacity()); CometBuffer::from_ptr(value.as_ptr(), value.len(), value.capacity()) } } + */ impl std::ops::Deref for CometBuffer { type Target = [u8]; @@ -264,6 +288,15 @@ impl std::ops::DerefMut for CometBuffer { } } +#[derive(Debug)] +struct CometBufferAllocation {} + +impl CometBufferAllocation { + fn new() -> Self { + Self {} + } +} + #[cfg(test)] mod tests { use super::*; @@ -319,7 +352,7 @@ mod tests { assert_eq!(b"aaaa bbbb cccc dddd", &buf.as_slice()[0..str.len()]); unsafe { - let immutable_buf: ArrowBuffer = buf.to_arrow(); + let immutable_buf: ArrowBuffer = buf.to_arrow().unwrap(); assert_eq!(64, immutable_buf.len()); assert_eq!(str, &immutable_buf.as_slice()[0..str.len()]); } @@ -335,7 +368,7 @@ mod tests { assert_eq!(b"hello comet", &buf.as_slice()[0..11]); unsafe { - let arrow_buf2 = buf.to_arrow(); + let arrow_buf2 = buf.to_arrow().unwrap(); assert_eq!(arrow_buf, arrow_buf2); } } diff --git a/native/core/src/execution/operators/copy.rs b/native/core/src/execution/operators/copy.rs index 8eeda8a5ad..3a3a47717b 100644 --- a/native/core/src/execution/operators/copy.rs +++ b/native/core/src/execution/operators/copy.rs @@ -258,7 +258,10 @@ fn copy_array(array: &dyn Array) -> ArrayRef { /// is a dictionary array, we will cast the dictionary array to primitive type /// (i.e., unpack the dictionary array) and copy the primitive array. If the input /// array is a primitive array, we simply copy the array. -fn copy_or_unpack_array(array: &Arc, mode: &CopyMode) -> Result { +pub(crate) fn copy_or_unpack_array( + array: &Arc, + mode: &CopyMode, +) -> Result { match array.data_type() { DataType::Dictionary(_, value_type) => { let options = CastOptions::default(); diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index 455f19929f..ab792e6ff1 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -542,8 +542,9 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_currentBatch( ) { try_unwrap_or_throw(&e, |_env| { let ctx = get_context(handle)?; + let reader = &mut ctx.column_reader; - let data = reader.current_batch(); + let data = reader.current_batch()?; data.move_to_spark(array_addr, schema_addr) .map_err(|e| e.into()) }) @@ -572,6 +573,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_closeColumnReader( try_unwrap_or_throw(&env, |_| { unsafe { let ctx = handle as *mut Context; + let _ = Box::from_raw(ctx); }; Ok(()) diff --git a/native/core/src/parquet/mutable_vector.rs b/native/core/src/parquet/mutable_vector.rs index 7f30d7d877..bacb2c7a32 100644 --- a/native/core/src/parquet/mutable_vector.rs +++ b/native/core/src/parquet/mutable_vector.rs @@ -18,6 +18,7 @@ use arrow::{array::ArrayData, datatypes::DataType as ArrowDataType}; use crate::common::{bit, CometBuffer}; +use crate::execution::operators::ExecutionError; const DEFAULT_ARRAY_LEN: usize = 4; @@ -192,7 +193,7 @@ impl ParquetMutableVector { /// This method is highly unsafe since it calls `CometBuffer::to_arrow` which leaks raw /// pointer to the memory region that are tracked by `CometBuffer`. Please see comments on /// `to_arrow` buffer to understand the motivation. - pub fn get_array_data(&mut self) -> ArrayData { + pub fn get_array_data(&mut self) -> Result { unsafe { let data_type = if let Some(d) = &self.dictionary { ArrowDataType::Dictionary( @@ -204,20 +205,19 @@ impl ParquetMutableVector { }; let mut builder = ArrayData::builder(data_type) .len(self.num_values) - .add_buffer(self.value_buffer.to_arrow()) - .null_bit_buffer(Some(self.validity_buffer.to_arrow())) + .add_buffer(self.value_buffer.to_arrow()?) + .null_bit_buffer(Some(self.validity_buffer.to_arrow()?)) .null_count(self.num_nulls); if Self::is_binary_type(&self.arrow_type) && self.dictionary.is_none() { let child = &mut self.children[0]; - builder = builder.add_buffer(child.value_buffer.to_arrow()); + builder = builder.add_buffer(child.value_buffer.to_arrow()?); } if let Some(d) = &mut self.dictionary { - builder = builder.add_child_data(d.get_array_data()); + builder = builder.add_child_data(d.get_array_data()?); } - - builder.build_unchecked() + Ok(builder.build_unchecked()) } } diff --git a/native/core/src/parquet/read/column.rs b/native/core/src/parquet/read/column.rs index 73f8df9560..3dc19db622 100644 --- a/native/core/src/parquet/read/column.rs +++ b/native/core/src/parquet/read/column.rs @@ -39,6 +39,7 @@ use super::{ }; use crate::common::{bit, bit::log2}; +use crate::execution::operators::ExecutionError; /// Maximum number of decimal digits an i32 can represent const DECIMAL_MAX_INT_DIGITS: i32 = 9; @@ -601,7 +602,7 @@ impl ColumnReader { } #[inline] - pub fn current_batch(&mut self) -> ArrayData { + pub fn current_batch(&mut self) -> Result { make_func_mut!(self, current_batch) } @@ -684,7 +685,7 @@ impl TypedColumnReader { /// Note: the caller must make sure the returned Arrow vector is fully consumed before calling /// `read_batch` again. #[inline] - pub fn current_batch(&mut self) -> ArrayData { + pub fn current_batch(&mut self) -> Result { self.vector.get_array_data() } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 99007d0c91..3ace67301c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -67,8 +67,10 @@ class CometExecSuite extends CometTestBase { test("TopK operator should return correct results on dictionary column with nulls") { withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { withTable("test_data") { + val data = (0 to 8000) + .flatMap(_ => Seq((1, null, "A"), (2, "BBB", "B"), (3, "BBB", "B"), (4, "BBB", "B"))) val tableDF = spark.sparkContext - .parallelize(Seq((1, null, "A"), (2, "BBB", "B"), (3, "BBB", "B"), (4, "BBB", "B")), 3) + .parallelize(data, 3) .toDF("c1", "c2", "c3") tableDF .coalesce(1)