diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala index 6f8d7cde703b..4ee153173c5c 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala @@ -235,7 +235,18 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP } } - private def fallbackToRowBasedPlan(outputsColumnar: Boolean): SparkPlan = { + private def fallbackToRowBasedPlan(glutenPlan: SparkPlan, outputsColumnar: Boolean): SparkPlan = { + // Propagate fallback reason to vanilla SparkPlan + glutenPlan.foreach { + case _: GlutenPlan => + case p: SparkPlan if TransformHints.isNotTransformable(p) && p.logicalLink.isDefined => + originalPlan + .find(_.logicalLink.exists(_.fastEquals(p.logicalLink.get))) + .filterNot(TransformHints.isNotTransformable) + .foreach(origin => TransformHints.tag(origin, TransformHints.getHint(p))) + case _ => + } + val planWithTransitions = Transitions.insertTransitions(originalPlan, outputsColumnar) planWithTransitions } @@ -259,7 +270,7 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP // Scan Parquet // | // ColumnarToRow - val vanillaSparkPlan = fallbackToRowBasedPlan(outputsColumnar) + val vanillaSparkPlan = fallbackToRowBasedPlan(plan, outputsColumnar) val vanillaSparkTransitionCost = countTransitionCostForVanillaSparkPlan(vanillaSparkPlan) if ( GlutenConfig.getConf.fallbackPreferColumnar && diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala index b85dd6a3518e..6860d6a12958 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala @@ -23,7 +23,9 @@ import org.apache.gluten.utils.BackendTestUtils import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql.{GlutenSQLTestsTrait, Row} +import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.ui.{GlutenSQLAppStatusStore, SparkListenerSQLExecutionStart} import org.apache.spark.status.ElementTrackingStore @@ -161,4 +163,67 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp } } } + + test("Add logical link to rewritten spark plan") { + val events = new ArrayBuffer[GlutenPlanFallbackEvent] + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case e: GlutenPlanFallbackEvent => events.append(e) + case _ => + } + } + } + spark.sparkContext.addSparkListener(listener) + withSQLConf(GlutenConfig.EXPRESSION_BLACK_LIST.key -> "add") { + try { + val df = spark.sql("select sum(id + 1) from range(10)") + df.collect() + spark.sparkContext.listenerBus.waitUntilEmpty() + val project = find(df.queryExecution.executedPlan) { + _.isInstanceOf[ProjectExec] + } + assert(project.isDefined) + assert( + events.exists(_.fallbackNodeToReason.values.toSet + .exists(_.contains("Not supported to map spark function name")))) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + } + + test("ExpandFallbackPolicy should propagate fallback reason to vanilla SparkPlan") { + val events = new ArrayBuffer[GlutenPlanFallbackEvent] + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case e: GlutenPlanFallbackEvent => events.append(e) + case _ => + } + } + } + spark.sparkContext.addSparkListener(listener) + spark.range(10).selectExpr("id as c1", "id as c2").write.format("parquet").saveAsTable("t") + withTable("t") { + withSQLConf( + GlutenConfig.EXPRESSION_BLACK_LIST.key -> "max", + GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1") { + try { + val df = spark.sql("select c2, max(c1) as id from t group by c2") + df.collect() + spark.sparkContext.listenerBus.waitUntilEmpty() + val agg = collect(df.queryExecution.executedPlan) { case a: HashAggregateExec => a } + assert(agg.size == 2) + assert( + events.count( + _.fallbackNodeToReason.values.toSet.exists(_.contains( + "Could not find a valid substrait mapping name for max" + ))) == 2) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + } + } } diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala index 9e8c7e54291a..fd6aa047558f 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql.{GlutenSQLTestsTrait, Row} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.ui.{GlutenSQLAppStatusStore, SparkListenerSQLExecutionStart} import org.apache.spark.status.ElementTrackingStore @@ -168,18 +169,52 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp withSQLConf(GlutenConfig.EXPRESSION_BLACK_LIST.key -> "add") { try { val df = spark.sql("select sum(id + 1) from range(10)") - spark.sparkContext.listenerBus.waitUntilEmpty() df.collect() + spark.sparkContext.listenerBus.waitUntilEmpty() val project = find(df.queryExecution.executedPlan) { _.isInstanceOf[ProjectExec] } assert(project.isDefined) - events.exists( - _.fallbackNodeToReason.values.toSet - .contains("Project: Not supported to map spark function name")) + assert( + events.exists(_.fallbackNodeToReason.values.toSet + .exists(_.contains("Not supported to map spark function name")))) } finally { spark.sparkContext.removeSparkListener(listener) } } } + + test("ExpandFallbackPolicy should propagate fallback reason to vanilla SparkPlan") { + val events = new ArrayBuffer[GlutenPlanFallbackEvent] + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case e: GlutenPlanFallbackEvent => events.append(e) + case _ => + } + } + } + spark.sparkContext.addSparkListener(listener) + spark.range(10).selectExpr("id as c1", "id as c2").write.format("parquet").saveAsTable("t") + withTable("t") { + withSQLConf( + GlutenConfig.EXPRESSION_BLACK_LIST.key -> "max", + GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1") { + try { + val df = spark.sql("select c2, max(c1) as id from t group by c2") + df.collect() + spark.sparkContext.listenerBus.waitUntilEmpty() + val agg = collect(df.queryExecution.executedPlan) { case a: HashAggregateExec => a } + assert(agg.size == 2) + assert( + events.count( + _.fallbackNodeToReason.values.toSet.exists(_.contains( + "Could not find a valid substrait mapping name for max" + ))) == 2) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + } + } } diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala index 9e8c7e54291a..fd6aa047558f 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql.{GlutenSQLTestsTrait, Row} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.ui.{GlutenSQLAppStatusStore, SparkListenerSQLExecutionStart} import org.apache.spark.status.ElementTrackingStore @@ -168,18 +169,52 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp withSQLConf(GlutenConfig.EXPRESSION_BLACK_LIST.key -> "add") { try { val df = spark.sql("select sum(id + 1) from range(10)") - spark.sparkContext.listenerBus.waitUntilEmpty() df.collect() + spark.sparkContext.listenerBus.waitUntilEmpty() val project = find(df.queryExecution.executedPlan) { _.isInstanceOf[ProjectExec] } assert(project.isDefined) - events.exists( - _.fallbackNodeToReason.values.toSet - .contains("Project: Not supported to map spark function name")) + assert( + events.exists(_.fallbackNodeToReason.values.toSet + .exists(_.contains("Not supported to map spark function name")))) } finally { spark.sparkContext.removeSparkListener(listener) } } } + + test("ExpandFallbackPolicy should propagate fallback reason to vanilla SparkPlan") { + val events = new ArrayBuffer[GlutenPlanFallbackEvent] + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case e: GlutenPlanFallbackEvent => events.append(e) + case _ => + } + } + } + spark.sparkContext.addSparkListener(listener) + spark.range(10).selectExpr("id as c1", "id as c2").write.format("parquet").saveAsTable("t") + withTable("t") { + withSQLConf( + GlutenConfig.EXPRESSION_BLACK_LIST.key -> "max", + GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1") { + try { + val df = spark.sql("select c2, max(c1) as id from t group by c2") + df.collect() + spark.sparkContext.listenerBus.waitUntilEmpty() + val agg = collect(df.queryExecution.executedPlan) { case a: HashAggregateExec => a } + assert(agg.size == 2) + assert( + events.count( + _.fallbackNodeToReason.values.toSet.exists(_.contains( + "Could not find a valid substrait mapping name for max" + ))) == 2) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + } + } }