Skip to content

Commit

Permalink
[CORE] ExpandFallbackPolicy should propagate fallback reason to vanil…
Browse files Browse the repository at this point in the history
…la SparkPlan (#5971)
  • Loading branch information
ulysses-you authored Jun 5, 2024
1 parent 59aaa1c commit c9350fb
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,18 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
}
}

private def fallbackToRowBasedPlan(outputsColumnar: Boolean): SparkPlan = {
private def fallbackToRowBasedPlan(glutenPlan: SparkPlan, outputsColumnar: Boolean): SparkPlan = {
// Propagate fallback reason to vanilla SparkPlan
glutenPlan.foreach {
case _: GlutenPlan =>
case p: SparkPlan if TransformHints.isNotTransformable(p) && p.logicalLink.isDefined =>
originalPlan
.find(_.logicalLink.exists(_.fastEquals(p.logicalLink.get)))
.filterNot(TransformHints.isNotTransformable)
.foreach(origin => TransformHints.tag(origin, TransformHints.getHint(p)))
case _ =>
}

val planWithTransitions = Transitions.insertTransitions(originalPlan, outputsColumnar)
planWithTransitions
}
Expand All @@ -259,7 +270,7 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
// Scan Parquet
// |
// ColumnarToRow
val vanillaSparkPlan = fallbackToRowBasedPlan(outputsColumnar)
val vanillaSparkPlan = fallbackToRowBasedPlan(plan, outputsColumnar)
val vanillaSparkTransitionCost = countTransitionCostForVanillaSparkPlan(vanillaSparkPlan)
if (
GlutenConfig.getConf.fallbackPreferColumnar &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import org.apache.gluten.utils.BackendTestUtils

import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.{GlutenSQLTestsTrait, Row}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.ui.{GlutenSQLAppStatusStore, SparkListenerSQLExecutionStart}
import org.apache.spark.status.ElementTrackingStore

Expand Down Expand Up @@ -161,4 +163,67 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp
}
}
}

test("Add logical link to rewritten spark plan") {
val events = new ArrayBuffer[GlutenPlanFallbackEvent]
val listener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
case e: GlutenPlanFallbackEvent => events.append(e)
case _ =>
}
}
}
spark.sparkContext.addSparkListener(listener)
withSQLConf(GlutenConfig.EXPRESSION_BLACK_LIST.key -> "add") {
try {
val df = spark.sql("select sum(id + 1) from range(10)")
df.collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
val project = find(df.queryExecution.executedPlan) {
_.isInstanceOf[ProjectExec]
}
assert(project.isDefined)
assert(
events.exists(_.fallbackNodeToReason.values.toSet
.exists(_.contains("Not supported to map spark function name"))))
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}

test("ExpandFallbackPolicy should propagate fallback reason to vanilla SparkPlan") {
val events = new ArrayBuffer[GlutenPlanFallbackEvent]
val listener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
case e: GlutenPlanFallbackEvent => events.append(e)
case _ =>
}
}
}
spark.sparkContext.addSparkListener(listener)
spark.range(10).selectExpr("id as c1", "id as c2").write.format("parquet").saveAsTable("t")
withTable("t") {
withSQLConf(
GlutenConfig.EXPRESSION_BLACK_LIST.key -> "max",
GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1") {
try {
val df = spark.sql("select c2, max(c1) as id from t group by c2")
df.collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
val agg = collect(df.queryExecution.executedPlan) { case a: HashAggregateExec => a }
assert(agg.size == 2)
assert(
events.count(
_.fallbackNodeToReason.values.toSet.exists(_.contains(
"Could not find a valid substrait mapping name for max"
))) == 2)
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.{GlutenSQLTestsTrait, Row}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.ui.{GlutenSQLAppStatusStore, SparkListenerSQLExecutionStart}
import org.apache.spark.status.ElementTrackingStore

Expand Down Expand Up @@ -168,18 +169,52 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp
withSQLConf(GlutenConfig.EXPRESSION_BLACK_LIST.key -> "add") {
try {
val df = spark.sql("select sum(id + 1) from range(10)")
spark.sparkContext.listenerBus.waitUntilEmpty()
df.collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
val project = find(df.queryExecution.executedPlan) {
_.isInstanceOf[ProjectExec]
}
assert(project.isDefined)
events.exists(
_.fallbackNodeToReason.values.toSet
.contains("Project: Not supported to map spark function name"))
assert(
events.exists(_.fallbackNodeToReason.values.toSet
.exists(_.contains("Not supported to map spark function name"))))
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}

test("ExpandFallbackPolicy should propagate fallback reason to vanilla SparkPlan") {
val events = new ArrayBuffer[GlutenPlanFallbackEvent]
val listener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
case e: GlutenPlanFallbackEvent => events.append(e)
case _ =>
}
}
}
spark.sparkContext.addSparkListener(listener)
spark.range(10).selectExpr("id as c1", "id as c2").write.format("parquet").saveAsTable("t")
withTable("t") {
withSQLConf(
GlutenConfig.EXPRESSION_BLACK_LIST.key -> "max",
GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1") {
try {
val df = spark.sql("select c2, max(c1) as id from t group by c2")
df.collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
val agg = collect(df.queryExecution.executedPlan) { case a: HashAggregateExec => a }
assert(agg.size == 2)
assert(
events.count(
_.fallbackNodeToReason.values.toSet.exists(_.contains(
"Could not find a valid substrait mapping name for max"
))) == 2)
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.{GlutenSQLTestsTrait, Row}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.ui.{GlutenSQLAppStatusStore, SparkListenerSQLExecutionStart}
import org.apache.spark.status.ElementTrackingStore

Expand Down Expand Up @@ -168,18 +169,52 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp
withSQLConf(GlutenConfig.EXPRESSION_BLACK_LIST.key -> "add") {
try {
val df = spark.sql("select sum(id + 1) from range(10)")
spark.sparkContext.listenerBus.waitUntilEmpty()
df.collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
val project = find(df.queryExecution.executedPlan) {
_.isInstanceOf[ProjectExec]
}
assert(project.isDefined)
events.exists(
_.fallbackNodeToReason.values.toSet
.contains("Project: Not supported to map spark function name"))
assert(
events.exists(_.fallbackNodeToReason.values.toSet
.exists(_.contains("Not supported to map spark function name"))))
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}

test("ExpandFallbackPolicy should propagate fallback reason to vanilla SparkPlan") {
val events = new ArrayBuffer[GlutenPlanFallbackEvent]
val listener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
case e: GlutenPlanFallbackEvent => events.append(e)
case _ =>
}
}
}
spark.sparkContext.addSparkListener(listener)
spark.range(10).selectExpr("id as c1", "id as c2").write.format("parquet").saveAsTable("t")
withTable("t") {
withSQLConf(
GlutenConfig.EXPRESSION_BLACK_LIST.key -> "max",
GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1") {
try {
val df = spark.sql("select c2, max(c1) as id from t group by c2")
df.collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
val agg = collect(df.queryExecution.executedPlan) { case a: HashAggregateExec => a }
assert(agg.size == 2)
assert(
events.count(
_.fallbackNodeToReason.values.toSet.exists(_.contains(
"Could not find a valid substrait mapping name for max"
))) == 2)
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}
}
}

0 comments on commit c9350fb

Please sign in to comment.