Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] Minor refactors on ColumnarRuleApplier #6086

Merged
merged 4 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,44 @@
*/
package org.apache.gluten.extension.columnar

import org.apache.gluten.GlutenConfig
import org.apache.gluten.metrics.GlutenTimeMetric
import org.apache.gluten.utils.LogLevelUtil

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.execution.SparkPlan

trait ColumnarRuleApplier {
def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan
}

object ColumnarRuleApplier {
class Executor(phase: String, rules: Seq[Rule[SparkPlan]]) extends RuleExecutor[SparkPlan] {
private val batch: Batch =
Batch(s"Columnar (Phase [$phase])", Once, rules.map(r => new LoggedRule(r)): _*)

// TODO Remove this exclusion then pass Spark's idempotence check.
override protected val excludedOnceBatches: Set[String] = Set(batch.name)

override protected def batches: Seq[Batch] = List(batch)
}

private class LoggedRule(delegate: Rule[SparkPlan])
extends Rule[SparkPlan]
with Logging
with LogLevelUtil {
// Columnar plan change logging added since https://github.com/apache/incubator-gluten/pull/456.
private val transformPlanLogLevel = GlutenConfig.getConf.transformPlanLogLevel
override val ruleName: String = delegate.ruleName

override def apply(plan: SparkPlan): SparkPlan = GlutenTimeMetric.withMillisTime {
logOnLevel(
transformPlanLogLevel,
s"Preparing to apply rule $ruleName on plan:\n${plan.toString}")
val out = delegate.apply(plan)
logOnLevel(transformPlanLogLevel, s"Plan after applied rule $ruleName:\n${plan.toString}")
out
}(t => logOnLevel(transformPlanLogLevel, s"Applying rule $ruleName took $t ms."))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@ import org.apache.gluten.extension.columnar._
import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions}
import org.apache.gluten.extension.columnar.util.AdaptiveContext
import org.apache.gluten.metrics.GlutenTimeMetric
import org.apache.gluten.utils.{LogLevelUtil, PhysicalPlanSelector}

import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter, SparkPlan}
import org.apache.spark.util.SparkRuleUtil

Expand All @@ -47,41 +46,26 @@ class EnumeratedApplier(session: SparkSession)
with LogLevelUtil {
// An empirical value.
private val aqeStackTraceIndex = 16

private lazy val transformPlanLogLevel = GlutenConfig.getConf.transformPlanLogLevel
private lazy val planChangeLogger = new PlanChangeLogger[SparkPlan]()

private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex)

override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan =
PhysicalPlanSelector.maybe(session, plan) {
val transformed = transformPlan(transformRules(outputsColumnar), plan, "transform")
val transformed =
transformPlan("transform", transformRules(outputsColumnar).map(_(session)), plan)
val postPlan = maybeAqe {
transformPlan(postRules(), transformed, "post")
transformPlan("post", postRules().map(_(session)), transformed)
}
val finalPlan = transformPlan(finalRules(), postPlan, "final")
val finalPlan = transformPlan("final", finalRules().map(_(session)), postPlan)
finalPlan
}

private def transformPlan(
getRules: List[SparkSession => Rule[SparkPlan]],
plan: SparkPlan,
step: String) = GlutenTimeMetric.withMillisTime {
logOnLevel(
transformPlanLogLevel,
s"${step}ColumnarTransitions preOverriden plan:\n${plan.toString}")
val overridden = getRules.foldLeft(plan) {
(p, getRule) =>
val rule = getRule(session)
val newPlan = rule(p)
planChangeLogger.logRule(rule.ruleName, p, newPlan)
newPlan
}
logOnLevel(
transformPlanLogLevel,
s"${step}ColumnarTransitions afterOverriden plan:\n${overridden.toString}")
overridden
}(t => logOnLevel(transformPlanLogLevel, s"${step}Transform SparkPlan took: $t ms."))
phase: String,
rules: Seq[Rule[SparkPlan]],
plan: SparkPlan): SparkPlan = {
val executor = new ColumnarRuleApplier.Executor(phase, rules)
executor.execute(plan)
}

private def maybeAqe[T](f: => T): T = {
adaptiveContext.setAdaptiveContext()
Expand All @@ -96,7 +80,7 @@ class EnumeratedApplier(session: SparkSession)
* Rules to let planner create a suggested Gluten plan being sent to `fallbackPolicies` in which
* the plan will be breakdown and decided to be fallen back or not.
*/
private def transformRules(outputsColumnar: Boolean): List[SparkSession => Rule[SparkPlan]] = {
private def transformRules(outputsColumnar: Boolean): Seq[SparkSession => Rule[SparkPlan]] = {
List(
(_: SparkSession) => RemoveTransitions,
(spark: SparkSession) => FallbackOnANSIMode(spark),
Expand Down Expand Up @@ -126,7 +110,7 @@ class EnumeratedApplier(session: SparkSession)
* Rules applying to non-fallen-back Gluten plans. To do some post cleanup works on the plan to
* make sure it be able to run and be compatible with Spark's execution engine.
*/
private def postRules(): List[SparkSession => Rule[SparkPlan]] =
private def postRules(): Seq[SparkSession => Rule[SparkPlan]] =
List(
(s: SparkSession) => RemoveTopmostColumnarToRow(s, adaptiveContext.isAdaptiveContext())) :::
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarPostRules() :::
Expand All @@ -137,7 +121,7 @@ class EnumeratedApplier(session: SparkSession)
* Rules consistently applying to all input plans after all other rules have been applied, despite
* whether the input plan is fallen back or not.
*/
private def finalRules(): List[SparkSession => Rule[SparkPlan]] = {
private def finalRules(): Seq[SparkSession => Rule[SparkPlan]] = {
List(
// The rule is required despite whether the stage is fallen back or not. Since
// ColumnarCachedBatchSerializer is statically registered to Spark without a columnar rule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTable
import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions}
import org.apache.gluten.extension.columnar.util.AdaptiveContext
import org.apache.gluten.metrics.GlutenTimeMetric
import org.apache.gluten.utils.{LogLevelUtil, PhysicalPlanSelector}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter, SparkPlan}
import org.apache.spark.util.SparkRuleUtil

Expand All @@ -42,54 +41,39 @@ class HeuristicApplier(session: SparkSession)
with LogLevelUtil {
// This is an empirical value, may need to be changed for supporting other versions of spark.
private val aqeStackTraceIndex = 19

private lazy val transformPlanLogLevel = GlutenConfig.getConf.transformPlanLogLevel
private lazy val planChangeLogger = new PlanChangeLogger[SparkPlan]()

private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex)

override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan =
override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = {
withTransformRules(transformRules(outputsColumnar)).apply(plan)
}

// Visible for testing.
def withTransformRules(transformRules: List[SparkSession => Rule[SparkPlan]]): Rule[SparkPlan] =
def withTransformRules(transformRules: Seq[SparkSession => Rule[SparkPlan]]): Rule[SparkPlan] =
plan =>
PhysicalPlanSelector.maybe(session, plan) {
val finalPlan = prepareFallback(plan) {
p =>
val suggestedPlan = transformPlan(transformRules, p, "transform")
transformPlan(fallbackPolicies(), suggestedPlan, "fallback") match {
val suggestedPlan = transformPlan("transform", transformRules.map(_(session)), p)
transformPlan("fallback", fallbackPolicies().map(_(session)), suggestedPlan) match {
case FallbackNode(fallbackPlan) =>
// we should use vanilla c2r rather than native c2r,
// and there should be no `GlutenPlan` any more,
// so skip the `postRules()`.
fallbackPlan
case plan =>
transformPlan(postRules(), plan, "post")
transformPlan("post", postRules().map(_(session)), plan)
}
}
transformPlan(finalRules(), finalPlan, "final")
transformPlan("final", finalRules().map(_(session)), finalPlan)
}

private def transformPlan(
getRules: List[SparkSession => Rule[SparkPlan]],
plan: SparkPlan,
step: String) = GlutenTimeMetric.withMillisTime {
logOnLevel(
transformPlanLogLevel,
s"${step}ColumnarTransitions preOverridden plan:\n${plan.toString}")
val overridden = getRules.foldLeft(plan) {
(p, getRule) =>
val rule = getRule(session)
val newPlan = rule(p)
planChangeLogger.logRule(rule.ruleName, p, newPlan)
newPlan
}
logOnLevel(
transformPlanLogLevel,
s"${step}ColumnarTransitions afterOverridden plan:\n${overridden.toString}")
overridden
}(t => logOnLevel(transformPlanLogLevel, s"${step}Transform SparkPlan took: $t ms."))
phase: String,
rules: Seq[Rule[SparkPlan]],
plan: SparkPlan): SparkPlan = {
val executor = new ColumnarRuleApplier.Executor(phase, rules)
executor.execute(plan)
}

private def prepareFallback[T](plan: SparkPlan)(f: SparkPlan => T): T = {
adaptiveContext.setAdaptiveContext()
Expand All @@ -106,7 +90,7 @@ class HeuristicApplier(session: SparkSession)
* Rules to let planner create a suggested Gluten plan being sent to `fallbackPolicies` in which
* the plan will be breakdown and decided to be fallen back or not.
*/
private def transformRules(outputsColumnar: Boolean): List[SparkSession => Rule[SparkPlan]] = {
private def transformRules(outputsColumnar: Boolean): Seq[SparkSession => Rule[SparkPlan]] = {
List(
(_: SparkSession) => RemoveTransitions,
(spark: SparkSession) => FallbackOnANSIMode(spark),
Expand Down Expand Up @@ -138,7 +122,7 @@ class HeuristicApplier(session: SparkSession)
* Rules to add wrapper `FallbackNode`s on top of the input plan, as hints to make planner fall
* back the whole input plan to the original vanilla Spark plan.
*/
private def fallbackPolicies(): List[SparkSession => Rule[SparkPlan]] = {
private def fallbackPolicies(): Seq[SparkSession => Rule[SparkPlan]] = {
List(
(_: SparkSession) =>
ExpandFallbackPolicy(adaptiveContext.isAdaptiveContext(), adaptiveContext.originalPlan()))
Expand All @@ -148,7 +132,7 @@ class HeuristicApplier(session: SparkSession)
* Rules applying to non-fallen-back Gluten plans. To do some post cleanup works on the plan to
* make sure it be able to run and be compatible with Spark's execution engine.
*/
private def postRules(): List[SparkSession => Rule[SparkPlan]] =
private def postRules(): Seq[SparkSession => Rule[SparkPlan]] =
List(
(s: SparkSession) => RemoveTopmostColumnarToRow(s, adaptiveContext.isAdaptiveContext())) :::
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarPostRules() :::
Expand All @@ -159,7 +143,7 @@ class HeuristicApplier(session: SparkSession)
* Rules consistently applying to all input plans after all other rules have been applied, despite
* whether the input plan is fallen back or not.
*/
private def finalRules(): List[SparkSession => Rule[SparkPlan]] = {
private def finalRules(): Seq[SparkSession => Rule[SparkPlan]] = {
List(
// The rule is required despite whether the stage is fallen back or not. Since
// ColumnarCachedBatchSerializer is statically registered to Spark without a columnar rule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,9 @@ package org.apache.spark.sql

import org.apache.spark.{SparkContext, Success, TaskKilled}
import org.apache.spark.executor.ExecutorMetrics
import org.apache.spark.scheduler.{
SparkListener,
SparkListenerExecutorMetricsUpdate,
SparkListenerTaskEnd,
SparkListenerTaskStart
}
import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorMetricsUpdate, SparkListenerTaskEnd, SparkListenerTaskStart}
import org.apache.spark.sql.KillTaskListener.INIT_WAIT_TIME_MS
import org.apache.spark.sql.catalyst.QueryPlanningTracker

import com.google.common.base.Preconditions
import org.apache.commons.lang3.RandomUtils
Expand All @@ -50,7 +46,8 @@ object SparkQueryRunner {
"ProcessTreePythonVMemory",
"ProcessTreePythonRSSMemory",
"ProcessTreeOtherVMemory",
"ProcessTreeOtherRSSMemory")
"ProcessTreeOtherRSSMemory"
)

def runQuery(
spark: SparkSession,
Expand Down Expand Up @@ -82,25 +79,33 @@ object SparkQueryRunner {

println(s"Executing SQL query from resource path $queryPath...")
try {
val tracker = new QueryPlanningTracker
val sql = resourceToString(queryPath)
val prev = System.nanoTime()
val df = spark.sql(sql)
val rows = df.collect()
val rows = QueryPlanningTracker.withTracker(tracker) {
df.collect()
}
if (explain) {
df.explain(extended = true)
}
val planMillis =
df.queryExecution.tracker.phases.values.map(p => p.endTimeMs - p.startTimeMs).sum
val sparkTracker = df.queryExecution.tracker
val sparkRulesMillis =
sparkTracker.rules.map(_._2.totalTimeNs).sum / 1000000L
val otherRulesMillis =
tracker.rules.map(_._2.totalTimeNs).sum / 1000000L
val planMillis = sparkRulesMillis + otherRulesMillis
val totalMillis = (System.nanoTime() - prev) / 1000000L
val collectedMetrics = metrics.map(name => (name, em.getMetricValue(name))).toMap
RunResult(rows, planMillis, totalMillis - planMillis, collectedMetrics)
} finally {
sc.removeSparkListener(metricsListener)
killTaskListener.foreach(l => {
sc.removeSparkListener(l)
println(s"Successful kill rate ${"%.2f%%"
.format(100 * l.successfulKillRate())} during execution of app: ${sc.applicationId}")
})
killTaskListener.foreach(
l => {
sc.removeSparkListener(l)
println(s"Successful kill rate ${"%.2f%%"
.format(100 * l.successfulKillRate())} during execution of app: ${sc.applicationId}")
})
sc.setJobDescription(null)
}
}
Expand Down Expand Up @@ -166,7 +171,8 @@ class KillTaskListener(val sc: SparkContext) extends SparkListener {
val total = Math.min(
stageKillMaxWaitTimeLookup.computeIfAbsent(taskStart.stageId, _ => Long.MaxValue),
stageKillWaitTimeLookup
.computeIfAbsent(taskStart.stageId, _ => INIT_WAIT_TIME_MS))
.computeIfAbsent(taskStart.stageId, _ => INIT_WAIT_TIME_MS)
)
val elapsed = System.currentTimeMillis() - startMs
val remaining = total - elapsed
if (remaining <= 0L) {
Expand All @@ -180,6 +186,7 @@ class KillTaskListener(val sc: SparkContext) extends SparkListener {
}
throw new IllegalStateException()
}

val elapsed = wait()

// We have 50% chance to kill the task. FIXME make it configurable?
Expand Down
Loading