From 86167fd8cb64cd4884246eb0274dcb0336bf4587 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 17 Mar 2024 17:16:31 -0700 Subject: [PATCH] feat: Remove COMET_EXEC_BROADCAST_ENABLED --- .../scala/org/apache/comet/CometConf.scala | 9 +++-- .../comet/CometSparkSessionExtensions.scala | 38 ++++++++++++------- .../apache/comet/exec/CometExecSuite.scala | 4 +- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index de49fdfb0b..bd2e04d0ca 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -139,12 +139,13 @@ object CometConf { .booleanConf .createWithDefault(false) - val COMET_EXEC_BROADCAST_ENABLED: ConfigEntry[Boolean] = + val COMET_EXEC_BROADCAST_FORCE_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.broadcast.enabled") .doc( - "Whether to enable broadcasting for Comet native operators. By default, " + - "this config is false. Note that this feature is not fully supported yet " + - "and only enabled for test purpose.") + "Whether to force enabling broadcasting for Comet native operators. By default, " + + "this config is false. Comet broadcast feature will be enabled automatically by " + + "Comet extension. But for unit tests, we need this feature to force enabling it " + + "for invalid cases. So this config is only used for unit test.") .booleanConf .createWithDefault(false) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 87c2265fcb..2d426fde0a 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} @@ -42,7 +41,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported} +import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde @@ -368,15 +367,26 @@ class CometSparkSessionExtensions u } - case b: BroadcastExchangeExec - if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") && - isCometBroadCastEnabled(conf) => - QueryPlanSerde.operator2Proto(b) match { - case Some(nativeOp) => - val cometOp = CometBroadcastExchangeExec(b, b.child) - CometSinkPlaceHolder(nativeOp, b, cometOp) - case None => b + // `CometBroadcastExchangeExec`'s broadcast output is not compatible with Spark's broadcast + // exchange. It is only used for Comet native execution. We only transform Spark broadcast + // exchange to Comet broadcast exchange if its downstream is a Comet native plan or if the + // broadcast exchange is forced to be enabled by Comet config. + case plan + if (isCometNative(plan) || isCometBroadCastForceEnabled(conf)) && + plan.children.exists(_.isInstanceOf[BroadcastExchangeExec]) => + val newChildren = plan.children.map { + case b: BroadcastExchangeExec + if isCometNative(b.child) && + isCometOperatorEnabled(conf, "broadcastExchangeExec") => + QueryPlanSerde.operator2Proto(b) match { + case Some(nativeOp) => + val cometOp = CometBroadcastExchangeExec(b, b.child) + CometSinkPlaceHolder(nativeOp, b, cometOp) + case None => b + } + case other => other } + plan.withNewChildren(newChildren) // Native shuffle for Comet operators case s: ShuffleExchangeExec @@ -547,11 +557,13 @@ object CometSparkSessionExtensions extends Logging { private[comet] def isCometOperatorEnabled(conf: SQLConf, operator: String): Boolean = { val operatorFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.enabled" - conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) + val operatorDisabledFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.disabled" + conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) && + !conf.getConfString(operatorDisabledFlag, "false").toBoolean } - private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = { - COMET_EXEC_BROADCAST_ENABLED.get(conf) + private[comet] def isCometBroadCastForceEnabled(conf: SQLConf): Boolean = { + COMET_EXEC_BROADCAST_FORCE_ENABLED.get(conf) } private[comet] def isCometShuffleEnabled(conf: SQLConf): Boolean = diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 6a34d4fe4a..1081acac67 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -73,7 +73,7 @@ class CometExecSuite extends CometTestBase { test("CometBroadcastExchangeExec") { assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") - withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") { + withSQLConf(CometConf.COMET_EXEC_BROADCAST_FORCE_ENABLED.key -> "true") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_b") { val df = sql( @@ -99,7 +99,7 @@ class CometExecSuite extends CometTestBase { test("CometBroadcastExchangeExec: empty broadcast") { assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") - withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") { + withSQLConf(CometConf.COMET_EXEC_BROADCAST_FORCE_ENABLED.key -> "true") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_b") { val df = sql(