diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index e45897a92..ea2bc9727 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -375,7 +375,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( exec_context.stream = Some(stream); } else { // Pull input batches - pull_input_batches(exec_context)?; + // pull_input_batches(exec_context)?; } loop { @@ -416,7 +416,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( update_metrics(&mut env, exec_context)?; // Pull input batches - pull_input_batches(exec_context)?; + // pull_input_batches(exec_context)?; // Output not ready yet continue; diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 7d75f7f1c..18c76a35a 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -15,16 +15,6 @@ // specific language governing permissions and limitations // under the License. -use futures::Stream; -use itertools::Itertools; -use std::rc::Rc; -use std::{ - any::Any, - pin::Pin, - sync::{Arc, Mutex}, - task::{Context, Poll}, -}; - use crate::{ errors::CometError, execution::{ @@ -48,9 +38,22 @@ use datafusion::{ physical_plan::{ExecutionPlan, *}, }; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; +use futures::{FutureExt, Stream}; +use itertools::Itertools; use jni::objects::JValueGen; use jni::objects::{GlobalRef, JObject}; use jni::sys::jsize; +use std::future::Future; +use std::rc::Rc; +use std::{ + any::Any, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + thread, +}; +use tokio::runtime::{Handle, Runtime}; +use tokio::task::JoinHandle; /// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file /// scan or the result of reading a broadcast or shuffle exchange. @@ -163,10 +166,21 @@ impl ScanExec { Ok(()) } + fn get_next_handle( + exec_context_id: i64, + iter: &JObject<'static>, + num_cols: usize, + ) -> JoinHandle> { + let handle = Handle::current(); + let raw_object = iter.as_raw(); + let object = unsafe { JObject::from_raw(raw_object) }; + handle.spawn(async move { ScanExec::get_next(exec_context_id, &object, num_cols) }) + } + /// Invokes JNI call to get next batch. fn get_next( exec_context_id: i64, - iter: &JObject, + iter: &JObject<'static>, num_cols: usize, ) -> Result { if exec_context_id == TEST_EXEC_CONTEXT_ID { @@ -343,6 +357,7 @@ struct ScanStream<'a> { scan: ScanExec, /// Schema representing the data schema: SchemaRef, + handle: Arc>>>>, /// Metrics baseline_metrics: BaselineMetrics, /// Cast options @@ -358,6 +373,7 @@ impl<'a> ScanStream<'a> { Self { scan, schema, + handle: Arc::new(Mutex::new(None)), baseline_metrics, cast_options: CastOptions::default(), cast_time, @@ -399,7 +415,7 @@ impl<'a> ScanStream<'a> { impl<'a> Stream for ScanStream<'a> { type Item = DataFusionResult; - fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut timer = self.baseline_metrics.elapsed_compute().timer(); let mut scan_batch = self.scan.batch.try_lock().unwrap(); @@ -407,8 +423,42 @@ impl<'a> Stream for ScanStream<'a> { let input_batch = if let Some(batch) = input_batch { batch } else { - timer.stop(); - return Poll::Pending; + let mut current = self.handle.try_lock().unwrap(); + + if let Some(ref mut handle) = &mut *current { + if let Poll::Ready(batch) = handle.poll_unpin(cx) { + match batch { + Ok(batch) => { + *current = None; + &batch? + } + Err(e) => { + return Poll::Ready(Some(Err(DataFusionError::Execution( + e.to_string(), + )))); + } + } + } else { + return Poll::Pending; + } + /* + let result = handle.join().unwrap(); + match result { + Ok(batch) => batch, + Err(e) => { + return Poll::Ready(Some(Err(arrow_datafusion_err!(e)))); + } + } + */ + } else { + *current = Some(ScanExec::get_next_handle( + self.scan.exec_context_id, + self.scan.input_source.as_ref().unwrap().as_obj(), + self.scan.data_types.len(), + )); + + return Poll::Pending; + } }; let result = match input_batch {