diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 039538814..4ed88f003 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -347,6 +347,13 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_NATIVE_OPTIMIZER_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.exec.optimizer.enabled") + .internal() + .doc("Enable DataFusion physical optimizer for native plans.") + .booleanConf + .createWithDefault(true) + val COMET_WORKER_THREADS: ConfigEntry[Int] = conf("spark.comet.workerThreads") .internal() diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index b9b882824..6f9717e19 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -58,6 +58,7 @@ use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use datafusion::physical_plan::windows::BoundedWindowAggExec; use datafusion::physical_plan::InputOrderMode; +use datafusion::physical_planner::DefaultPhysicalPlanner; use datafusion::{ arrow::{compute::SortOptions, datatypes::SchemaRef}, common::DataFusionError, @@ -138,18 +139,6 @@ pub struct PhysicalPlanner { session_ctx: Arc, } -impl Default for PhysicalPlanner { - fn default() -> Self { - let session_ctx = Arc::new(SessionContext::new()); - let execution_props = ExecutionProps::new(); - Self { - exec_context_id: TEST_EXEC_CONTEXT_ID, - execution_props, - session_ctx, - } - } -} - impl PhysicalPlanner { pub fn new(session_ctx: Arc) -> Self { let execution_props = ExecutionProps::new(); @@ -1115,6 +1104,17 @@ impl PhysicalPlanner { } } + pub fn optimize_plan( + &self, + plan: Arc, + ) -> Result, ExecutionError> { + // optimize the physical plan + let datafusion_planner = DefaultPhysicalPlanner::default(); + datafusion_planner + .optimize_physical_plan(plan, &self.session_ctx.state(), |_, _| {}) + .map_err(|e| e.into()) + } + fn parse_join_parameters( &self, inputs: &mut Vec>, @@ -1967,10 +1967,12 @@ mod tests { use arrow_array::{DictionaryArray, Int32Array, StringArray}; use arrow_schema::DataType; use datafusion::{physical_plan::common::collect, prelude::SessionContext}; + use datafusion_expr::execution_props::ExecutionProps; use tokio::sync::mpsc; use crate::execution::{datafusion::planner::PhysicalPlanner, operators::InputBatch}; + use crate::execution::datafusion::planner::TEST_EXEC_CONTEXT_ID; use crate::execution::operators::ExecutionError; use datafusion_comet_proto::{ spark_expression::expr::ExprStruct::*, @@ -1979,6 +1981,18 @@ mod tests { spark_operator::{operator::OpStruct, Operator}, }; + impl Default for PhysicalPlanner { + fn default() -> Self { + let session_ctx = Arc::new(SessionContext::default()); + let execution_props = ExecutionProps::new(); + Self { + exec_context_id: TEST_EXEC_CONTEXT_ID, + execution_props, + session_ctx, + } + } + } + #[test] fn test_unpack_dictionary_primitive() { let op_scan = Operator { diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 3ad822cc4..a79c3127b 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -17,11 +17,14 @@ //! Define JNI APIs which can be called from Java/Scala. +use super::{serde, utils::SparkArrowConvert, CometMemoryPool}; use arrow::{ datatypes::DataType as ArrowDataType, ffi::{FFI_ArrowArray, FFI_ArrowSchema}, }; use arrow_array::RecordBatch; +use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::physical_optimizer::projection_pushdown::ProjectionPushdown; use datafusion::{ execution::{ disk_manager::DiskManagerConfig, @@ -42,8 +45,6 @@ use jni::{ }; use std::{collections::HashMap, sync::Arc, task::Poll}; -use super::{serde, utils::SparkArrowConvert, CometMemoryPool}; - use crate::{ errors::{try_unwrap_or_throw, CometError, CometResult}, execution::{ @@ -92,6 +93,8 @@ struct ExecutionContext { pub debug_native: bool, /// Whether to write native plans with metrics to stdout pub explain_native: bool, + /// Whether to enable physical optimizer + pub enable_optimizer: bool, } /// Accept serialized query plan and return the address of the native query plan. @@ -132,6 +135,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( // Whether we've enabled additional debugging on the native side let debug_native = parse_bool(&configs, "debug_native")?; let explain_native = parse_bool(&configs, "explain_native")?; + let enable_optimizer = parse_bool(&configs, "native_optimizer")?; let worker_threads = configs .get("worker_threads") @@ -184,6 +188,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( session_ctx: Arc::new(session), debug_native, explain_native, + enable_optimizer, }); Ok(Box::into_raw(exec_context) as i64) @@ -249,7 +254,14 @@ fn prepare_datafusion_session_context( let runtime = RuntimeEnv::new(rt_config).unwrap(); - let mut session_ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime)); + let state = SessionStateBuilder::new() + .with_config(session_config) + .with_runtime_env(Arc::new(runtime)) + .with_default_features() + .with_physical_optimizer_rules(vec![Arc::new(ProjectionPushdown::new())]) + .build(); + + let mut session_ctx = SessionContext::new_with_state(state); datafusion_functions_nested::register_all(&mut session_ctx)?; @@ -355,6 +367,13 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( &mut exec_context.input_sources.clone(), )?; + // optimize the physical plan + let root_op = if exec_context.enable_optimizer { + planner.optimize_plan(root_op)? + } else { + root_op + }; + exec_context.root_op = Some(Arc::clone(&root_op)); exec_context.scans = scans; diff --git a/native/core/src/execution/operators/copy.rs b/native/core/src/execution/operators/copy.rs index d6c095a77..5350747ed 100644 --- a/native/core/src/execution/operators/copy.rs +++ b/native/core/src/execution/operators/copy.rs @@ -116,10 +116,9 @@ impl ExecutionPlan for CopyExec { self: Arc, children: Vec>, ) -> DataFusionResult> { - let input = Arc::clone(&self.input); - let new_input = input.with_new_children(children)?; + assert!(children.len() == 1); Ok(Arc::new(CopyExec { - input: new_input, + input: Arc::clone(&children[0]), schema: Arc::clone(&self.schema), cache: self.cache.clone(), metrics: self.metrics.clone(), diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index dcdc8ae92..d79099f40 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -23,7 +23,7 @@ import org.apache.spark._ import org.apache.spark.sql.comet.CometMetricNode import org.apache.spark.sql.vectorized._ -import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION, COMET_EXPLAIN_NATIVE_ENABLED, COMET_WORKER_THREADS} +import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION, COMET_EXPLAIN_NATIVE_ENABLED, COMET_NATIVE_OPTIMIZER_ENABLED, COMET_WORKER_THREADS} import org.apache.comet.vector.NativeUtil /** @@ -86,6 +86,7 @@ class CometExecIterator( result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get())) result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get())) result.put("explain_native", String.valueOf(COMET_EXPLAIN_NATIVE_ENABLED.get())) + result.put("native_optimizer", String.valueOf(COMET_NATIVE_OPTIMIZER_ENABLED.get())) result.put("worker_threads", String.valueOf(COMET_WORKER_THREADS.get())) result.put("blocking_threads", String.valueOf(COMET_BLOCKING_THREADS.get()))