Skip to content

Commit

Permalink
chore: Revise array import to more follow C Data Interface semantics (a…
Browse files Browse the repository at this point in the history
…pache#905)

* chore: Revise array import to more follow C Data Interface semantics

* more

* fix

* For review

* Try

* check alignment

* Try

* Add comment
  • Loading branch information
viirya authored Sep 6, 2024
1 parent a932cf7 commit 6c2c182
Showing 9 changed files with 162 additions and 71 deletions.
85 changes: 75 additions & 10 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
@@ -19,9 +19,13 @@

package org.apache.comet.vector

import java.nio.ByteOrder

import scala.collection.mutable

import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data}
import org.apache.arrow.c.NativeUtil.NULL
import org.apache.arrow.memory.util.MemoryUtil
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.spark.SparkException
@@ -56,6 +60,37 @@ class NativeUtil {
*/
private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider

/**
* Allocates Arrow structs for the given number of columns.
*
* @param numCols
* the number of columns
* @return
* a pair of Arrow arrays and Arrow schemas
*/
def allocateArrowStructs(numCols: Int): (Array[ArrowArray], Array[ArrowSchema]) = {
val arrays = new Array[ArrowArray](numCols)
val schemas = new Array[ArrowSchema](numCols)

(0 until numCols).foreach { index =>
val arrowSchema = ArrowSchema.allocateNew(allocator)

// Manually fill NULL to `release` slot of ArrowSchema because ArrowSchema doesn't provide
// `markReleased`.
// The total size of ArrowSchema is 72 bytes.
// The `release` slot is at offset 56 in the ArrowSchema struct.
val buffer =
MemoryUtil.directBuffer(arrowSchema.memoryAddress(), 72).order(ByteOrder.nativeOrder)
buffer.putLong(56, NULL);

val arrowArray = ArrowArray.allocateNew(allocator)
arrays(index) = arrowArray
schemas(index) = arrowSchema
}

(arrays, schemas)
}

/**
* Exports a Comet `ColumnarBatch` into a list of memory addresses that can be consumed by the
* native execution.
@@ -101,31 +136,61 @@ class NativeUtil {
batch.numRows()
}

/**
* Gets the next batch from native execution.
*
* @param numOutputCols
* The number of output columns
* @param func
* The function to call to get the next batch
* @return
* The number of row of the next batch, or None if there are no more batches
*/
def getNextBatch(
numOutputCols: Int,
func: (Array[Long], Array[Long]) => Long): Option[ColumnarBatch] = {
val (arrays, schemas) = allocateArrowStructs(numOutputCols)

val arrayAddrs = arrays.map(_.memoryAddress())
val schemaAddrs = schemas.map(_.memoryAddress())

val result = func(arrayAddrs, schemaAddrs)

result match {
case -1 =>
// EOF
None
case numRows =>
val cometVectors = importVector(arrays, schemas)
Some(new ColumnarBatch(cometVectors.toArray, numRows.toInt))
case flag =>
throw new IllegalStateException(s"Invalid native flag: $flag")
}
}

/**
* Imports a list of Arrow addresses from native execution, and return a list of Comet vectors.
*
* @param arrayAddress
* a list containing paris of Arrow addresses from the native, in the format of (address of
* Arrow array, address of Arrow schema)
* @param arrays
* a list of Arrow array
* @param schemas
* a list of Arrow schema
* @return
* a list of Comet vectors
*/
def importVector(arrayAddress: Array[Long]): Seq[CometVector] = {
def importVector(arrays: Array[ArrowArray], schemas: Array[ArrowSchema]): Seq[CometVector] = {
val arrayVectors = mutable.ArrayBuffer.empty[CometVector]

for (i <- arrayAddress.indices by 2) {
val arrowSchema = ArrowSchema.wrap(arrayAddress(i + 1))
val arrowArray = ArrowArray.wrap(arrayAddress(i))
(0 until arrays.length).foreach { i =>
val arrowSchema = schemas(i)
val arrowArray = arrays(i)

// Native execution should always have 'useDecimal128' set to true since it doesn't support
// other cases.
arrayVectors += CometVector.getVector(
importer.importVector(arrowArray, arrowSchema, dictionaryProvider),
true,
dictionaryProvider)

arrowArray.close()
arrowSchema.close()
}
arrayVectors.toSeq
}
71 changes: 37 additions & 34 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
@@ -17,10 +17,7 @@

//! Define JNI APIs which can be called from Java/Scala.
use arrow::{
datatypes::DataType as ArrowDataType,
ffi::{FFI_ArrowArray, FFI_ArrowSchema},
};
use arrow::datatypes::DataType as ArrowDataType;
use arrow_array::RecordBatch;
use datafusion::{
execution::{
@@ -78,8 +75,6 @@ struct ExecutionContext {
pub input_sources: Vec<Arc<GlobalRef>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
/// The FFI arrays. We need to keep them alive here.
pub ffi_arrays: Vec<(Arc<FFI_ArrowArray>, Arc<FFI_ArrowSchema>)>,
/// Configurations for DF execution
pub conf: HashMap<String, String>,
/// The Tokio runtime used for async.
@@ -177,7 +172,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
scans: vec![],
input_sources,
stream: None,
ffi_arrays: vec![],
conf: configs,
runtime,
metrics,
@@ -265,14 +259,33 @@ fn parse_bool(conf: &HashMap<String, String>, name: &str) -> CometResult<bool> {
}

/// Prepares arrow arrays for output.
fn prepare_output(
unsafe fn prepare_output(
env: &mut JNIEnv,
array_addrs: jlongArray,
schema_addrs: jlongArray,
output_batch: RecordBatch,
exec_context: &mut ExecutionContext,
) -> CometResult<jlongArray> {
) -> CometResult<jlong> {
let array_address_array = JLongArray::from_raw(array_addrs);
let num_cols = env.get_array_length(&array_address_array)? as usize;

let array_addrs = env.get_array_elements(&array_address_array, ReleaseMode::NoCopyBack)?;
let array_addrs = &*array_addrs;

let schema_address_array = JLongArray::from_raw(schema_addrs);
let schema_addrs = env.get_array_elements(&schema_address_array, ReleaseMode::NoCopyBack)?;
let schema_addrs = &*schema_addrs;

let results = output_batch.columns();
let num_rows = output_batch.num_rows();

if results.len() != num_cols {
return Err(CometError::Internal(format!(
"Output column count mismatch: expected {num_cols}, got {}",
results.len()
)));
}

if exec_context.debug_native {
// Validate the output arrays.
for array in results.iter() {
@@ -283,35 +296,20 @@ fn prepare_output(
}
}

let return_flag = 1;

let long_array = env.new_long_array((results.len() * 2) as i32 + 2)?;
env.set_long_array_region(&long_array, 0, &[return_flag, num_rows as jlong])?;

let mut arrays = vec![];

let mut i = 0;
while i < results.len() {
let array_ref = results.get(i).ok_or(CometError::IndexOutOfBounds(i))?;
let (array, schema) = array_ref.to_data().to_spark()?;

unsafe {
let arrow_array = Arc::from_raw(array as *const FFI_ArrowArray);
let arrow_schema = Arc::from_raw(schema as *const FFI_ArrowSchema);
arrays.push((arrow_array, arrow_schema));
}
array_ref
.to_data()
.move_to_spark(array_addrs[i], schema_addrs[i])?;

env.set_long_array_region(&long_array, (i * 2) as i32 + 2, &[array, schema])?;
i += 1;
}

// Update metrics
update_metrics(env, exec_context)?;

// Record the pointer to allocated Arrow Arrays
exec_context.ffi_arrays = arrays;

Ok(long_array.into_raw())
Ok(num_rows as jlong)
}

/// Pull the next input from JVM. Note that we cannot pull input batches in
@@ -337,7 +335,9 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
e: JNIEnv,
_class: JClass,
exec_context: jlong,
) -> jlongArray {
array_addrs: jlongArray,
schema_addrs: jlongArray,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
// Retrieve the query
let exec_context = get_execution_context(exec_context);
@@ -383,7 +383,13 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(

match poll_output {
Poll::Ready(Some(output)) => {
return prepare_output(&mut env, output?, exec_context);
return prepare_output(
&mut env,
array_addrs,
schema_addrs,
output?,
exec_context,
);
}
Poll::Ready(None) => {
// Reaches EOF of output.
@@ -399,10 +405,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
}
}

let long_array = env.new_long_array(1)?;
env.set_long_array_region(&long_array, 0, &[-1])?;

return Ok(long_array.into_raw());
return Ok(-1);
}
// A poll pending means there are more than one blocking operators,
// we don't need go back-forth between JVM/Native. Just keeping polling.
24 changes: 24 additions & 0 deletions native/core/src/execution/utils.rs
Original file line number Diff line number Diff line change
@@ -55,6 +55,9 @@ pub trait SparkArrowConvert {
/// Convert Arrow Arrays to C data interface.
/// It returns a tuple (ArrowArray address, ArrowSchema address).
fn to_spark(&self) -> Result<(i64, i64), ExecutionError>;

/// Move Arrow Arrays to C data interface.
fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError>;
}

impl SparkArrowConvert for ArrayData {
@@ -96,6 +99,27 @@ impl SparkArrowConvert for ArrayData {

Ok((array as i64, schema as i64))
}

/// Move this ArrowData to pointers of Arrow C data interface.
fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> {
let array_ptr = array as *mut FFI_ArrowArray;
let schema_ptr = schema as *mut FFI_ArrowSchema;

let array_align = std::mem::align_of::<FFI_ArrowArray>();
let schema_align = std::mem::align_of::<FFI_ArrowSchema>();

// Check if the pointer alignment is correct for `replace`.
if array_ptr.align_offset(array_align) != 0 || schema_ptr.align_offset(schema_align) != 0 {
return Err(ExecutionError::ArrowError(
"Pointer alignment is not correct".to_string(),
));
}

unsafe { std::ptr::replace(array_ptr, FFI_ArrowArray::new(self)) };
unsafe { std::ptr::replace(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?) };

Ok(())
}
}

/// Converts a slice of bytes to i128. The bytes are serialized in big-endian order by
22 changes: 6 additions & 16 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
@@ -43,6 +43,7 @@ import org.apache.comet.vector.NativeUtil
class CometExecIterator(
val id: Long,
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
protobufQueryPlan: Array[Byte],
nativeMetrics: CometMetricNode)
extends Iterator[ColumnarBatch] {
@@ -100,22 +101,11 @@ class CometExecIterator(
}

def getNextBatch(): Option[ColumnarBatch] = {
// we execute the native plan each time we need another output batch and this could
// result in multiple input batches being processed
val result = nativeLib.executePlan(plan)

result(0) match {
case -1 =>
// EOF
None
case 1 =>
val numRows = result(1)
val addresses = result.slice(2, result.length)
val cometVectors = nativeUtil.importVector(addresses)
Some(new ColumnarBatch(cometVectors.toArray, numRows.toInt))
case flag =>
throw new IllegalStateException(s"Invalid native flag: $flag")
}
nativeUtil.getNextBatch(
numOutputCols,
(arrayAddrs, schemaAddrs) => {
nativeLib.executePlan(plan, arrayAddrs, schemaAddrs)
})
}

override def hasNext: Boolean = {
10 changes: 6 additions & 4 deletions spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
@@ -58,12 +58,14 @@ class Native extends NativeBase {
*
* @param plan
* the address to native query plan.
* @param arrayAddrs
* the addresses of Arrow Array structures
* @param schemaAddrs
* the addresses of Arrow Schema structures
* @return
* an array containing: 1) the status flag (1 for normal returned arrays, -1 for end of
* output) 2) (optional) the number of rows if returned flag is 1 3) the addresses of output
* Arrow arrays
* the number of rows, if -1, it means end of the output.
*/
@native def executePlan(plan: Long): Array[Long]
@native def executePlan(plan: Long, arrayAddrs: Array[Long], schemaAddrs: Array[Long]): Long

/**
* Release and drop the native query plan object and context object.
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ object CometExecUtils {
limit: Int): RDD[ColumnarBatch] = {
childPlan.mapPartitionsInternal { iter =>
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get
CometExec.getCometIterator(Seq(iter), limitOp)
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp)
}
}

Original file line number Diff line number Diff line change
@@ -82,7 +82,7 @@ case class CometTakeOrderedAndProjectExec(
CometExecUtils
.getTopKNativePlan(child.output, sortOrder, child, limit)
.get
CometExec.getCometIterator(Seq(iter), topK)
CometExec.getCometIterator(Seq(iter), child.output.length, topK)
}
}

@@ -102,7 +102,7 @@ case class CometTakeOrderedAndProjectExec(
val topKAndProjection = CometExecUtils
.getProjectionNativePlan(projectList, child.output, sortOrder, child, limit)
.get
val it = CometExec.getCometIterator(Seq(iter), topKAndProjection)
val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection)
setSubqueries(it.id, this)

Option(TaskContext.get()).foreach { context =>
Original file line number Diff line number Diff line change
@@ -487,6 +487,7 @@ class CometShuffleWriteProcessor(

val cometIter = CometExec.getCometIterator(
Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
outputAttributes.length,
nativePlan,
nativeMetrics)

Loading

0 comments on commit 6c2c182

Please sign in to comment.