From 24d650f71c38ac709689e6af37bf8eb78355794e Mon Sep 17 00:00:00 2001 From: Mingliang Zhu Date: Fri, 6 Sep 2024 16:57:38 +0800 Subject: [PATCH] [VL] Fix function `input_file_name()` outputs empty string in certain query plan patterns (#7124) --- .../backendsapi/velox/VeloxBackend.scala | 2 - .../backendsapi/velox/VeloxRuleApi.scala | 2 + .../ScalarFunctionsValidateSuite.scala | 17 +- .../backendsapi/BackendSettingsApi.scala | 1 - .../columnar/MiscColumnarRules.scala | 2 +- .../columnar/OffloadSingleNode.scala | 152 +----------------- .../PushDownInputFileExpression.scala | 118 ++++++++++++++ 7 files changed, 142 insertions(+), 152 deletions(-) create mode 100644 gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 611e9c15bf48a..ecd053a632f39 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -282,8 +282,6 @@ object VeloxBackendSettings extends BackendSettingsApi { override def supportNativeRowIndexColumn(): Boolean = true - override def supportNativeInputFileRelatedExpr(): Boolean = true - override def supportExpandExec(): Boolean = true override def supportSortExec(): Boolean = true diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 4ff7f0305d58c..b278224b2a822 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -52,6 +52,7 @@ private object VeloxRuleApi { def injectLegacy(injector: LegacyInjector): Unit = { // Gluten columnar: Transform rules. injector.injectTransform(_ => RemoveTransitions) + injector.injectTransform(_ => PushDownInputFileExpression.PreOffload) injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) injector.injectTransform(c => FallbackMultiCodegens.apply(c.session)) injector.injectTransform(c => PlanOneRowRelation.apply(c.session)) @@ -64,6 +65,7 @@ private object VeloxRuleApi { injector.injectTransform(_ => TransformPreOverrides()) injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject()) injector.injectTransform(c => RewriteTransformer.apply(c.session)) + injector.injectTransform(_ => PushDownInputFileExpression.PostOffload) injector.injectTransform(_ => EnsureLocalSortRequirements) injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index 81da24f8ed47b..a376fd488dd73 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -28,7 +28,7 @@ class ScalarFunctionsValidateSuiteRasOff extends ScalarFunctionsValidateSuite { super.sparkConf .set("spark.gluten.ras.enabled", "false") } - + import testImplicits._ // Since https://github.com/apache/incubator-gluten/pull/6200. test("Test input_file_name function") { runQueryAndCompare("""SELECT input_file_name(), l_orderkey @@ -44,6 +44,21 @@ class ScalarFunctionsValidateSuiteRasOff extends ScalarFunctionsValidateSuite { | limit 100""".stripMargin) { checkGlutenOperatorMatch[ProjectExecTransformer] } + withTempPath { + path => + Seq(1, 2, 3).toDF("a").write.json(path.getCanonicalPath) + spark.read.json(path.getCanonicalPath).createOrReplaceTempView("json_table") + val sql = + """ + |SELECT input_file_name(), a + |FROM + |(SELECT a FROM json_table + |UNION ALL + |SELECT l_orderkey as a FROM lineitem) + |LIMIT 100 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + } } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index c9205bae9d8fe..451cb2fd25687 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -45,7 +45,6 @@ trait BackendSettingsApi { def supportNativeWrite(fields: Array[StructField]): Boolean = true def supportNativeMetadataColumns(): Boolean = false def supportNativeRowIndexColumn(): Boolean = false - def supportNativeInputFileRelatedExpr(): Boolean = false def supportExpandExec(): Boolean = false def supportSortExec(): Boolean = false diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala index b7a30f7e177a5..4e668d7f837a0 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala @@ -34,7 +34,7 @@ object MiscColumnarRules { object TransformPreOverrides { def apply(): TransformPreOverrides = { TransformPreOverrides( - List(OffloadProject(), OffloadFilter()), + List(OffloadFilter()), List( OffloadOthers(), OffloadAggregate(), diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 3cea5e76a8373..6047789e6abe9 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -26,21 +26,17 @@ import org.apache.gluten.utils.{LogLevelUtil, PlanUtil} import org.apache.spark.api.python.EvalPythonExecTransformer import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.datasources.WriteFilesExec -import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanExecBase} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BatchEvalPythonExec} import org.apache.spark.sql.execution.window.{WindowExec, WindowGroupLimitExecShim} import org.apache.spark.sql.hive.HiveTableScanExecTransformer -import org.apache.spark.sql.types.{LongType, StringType} - -import scala.collection.mutable.Map /** * Converts a vanilla Spark plan node into Gluten plan node. Gluten plan is supposed to be executed @@ -224,148 +220,6 @@ object OffloadJoin { } } -case class OffloadProject() extends OffloadSingleNode with LogLevelUtil { - private def containsInputFileRelatedExpr(expr: Expression): Boolean = { - expr match { - case _: InputFileName | _: InputFileBlockStart | _: InputFileBlockLength => true - case _ => expr.children.exists(containsInputFileRelatedExpr) - } - } - - private def rewriteExpr( - expr: Expression, - replacedExprs: Map[String, AttributeReference]): Expression = { - expr match { - case _: InputFileName => - replacedExprs.getOrElseUpdate( - expr.prettyName, - AttributeReference(expr.prettyName, StringType, false)()) - case _: InputFileBlockStart => - replacedExprs.getOrElseUpdate( - expr.prettyName, - AttributeReference(expr.prettyName, LongType, false)()) - case _: InputFileBlockLength => - replacedExprs.getOrElseUpdate( - expr.prettyName, - AttributeReference(expr.prettyName, LongType, false)()) - case other => - other.withNewChildren(other.children.map(child => rewriteExpr(child, replacedExprs))) - } - } - - private def addMetadataCol( - plan: SparkPlan, - replacedExprs: Map[String, AttributeReference]): SparkPlan = { - def genNewOutput(output: Seq[Attribute]): Seq[Attribute] = { - var newOutput = output - for ((_, newAttr) <- replacedExprs) { - if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) { - newOutput = newOutput :+ newAttr - } - } - newOutput - } - def genNewProjectList(projectList: Seq[NamedExpression]): Seq[NamedExpression] = { - var newProjectList = projectList - for ((_, newAttr) <- replacedExprs) { - if (!newProjectList.exists(attr => attr.exprId == newAttr.exprId)) { - newProjectList = newProjectList :+ newAttr.toAttribute - } - } - newProjectList - } - - plan match { - case f: FileSourceScanExec => - f.copy(output = genNewOutput(f.output)) - case f: FileSourceScanExecTransformer => - f.copy(output = genNewOutput(f.output)) - case b: BatchScanExec => - b.copy(output = genNewOutput(b.output).asInstanceOf[Seq[AttributeReference]]) - case b: BatchScanExecTransformer => - b.copy(output = genNewOutput(b.output).asInstanceOf[Seq[AttributeReference]]) - case p @ ProjectExec(projectList, child) => - p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs)) - case p @ ProjectExecTransformer(projectList, child) => - p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs)) - case u @ UnionExec(children) => - val newFirstChild = addMetadataCol(children.head, replacedExprs) - val newOtherChildren = children.tail.map { - child => - // Make sure exprId is unique in each child of Union. - val newReplacedExprs = replacedExprs.map { - expr => (expr._1, AttributeReference(expr._2.name, expr._2.dataType, false)()) - } - addMetadataCol(child, newReplacedExprs) - } - u.copy(children = newFirstChild +: newOtherChildren) - case _ => plan.withNewChildren(plan.children.map(addMetadataCol(_, replacedExprs))) - } - } - - private def tryOffloadProjectExecWithInputFileRelatedExprs( - projectExec: ProjectExec): SparkPlan = { - def findScanNodes(plan: SparkPlan): Seq[SparkPlan] = { - plan.collect { - case f @ (_: FileSourceScanExec | _: AbstractFileSourceScanExec | - _: DataSourceV2ScanExecBase) => - f - } - } - val addHint = AddFallbackTagRule() - val newProjectList = projectExec.projectList.filterNot(containsInputFileRelatedExpr) - val newProjectExec = ProjectExec(newProjectList, projectExec.child) - addHint.apply(newProjectExec) - if (FallbackTags.nonEmpty(newProjectExec)) { - // Project is still not transformable after remove `input_file_name` expressions. - projectExec - } else { - // the project with `input_file_name` expression may have multiple data source - // by union all, reference: - // https://github.com/apache/spark/blob/e459674127e7b21e2767cc62d10ea6f1f941936c - // /sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala#L519 - val leafScans = findScanNodes(projectExec) - if (leafScans.isEmpty || leafScans.exists(FallbackTags.nonEmpty)) { - // It means - // 1. projectExec has `input_file_name` but no scan child. - // 2. It has scan children node but the scan node fallback. - projectExec - } else { - val replacedExprs = scala.collection.mutable.Map[String, AttributeReference]() - val newProjectList = projectExec.projectList.map { - expr => rewriteExpr(expr, replacedExprs).asInstanceOf[NamedExpression] - } - val newChild = addMetadataCol(projectExec.child, replacedExprs) - logDebug( - s"Columnar Processing for ${projectExec.getClass} with " + - s"ProjectList ${projectExec.projectList} is currently supported.") - ProjectExecTransformer(newProjectList, newChild) - } - } - } - - private def genProjectExec(projectExec: ProjectExec): SparkPlan = { - if ( - FallbackTags.nonEmpty(projectExec) && - BackendsApiManager.getSettings.supportNativeInputFileRelatedExpr() && - projectExec.projectList.exists(containsInputFileRelatedExpr) - ) { - tryOffloadProjectExecWithInputFileRelatedExprs(projectExec) - } else if (FallbackTags.nonEmpty(projectExec)) { - projectExec - } else { - logDebug(s"Columnar Processing for ${projectExec.getClass} is currently supported.") - ProjectExecTransformer(projectExec.projectList, projectExec.child) - } - } - - override def offload(plan: SparkPlan): SparkPlan = plan match { - case p: ProjectExec => - genProjectExec(p) - case other => other - } -} - // Filter transformation. case class OffloadFilter() extends OffloadSingleNode with LogLevelUtil { import OffloadOthers._ @@ -443,6 +297,10 @@ object OffloadOthers { case plan: CoalesceExec => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarCoalesceExec(plan.numPartitions, plan.child) + case plan: ProjectExec => + val columnarChild = plan.child + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ProjectExecTransformer(plan.projectList, columnarChild) case plan: SortAggregateExec => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") HashAggregateExecBaseTransformer.from(plan) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala new file mode 100644 index 0000000000000..e1219fead7282 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.columnar + +import org.apache.gluten.execution.{BatchScanExecTransformer, FileSourceScanExecTransformer} + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{DeserializeToObjectExec, LeafExecNode, ProjectExec, SerializeFromObjectExec, SparkPlan, UnionExec} + +import scala.collection.mutable + +/** + * The Spark implementations of input_file_name/input_file_block_start/input_file_block_length uses + * a thread local to stash the file name and retrieve it from the function. If there is a + * transformer node between project input_file_function and scan, the result of input_file_name is + * an empty string. So we should push down input_file_function to transformer scan or add fallback + * project of input_file_function before fallback scan. + * + * Two rules are involved: + * - Before offload, add new project before leaf node and push down input file expression to the + * new project + * - After offload, if scan be offloaded, push down input file expression into scan and remove + * project + */ +object PushDownInputFileExpression { + def containsInputFileRelatedExpr(expr: Expression): Boolean = { + expr match { + case _: InputFileName | _: InputFileBlockStart | _: InputFileBlockLength => true + case _ => expr.children.exists(containsInputFileRelatedExpr) + } + } + + object PreOffload extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case ProjectExec(projectList, child) if projectList.exists(containsInputFileRelatedExpr) => + val replacedExprs = mutable.Map[String, Alias]() + val newProjectList = projectList.map { + expr => rewriteExpr(expr, replacedExprs).asInstanceOf[NamedExpression] + } + val newChild = addMetadataCol(child, replacedExprs) + ProjectExec(newProjectList, newChild) + } + + private def rewriteExpr( + expr: Expression, + replacedExprs: mutable.Map[String, Alias]): Expression = + expr match { + case _: InputFileName => + replacedExprs + .getOrElseUpdate(expr.prettyName, Alias(InputFileName(), expr.prettyName)()) + .toAttribute + case _: InputFileBlockStart => + replacedExprs + .getOrElseUpdate(expr.prettyName, Alias(InputFileBlockStart(), expr.prettyName)()) + .toAttribute + case _: InputFileBlockLength => + replacedExprs + .getOrElseUpdate(expr.prettyName, Alias(InputFileBlockLength(), expr.prettyName)()) + .toAttribute + case other => + other.withNewChildren(other.children.map(child => rewriteExpr(child, replacedExprs))) + } + + private def addMetadataCol( + plan: SparkPlan, + replacedExprs: mutable.Map[String, Alias]): SparkPlan = + plan match { + case p: LeafExecNode => + ProjectExec(p.output ++ replacedExprs.values, p) + // Output of SerializeFromObjectExec's child and output of DeserializeToObjectExec must be + // a single-field row. + case p @ (_: SerializeFromObjectExec | _: DeserializeToObjectExec) => + ProjectExec(p.output ++ replacedExprs.values, p) + case p: ProjectExec => + p.copy( + projectList = p.projectList ++ replacedExprs.values.toSeq.map(_.toAttribute), + child = addMetadataCol(p.child, replacedExprs)) + case u @ UnionExec(children) => + val newFirstChild = addMetadataCol(children.head, replacedExprs) + val newOtherChildren = children.tail.map { + child => + // Make sure exprId is unique in each child of Union. + val newReplacedExprs = replacedExprs.map { + expr => (expr._1, Alias(expr._2.child, expr._2.name)()) + } + addMetadataCol(child, newReplacedExprs) + } + u.copy(children = newFirstChild +: newOtherChildren) + case p => p.withNewChildren(p.children.map(child => addMetadataCol(child, replacedExprs))) + } + } + + object PostOffload extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case p @ ProjectExec(projectList, child: FileSourceScanExecTransformer) + if projectList.exists(containsInputFileRelatedExpr) => + child.copy(output = p.output) + case p @ ProjectExec(projectList, child: BatchScanExecTransformer) + if projectList.exists(containsInputFileRelatedExpr) => + child.copy(output = p.output.asInstanceOf[Seq[AttributeReference]]) + } + } +}