diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 89f79c9cdf..18b4d20d62 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -47,10 +47,10 @@ class NativeUtil { * a list containing number of rows + pairs of memory addresses in the format of (address of * Arrow array, address of Arrow schema) */ - def exportBatch(batch: ColumnarBatch): Array[Long] = { - val exportedVectors = mutable.ArrayBuffer.empty[Long] - exportedVectors += batch.numRows() - + def exportBatch( + arrayAddrs: Array[Long], + schemaAddrs: Array[Long], + batch: ColumnarBatch): Int = { (0 until batch.numCols()).foreach { index => batch.column(index) match { case a: CometVector => @@ -62,17 +62,16 @@ class NativeUtil { null } - val arrowSchema = ArrowSchema.allocateNew(allocator) - val arrowArray = ArrowArray.allocateNew(allocator) + // The array and schema structures are allocated by native side. + // Don't need to deallocate them here. + val arrowSchema = ArrowSchema.wrap(schemaAddrs(index)) + val arrowArray = ArrowArray.wrap(arrayAddrs(index)) Data.exportVector( allocator, getFieldVector(valueVector, "export"), provider, arrowArray, arrowSchema) - - exportedVectors += arrowArray.memoryAddress() - exportedVectors += arrowSchema.memoryAddress() case c => throw new SparkException( "Comet execution only takes Arrow Arrays, but got " + @@ -80,7 +79,7 @@ class NativeUtil { } } - exportedVectors.toArray + batch.numRows() } /** diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 59616efbb2..6a4ef22f10 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -28,8 +28,9 @@ use itertools::Itertools; use arrow::compute::{cast_with_options, CastOptions}; use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions}; use arrow_data::ArrayData; +use arrow_data::ffi::FFI_ArrowArray; use arrow_schema::{DataType, Field, Schema, SchemaRef}; - +use arrow_schema::ffi::FFI_ArrowSchema; use crate::{ errors::CometError, execution::{ @@ -46,9 +47,10 @@ use datafusion::{ }; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; use jni::{ - objects::{GlobalRef, JLongArray, JObject, ReleaseMode}, - sys::jlongArray, + objects::{GlobalRef, JObject}, }; +use jni::objects::JValueGen; +use jni::sys::jsize; /// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file /// scan or the result of reading a broadcast or shuffle exchange. @@ -86,7 +88,7 @@ impl ScanExec { // may end up either unpacking dictionary arrays or dictionary-encoding arrays. // Dictionary-encoded primitive arrays are always unpacked. let first_batch = if let Some(input_source) = input_source.as_ref() { - ScanExec::get_next(exec_context_id, input_source.as_obj())? + ScanExec::get_next(exec_context_id, input_source.as_obj(), data_types.len())? } else { InputBatch::EOF }; @@ -153,6 +155,7 @@ impl ScanExec { let next_batch = ScanExec::get_next( self.exec_context_id, self.input_source.as_ref().unwrap().as_obj(), + self.data_types.len(), )?; *current_batch = Some(next_batch); } @@ -161,7 +164,7 @@ impl ScanExec { } /// Invokes JNI call to get next batch. - fn get_next(exec_context_id: i64, iter: &JObject) -> Result { + fn get_next(exec_context_id: i64, iter: &JObject, num_cols: usize) -> Result { if exec_context_id == TEST_EXEC_CONTEXT_ID { // This is a unit test. We don't need to call JNI. return Ok(InputBatch::EOF); @@ -175,49 +178,57 @@ impl ScanExec { } let mut env = JVMClasses::get_env()?; - let batch_object: JObject = unsafe { - jni_call!(&mut env, - comet_batch_iterator(iter).next() -> JObject)? - }; - if batch_object.is_null() { - return Err(CometError::from(ExecutionError::GeneralError(format!( - "Null batch object. Plan id: {}", - exec_context_id - )))); + let mut array_addrs = Vec::with_capacity(num_cols); + let mut schema_addrs = Vec::with_capacity(num_cols); + + for _ in 0..num_cols { + let arrow_array = Arc::new(FFI_ArrowArray::empty()); + let arrow_schema = Arc::new(FFI_ArrowSchema::empty()); + let (array_ptr, schema_ptr) = (Arc::into_raw(arrow_array) as i64, Arc::into_raw(arrow_schema) as i64); + + array_addrs.push(array_ptr); + schema_addrs.push(schema_ptr); } - let batch_object = unsafe { JLongArray::from_raw(batch_object.as_raw() as jlongArray) }; + // Prepare the java array parameters + let long_array_addrs = env.new_long_array(num_cols as jsize)?; + let long_schema_addrs = env.new_long_array(num_cols as jsize)?; - let addresses = unsafe { env.get_array_elements(&batch_object, ReleaseMode::NoCopyBack)? }; + env.set_long_array_region(&long_array_addrs, 0, &array_addrs)?; + env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?; - // First element is the number of rows. - let num_rows = unsafe { *addresses.as_ptr() as i64 }; + let array_obj = JObject::from(long_array_addrs); + let schema_obj = JObject::from(long_schema_addrs); - if num_rows < 0 { - return Ok(InputBatch::EOF); - } + let array_obj = JValueGen::Object(array_obj.as_ref()); + let schema_obj = JValueGen::Object(schema_obj.as_ref()); + + let num_rows: i32 = unsafe { + jni_call!(&mut env, + comet_batch_iterator(iter).next(array_obj, schema_obj) -> i32)? + }; - let array_num = addresses.len() - 1; - if array_num % 2 != 0 { - return Err(CometError::Internal(format!( - "Invalid number of Arrow Array addresses: {}", - array_num - ))); + if num_rows == -1 { + return Ok(InputBatch::EOF); } - let num_arrays = array_num / 2; - let array_elements = unsafe { addresses.as_ptr().add(1) }; - let mut inputs: Vec = Vec::with_capacity(num_arrays); + let mut inputs: Vec = Vec::with_capacity(num_cols); - for i in 0..num_arrays { - let array_ptr = unsafe { *(array_elements.add(i * 2)) }; - let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) }; + for i in 0..num_cols { + let array_ptr = array_addrs[i]; + let schema_ptr = schema_addrs[i]; let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?; // TODO: validate array input data inputs.push(make_array(array_data)); + + // Drop the Arcs to avoid memory leak + unsafe { + Arc::from_raw(array_ptr as *const FFI_ArrowArray); + Arc::from_raw(schema_ptr as *const FFI_ArrowSchema); + } } Ok(InputBatch::new(inputs, Some(num_rows as usize))) diff --git a/native/core/src/jvm_bridge/batch_iterator.rs b/native/core/src/jvm_bridge/batch_iterator.rs index 06f43a8ce4..b2f1190753 100644 --- a/native/core/src/jvm_bridge/batch_iterator.rs +++ b/native/core/src/jvm_bridge/batch_iterator.rs @@ -21,6 +21,7 @@ use jni::{ signature::ReturnType, JNIEnv, }; +use jni::signature::Primitive; /// A struct that holds all the JNI methods and fields for JVM `CometBatchIterator` class. pub struct CometBatchIterator<'a> { @@ -37,8 +38,8 @@ impl<'a> CometBatchIterator<'a> { Ok(CometBatchIterator { class, - method_next: env.get_method_id(Self::JVM_CLASS, "next", "()[J")?, - method_next_ret: ReturnType::Array, + method_next: env.get_method_id(Self::JVM_CLASS, "next", "([J[J)I")?, + method_next_ret: ReturnType::Primitive(Primitive::Int), }) } } diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java b/spark/src/main/java/org/apache/comet/CometBatchIterator.java index 33603290ce..accd57c208 100644 --- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java +++ b/spark/src/main/java/org/apache/comet/CometBatchIterator.java @@ -40,17 +40,19 @@ public class CometBatchIterator { } /** - * Get the next batches of Arrow arrays. It will consume input iterator and return Arrow arrays by - * addresses. If the input iterator is done, it will return a one negative element array - * indicating the end of the iterator. + * Get the next batches of Arrow arrays. + * + * @param arrayAddrs The addresses of the ArrowArray structures. + * @param schemaAddrs The addresses of the ArrowSchema structures. + * @return the number of rows of the current batch. -1 if there is no more batch. */ - public long[] next() { + public int next(long[] arrayAddrs, long[] schemaAddrs) { boolean hasBatch = input.hasNext(); if (!hasBatch) { - return new long[] {-1}; + return -1; } - return nativeUtil.exportBatch(input.next()); + return nativeUtil.exportBatch(arrayAddrs, schemaAddrs, input.next()); } }