Skip to content

Commit

Permalink
chore: Revise batch pull approach to more follow C Data interface spec
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Aug 30, 2024
1 parent e57ead4 commit ca7f7d2
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 51 deletions.
19 changes: 9 additions & 10 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ class NativeUtil {
* a list containing number of rows + pairs of memory addresses in the format of (address of
* Arrow array, address of Arrow schema)
*/
def exportBatch(batch: ColumnarBatch): Array[Long] = {
val exportedVectors = mutable.ArrayBuffer.empty[Long]
exportedVectors += batch.numRows()

def exportBatch(
arrayAddrs: Array[Long],
schemaAddrs: Array[Long],
batch: ColumnarBatch): Int = {
(0 until batch.numCols()).foreach { index =>
batch.column(index) match {
case a: CometVector =>
Expand All @@ -62,25 +62,24 @@ class NativeUtil {
null
}

val arrowSchema = ArrowSchema.allocateNew(allocator)
val arrowArray = ArrowArray.allocateNew(allocator)
// The array and schema structures are allocated by native side.
// Don't need to deallocate them here.
val arrowSchema = ArrowSchema.wrap(schemaAddrs(index))
val arrowArray = ArrowArray.wrap(arrayAddrs(index))
Data.exportVector(
allocator,
getFieldVector(valueVector, "export"),
provider,
arrowArray,
arrowSchema)

exportedVectors += arrowArray.memoryAddress()
exportedVectors += arrowSchema.memoryAddress()
case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}

exportedVectors.toArray
batch.numRows()
}

/**
Expand Down
77 changes: 44 additions & 33 deletions native/core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ use itertools::Itertools;
use arrow::compute::{cast_with_options, CastOptions};
use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions};
use arrow_data::ArrayData;
use arrow_data::ffi::FFI_ArrowArray;
use arrow_schema::{DataType, Field, Schema, SchemaRef};

use arrow_schema::ffi::FFI_ArrowSchema;
use crate::{
errors::CometError,
execution::{
Expand All @@ -46,9 +47,10 @@ use datafusion::{
};
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult};
use jni::{
objects::{GlobalRef, JLongArray, JObject, ReleaseMode},
sys::jlongArray,
objects::{GlobalRef, JObject},
};
use jni::objects::JValueGen;
use jni::sys::jsize;

/// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file
/// scan or the result of reading a broadcast or shuffle exchange.
Expand Down Expand Up @@ -86,7 +88,7 @@ impl ScanExec {
// may end up either unpacking dictionary arrays or dictionary-encoding arrays.
// Dictionary-encoded primitive arrays are always unpacked.
let first_batch = if let Some(input_source) = input_source.as_ref() {
ScanExec::get_next(exec_context_id, input_source.as_obj())?
ScanExec::get_next(exec_context_id, input_source.as_obj(), data_types.len())?
} else {
InputBatch::EOF
};
Expand Down Expand Up @@ -153,6 +155,7 @@ impl ScanExec {
let next_batch = ScanExec::get_next(
self.exec_context_id,
self.input_source.as_ref().unwrap().as_obj(),
self.data_types.len(),
)?;
*current_batch = Some(next_batch);
}
Expand All @@ -161,7 +164,7 @@ impl ScanExec {
}

/// Invokes JNI call to get next batch.
fn get_next(exec_context_id: i64, iter: &JObject) -> Result<InputBatch, CometError> {
fn get_next(exec_context_id: i64, iter: &JObject, num_cols: usize) -> Result<InputBatch, CometError> {
if exec_context_id == TEST_EXEC_CONTEXT_ID {
// This is a unit test. We don't need to call JNI.
return Ok(InputBatch::EOF);
Expand All @@ -175,49 +178,57 @@ impl ScanExec {
}

let mut env = JVMClasses::get_env()?;
let batch_object: JObject = unsafe {
jni_call!(&mut env,
comet_batch_iterator(iter).next() -> JObject)?
};

if batch_object.is_null() {
return Err(CometError::from(ExecutionError::GeneralError(format!(
"Null batch object. Plan id: {}",
exec_context_id
))));
let mut array_addrs = Vec::with_capacity(num_cols);
let mut schema_addrs = Vec::with_capacity(num_cols);

for _ in 0..num_cols {
let arrow_array = Arc::new(FFI_ArrowArray::empty());
let arrow_schema = Arc::new(FFI_ArrowSchema::empty());
let (array_ptr, schema_ptr) = (Arc::into_raw(arrow_array) as i64, Arc::into_raw(arrow_schema) as i64);

array_addrs.push(array_ptr);
schema_addrs.push(schema_ptr);
}

let batch_object = unsafe { JLongArray::from_raw(batch_object.as_raw() as jlongArray) };
// Prepare the java array parameters
let long_array_addrs = env.new_long_array(num_cols as jsize)?;
let long_schema_addrs = env.new_long_array(num_cols as jsize)?;

let addresses = unsafe { env.get_array_elements(&batch_object, ReleaseMode::NoCopyBack)? };
env.set_long_array_region(&long_array_addrs, 0, &array_addrs)?;
env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?;

// First element is the number of rows.
let num_rows = unsafe { *addresses.as_ptr() as i64 };
let array_obj = JObject::from(long_array_addrs);
let schema_obj = JObject::from(long_schema_addrs);

if num_rows < 0 {
return Ok(InputBatch::EOF);
}
let array_obj = JValueGen::Object(array_obj.as_ref());
let schema_obj = JValueGen::Object(schema_obj.as_ref());

let num_rows: i32 = unsafe {
jni_call!(&mut env,
comet_batch_iterator(iter).next(array_obj, schema_obj) -> i32)?
};

let array_num = addresses.len() - 1;
if array_num % 2 != 0 {
return Err(CometError::Internal(format!(
"Invalid number of Arrow Array addresses: {}",
array_num
)));
if num_rows == -1 {
return Ok(InputBatch::EOF);
}

let num_arrays = array_num / 2;
let array_elements = unsafe { addresses.as_ptr().add(1) };
let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_arrays);
let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_cols);

for i in 0..num_arrays {
let array_ptr = unsafe { *(array_elements.add(i * 2)) };
let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) };
for i in 0..num_cols {
let array_ptr = array_addrs[i];
let schema_ptr = schema_addrs[i];
let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?;

// TODO: validate array input data

inputs.push(make_array(array_data));

// Drop the Arcs to avoid memory leak
unsafe {
Arc::from_raw(array_ptr as *const FFI_ArrowArray);
Arc::from_raw(schema_ptr as *const FFI_ArrowSchema);
}
}

Ok(InputBatch::new(inputs, Some(num_rows as usize)))
Expand Down
5 changes: 3 additions & 2 deletions native/core/src/jvm_bridge/batch_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use jni::{
signature::ReturnType,
JNIEnv,
};
use jni::signature::Primitive;

/// A struct that holds all the JNI methods and fields for JVM `CometBatchIterator` class.
pub struct CometBatchIterator<'a> {
Expand All @@ -37,8 +38,8 @@ impl<'a> CometBatchIterator<'a> {

Ok(CometBatchIterator {
class,
method_next: env.get_method_id(Self::JVM_CLASS, "next", "()[J")?,
method_next_ret: ReturnType::Array,
method_next: env.get_method_id(Self::JVM_CLASS, "next", "([J[J)I")?,
method_next_ret: ReturnType::Primitive(Primitive::Int),
})
}
}
14 changes: 8 additions & 6 deletions spark/src/main/java/org/apache/comet/CometBatchIterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,19 @@ public class CometBatchIterator {
}

/**
* Get the next batches of Arrow arrays. It will consume input iterator and return Arrow arrays by
* addresses. If the input iterator is done, it will return a one negative element array
* indicating the end of the iterator.
* Get the next batches of Arrow arrays.
*
* @param arrayAddrs The addresses of the ArrowArray structures.
* @param schemaAddrs The addresses of the ArrowSchema structures.
* @return the number of rows of the current batch. -1 if there is no more batch.
*/
public long[] next() {
public int next(long[] arrayAddrs, long[] schemaAddrs) {
boolean hasBatch = input.hasNext();

if (!hasBatch) {
return new long[] {-1};
return -1;
}

return nativeUtil.exportBatch(input.next());
return nativeUtil.exportBatch(arrayAddrs, schemaAddrs, input.next());
}
}

0 comments on commit ca7f7d2

Please sign in to comment.