Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Sep 17, 2024
1 parent 2108180 commit d9abd3a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 16 deletions.
4 changes: 2 additions & 2 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
exec_context.stream = Some(stream);
} else {
// Pull input batches
pull_input_batches(exec_context)?;
// pull_input_batches(exec_context)?;
}

loop {
Expand Down Expand Up @@ -416,7 +416,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
update_metrics(&mut env, exec_context)?;

// Pull input batches
pull_input_batches(exec_context)?;
// pull_input_batches(exec_context)?;

// Output not ready yet
continue;
Expand Down
78 changes: 64 additions & 14 deletions native/core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use futures::Stream;
use itertools::Itertools;
use std::rc::Rc;
use std::{
any::Any,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};

use crate::{
errors::CometError,
execution::{
Expand All @@ -48,9 +38,22 @@ use datafusion::{
physical_plan::{ExecutionPlan, *},
};
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult};
use futures::{FutureExt, Stream};
use itertools::Itertools;
use jni::objects::JValueGen;
use jni::objects::{GlobalRef, JObject};
use jni::sys::jsize;
use std::future::Future;
use std::rc::Rc;
use std::{
any::Any,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
thread,
};
use tokio::runtime::{Handle, Runtime};
use tokio::task::JoinHandle;

/// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file
/// scan or the result of reading a broadcast or shuffle exchange.
Expand Down Expand Up @@ -163,10 +166,21 @@ impl ScanExec {
Ok(())
}

fn get_next_handle(
exec_context_id: i64,
iter: &JObject<'static>,
num_cols: usize,
) -> JoinHandle<Result<InputBatch, CometError>> {
let handle = Handle::current();
let raw_object = iter.as_raw();
let object = unsafe { JObject::from_raw(raw_object) };
handle.spawn(async move { ScanExec::get_next(exec_context_id, &object, num_cols) })
}

/// Invokes JNI call to get next batch.
fn get_next(
exec_context_id: i64,
iter: &JObject,
iter: &JObject<'static>,
num_cols: usize,
) -> Result<InputBatch, CometError> {
if exec_context_id == TEST_EXEC_CONTEXT_ID {
Expand Down Expand Up @@ -343,6 +357,7 @@ struct ScanStream<'a> {
scan: ScanExec,
/// Schema representing the data
schema: SchemaRef,
handle: Arc<Mutex<Option<JoinHandle<Result<InputBatch, CometError>>>>>,
/// Metrics
baseline_metrics: BaselineMetrics,
/// Cast options
Expand All @@ -358,6 +373,7 @@ impl<'a> ScanStream<'a> {
Self {
scan,
schema,
handle: Arc::new(Mutex::new(None)),
baseline_metrics,
cast_options: CastOptions::default(),
cast_time,
Expand Down Expand Up @@ -399,16 +415,50 @@ impl<'a> ScanStream<'a> {
impl<'a> Stream for ScanStream<'a> {
type Item = DataFusionResult<RecordBatch>;

fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut timer = self.baseline_metrics.elapsed_compute().timer();
let mut scan_batch = self.scan.batch.try_lock().unwrap();

let input_batch = &*scan_batch;
let input_batch = if let Some(batch) = input_batch {
batch
} else {
timer.stop();
return Poll::Pending;
let mut current = self.handle.try_lock().unwrap();

if let Some(ref mut handle) = &mut *current {
if let Poll::Ready(batch) = handle.poll_unpin(cx) {
match batch {
Ok(batch) => {
*current = None;
&batch?
}
Err(e) => {
return Poll::Ready(Some(Err(DataFusionError::Execution(
e.to_string(),
))));
}
}
} else {
return Poll::Pending;
}
/*
let result = handle.join().unwrap();
match result {
Ok(batch) => batch,
Err(e) => {
return Poll::Ready(Some(Err(arrow_datafusion_err!(e))));
}
}
*/
} else {
*current = Some(ScanExec::get_next_handle(
self.scan.exec_context_id,
self.scan.input_source.as_ref().unwrap().as_obj(),
self.scan.data_types.len(),
));

return Poll::Pending;
}
};

let result = match input_batch {
Expand Down

0 comments on commit d9abd3a

Please sign in to comment.