From c7bce952900a60959c633007b06e55344ac22824 Mon Sep 17 00:00:00 2001 From: Zhen Li <10524738+zhli1142015@users.noreply.github.com> Date: Wed, 22 Nov 2023 16:29:46 +0800 Subject: [PATCH] [GLUTEN-3801][VL] Fix isAdaptiveContext null value for ColumnarOverrideRules (#3795) --- .../extension/ColumnarOverrides.scala | 32 ++++++++++++------- .../execution/VeloxDeltaSuite.scala | 18 +++++++++++ 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala index 3c12d0c56e22..f6c2b2ac90a9 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala @@ -45,6 +45,8 @@ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.hive.HiveTableScanExecTransformer import org.apache.spark.util.SparkRuleUtil +import scala.collection.mutable.ListBuffer + // This rule will conduct the conversion from Spark plan to the plan transformer. case class TransformPreOverrides(isAdaptiveContext: Boolean) extends Rule[SparkPlan] @@ -749,7 +751,10 @@ case class ColumnarOverrideRules(session: SparkSession) private val aqeStackTraceIndex = 14 // Holds the original plan for possible entire fallback. - private val localOriginalPlan = new ThreadLocal[SparkPlan] + private val localOriginalPlans: ThreadLocal[ListBuffer[SparkPlan]] = + ThreadLocal.withInitial(() => ListBuffer.empty[SparkPlan]) + private val localIsAdaptiveContextFlags: ThreadLocal[ListBuffer[Boolean]] = + ThreadLocal.withInitial(() => ListBuffer.empty[Boolean]) // Do not create rules in class initialization as we should access SQLConf // while creating the rules. At this time SQLConf may not be there yet. @@ -760,7 +765,10 @@ case class ColumnarOverrideRules(session: SparkSession) } def isAdaptiveContext: Boolean = - session.sparkContext.getLocalProperty(GLUTEN_IS_ADAPTIVE_CONTEXT).toBoolean + Option(session.sparkContext.getLocalProperty(GLUTEN_IS_ADAPTIVE_CONTEXT)) + .getOrElse("false") + .toBoolean || + localIsAdaptiveContextFlags.get().head private def setAdaptiveContext(): Unit = { val traceElements = Thread.currentThread.getStackTrace @@ -772,25 +780,27 @@ case class ColumnarOverrideRules(session: SparkSession) // columnar rule will be applied in adaptive execution context. This part of code // needs to be carefully checked when supporting higher versions of spark to make // sure the calling stack has not been changed. - session.sparkContext.setLocalProperty( - GLUTEN_IS_ADAPTIVE_CONTEXT, - traceElements(aqeStackTraceIndex).getClassName - .equals(AdaptiveSparkPlanExec.getClass.getName) - .toString) + localIsAdaptiveContextFlags + .get() + .prepend( + traceElements(aqeStackTraceIndex).getClassName + .equals(AdaptiveSparkPlanExec.getClass.getName)) } private def resetAdaptiveContext(): Unit = - session.sparkContext.setLocalProperty(GLUTEN_IS_ADAPTIVE_CONTEXT, null) + localIsAdaptiveContextFlags.get().remove(0) - private def setOriginalPlan(plan: SparkPlan): Unit = localOriginalPlan.set(plan) + private def setOriginalPlan(plan: SparkPlan): Unit = { + localOriginalPlans.get.prepend(plan) + } private def originalPlan: SparkPlan = { - val plan = localOriginalPlan.get() + val plan = localOriginalPlans.get.head assert(plan != null) plan } - private def resetOriginalPlan(): Unit = localOriginalPlan.remove() + private def resetOriginalPlan(): Unit = localOriginalPlans.get.remove(0) private def preOverrides(): List[SparkSession => Rule[SparkPlan]] = { val tagBeforeTransformHitsRules = diff --git a/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxDeltaSuite.scala b/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxDeltaSuite.scala index 3f1f7ab255fa..7e096151b16c 100644 --- a/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxDeltaSuite.scala +++ b/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxDeltaSuite.scala @@ -55,4 +55,22 @@ class VeloxDeltaSuite extends WholeStageTransformerSuite { checkLengthAndPlan(df2, 1) checkAnswer(df2, Row("v2") :: Nil) } + + test("basic test with stats.skipping disabled") { + withSQLConf("spark.databricks.delta.stats.skipping" -> "false") { + spark.sql(s""" + |create table delta_test2 (id int, name string) using delta + |""".stripMargin) + spark.sql(s""" + |insert into delta_test2 values (1, "v1"), (2, "v2") + |""".stripMargin) + val df1 = runQueryAndCompare("select * from delta_test2") { _ => } + checkLengthAndPlan(df1, 2) + checkAnswer(df1, Row(1, "v1") :: Row(2, "v2") :: Nil) + + val df2 = runQueryAndCompare("select name from delta_test2 where id = 2") { _ => } + checkLengthAndPlan(df2, 1) + checkAnswer(df2, Row("v2") :: Nil) + } + } }