Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use global tokio runtime per executor #1104

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -423,22 +423,6 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(false)

val COMET_WORKER_THREADS: ConfigEntry[Int] =
conf("spark.comet.workerThreads")
.internal()
.doc("The number of worker threads used for Comet native execution. " +
"By default, this config is 4.")
.intConf
.createWithDefault(4)

val COMET_BLOCKING_THREADS: ConfigEntry[Int] =
conf("spark.comet.blockingThreads")
.internal()
.doc("The number of blocking threads used for Comet native execution. " +
"By default, this config is 10.")
.intConf
.createWithDefault(10)

val COMET_BATCH_SIZE: ConfigEntry[Int] = conf("spark.comet.batchSize")
.doc("The columnar batch size, i.e., the maximum number of rows that a batch can contain.")
.intConf
Expand Down
28 changes: 15 additions & 13 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ use crate::execution::datafusion::spark_plan::SparkPlan;
use crate::execution::operators::ScanExec;
use log::info;

use once_cell::sync::Lazy;

static TOKIO_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this should be once per session (a user could modify the config in the session or run different queries in a single job that might require independent configuration)?

tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("Failed to create Tokio runtime")
});

/// Function to get a handle to the global Tokio runtime
pub fn get_runtime() -> &'static Runtime {
&TOKIO_RUNTIME
}

/// Comet native execution context. Kept alive across JNI calls.
struct ExecutionContext {
/// The id of the execution context.
Expand All @@ -77,8 +91,6 @@ struct ExecutionContext {
pub input_sources: Vec<Arc<GlobalRef>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
/// The Tokio runtime used for async.
pub runtime: Runtime,
/// Native metrics
pub metrics: Arc<GlobalRef>,
/// The time it took to create the native plan and configure the context
Expand Down Expand Up @@ -108,8 +120,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
batch_size: jint,
debug_native: jboolean,
explain_native: jboolean,
worker_threads: jint,
blocking_threads: jint,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
// Init JVM classes
Expand All @@ -123,13 +133,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
// Deserialize query plan
let spark_plan = serde::deserialize_op(bytes.as_slice())?;

// Use multi-threaded tokio runtime to prevent blocking spawned tasks if any
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(worker_threads as usize)
.max_blocking_threads(blocking_threads as usize)
.enable_all()
.build()?;

let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?);

// Get the global references of input sources
Expand Down Expand Up @@ -158,7 +161,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
scans: vec![],
input_sources,
stream: None,
runtime,
metrics,
plan_creation_time,
session_ctx: Arc::new(session),
Expand Down Expand Up @@ -344,7 +346,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
loop {
// Polling the stream.
let next_item = exec_context.stream.as_mut().unwrap().next();
let poll_output = exec_context.runtime.block_on(async { poll!(next_item) });
let poll_output = get_runtime().block_on(async { poll!(next_item) });

match poll_output {
Poll::Ready(Some(output)) => {
Expand Down
6 changes: 2 additions & 4 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark._
import org.apache.spark.sql.comet.CometMetricNode
import org.apache.spark.sql.vectorized._

import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, COMET_DEBUG_ENABLED, COMET_EXPLAIN_NATIVE_ENABLED, COMET_WORKER_THREADS}
import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, COMET_EXPLAIN_NATIVE_ENABLED}
import org.apache.comet.vector.NativeUtil

/**
Expand Down Expand Up @@ -68,9 +68,7 @@ class CometExecIterator(
new CometTaskMemoryManager(id),
batchSize = COMET_BATCH_SIZE.get(),
debug = COMET_DEBUG_ENABLED.get(),
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
workerThreads = COMET_WORKER_THREADS.get(),
blockingThreads = COMET_BLOCKING_THREADS.get())
explain = COMET_EXPLAIN_NATIVE_ENABLED.get())
}

private var nextBatch: Option[ColumnarBatch] = None
Expand Down
4 changes: 1 addition & 3 deletions spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ class Native extends NativeBase {
taskMemoryManager: CometTaskMemoryManager,
batchSize: Int,
debug: Boolean,
explain: Boolean,
workerThreads: Int,
blockingThreads: Int): Long
explain: Boolean): Long

/**
* Execute a native query plan based on given input Arrow arrays.
Expand Down
Loading