From f82c298268de4dfb5a31ff8c88518054cceff417 Mon Sep 17 00:00:00 2001 From: zml1206 Date: Thu, 5 Sep 2024 13:33:31 +0800 Subject: [PATCH 1/4] [VL] Fix input_file_name results in empty string --- .../backendsapi/velox/VeloxBackend.scala | 2 - .../backendsapi/velox/VeloxRuleApi.scala | 2 + .../ScalarFunctionsValidateSuite.scala | 17 +- .../backendsapi/BackendSettingsApi.scala | 1 - .../columnar/MiscColumnarRules.scala | 2 +- .../columnar/OffloadSingleNode.scala | 152 +----------------- .../PushDownInputFileExpression.scala | 112 +++++++++++++ 7 files changed, 136 insertions(+), 152 deletions(-) create mode 100644 gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 611e9c15bf48..ecd053a632f3 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -282,8 +282,6 @@ object VeloxBackendSettings extends BackendSettingsApi { override def supportNativeRowIndexColumn(): Boolean = true - override def supportNativeInputFileRelatedExpr(): Boolean = true - override def supportExpandExec(): Boolean = true override def supportSortExec(): Boolean = true diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 4ff7f0305d58..32b9dbc37e8d 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -52,6 +52,7 @@ private object VeloxRuleApi { def injectLegacy(injector: LegacyInjector): Unit = { // Gluten columnar: Transform rules. injector.injectTransform(_ => RemoveTransitions) + injector.injectTransform(_ => PushDownInputFileExpressionBeforeLeaf) injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) injector.injectTransform(c => FallbackMultiCodegens.apply(c.session)) injector.injectTransform(c => PlanOneRowRelation.apply(c.session)) @@ -64,6 +65,7 @@ private object VeloxRuleApi { injector.injectTransform(_ => TransformPreOverrides()) injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject()) injector.injectTransform(c => RewriteTransformer.apply(c.session)) + injector.injectTransform(_ => PushDownInputFileExpressionToScan) injector.injectTransform(_ => EnsureLocalSortRequirements) injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index b8de30b1b06f..77213f0b86e0 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -28,7 +28,7 @@ class ScalarFunctionsValidateSuiteRasOff extends ScalarFunctionsValidateSuite { super.sparkConf .set("spark.gluten.ras.enabled", "false") } - + import testImplicits._ // Since https://github.com/apache/incubator-gluten/pull/6200. test("Test input_file_name function") { runQueryAndCompare("""SELECT input_file_name(), l_orderkey @@ -44,6 +44,21 @@ class ScalarFunctionsValidateSuiteRasOff extends ScalarFunctionsValidateSuite { | limit 100""".stripMargin) { checkGlutenOperatorMatch[ProjectExecTransformer] } + withTempPath { + path => + Seq(1, 2, 3).toDF("a").write.json(path.getCanonicalPath) + spark.read.json(path.getCanonicalPath).createOrReplaceTempView("json_table") + val sql = + """ + |SELECT input_file_name(), a + |FROM + |(SELECT a FROM json_table + |UNION ALL + |SELECT l_orderkey as a FROM lineitem) + |LIMIT 100 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + } } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index c9205bae9d8f..451cb2fd2568 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -45,7 +45,6 @@ trait BackendSettingsApi { def supportNativeWrite(fields: Array[StructField]): Boolean = true def supportNativeMetadataColumns(): Boolean = false def supportNativeRowIndexColumn(): Boolean = false - def supportNativeInputFileRelatedExpr(): Boolean = false def supportExpandExec(): Boolean = false def supportSortExec(): Boolean = false diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala index b7a30f7e177a..4e668d7f837a 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala @@ -34,7 +34,7 @@ object MiscColumnarRules { object TransformPreOverrides { def apply(): TransformPreOverrides = { TransformPreOverrides( - List(OffloadProject(), OffloadFilter()), + List(OffloadFilter()), List( OffloadOthers(), OffloadAggregate(), diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 70b85165c37b..88f78422ad1f 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -26,21 +26,17 @@ import org.apache.gluten.utils.{LogLevelUtil, PlanUtil} import org.apache.spark.api.python.EvalPythonExecTransformer import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.datasources.WriteFilesExec -import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanExecBase} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BatchEvalPythonExec} import org.apache.spark.sql.execution.window.{WindowExec, WindowGroupLimitExecShim} import org.apache.spark.sql.hive.HiveTableScanExecTransformer -import org.apache.spark.sql.types.{LongType, StringType} - -import scala.collection.mutable.Map /** * Converts a vanilla Spark plan node into Gluten plan node. Gluten plan is supposed to be executed @@ -220,148 +216,6 @@ object OffloadJoin { } } -case class OffloadProject() extends OffloadSingleNode with LogLevelUtil { - private def containsInputFileRelatedExpr(expr: Expression): Boolean = { - expr match { - case _: InputFileName | _: InputFileBlockStart | _: InputFileBlockLength => true - case _ => expr.children.exists(containsInputFileRelatedExpr) - } - } - - private def rewriteExpr( - expr: Expression, - replacedExprs: Map[String, AttributeReference]): Expression = { - expr match { - case _: InputFileName => - replacedExprs.getOrElseUpdate( - expr.prettyName, - AttributeReference(expr.prettyName, StringType, false)()) - case _: InputFileBlockStart => - replacedExprs.getOrElseUpdate( - expr.prettyName, - AttributeReference(expr.prettyName, LongType, false)()) - case _: InputFileBlockLength => - replacedExprs.getOrElseUpdate( - expr.prettyName, - AttributeReference(expr.prettyName, LongType, false)()) - case other => - other.withNewChildren(other.children.map(child => rewriteExpr(child, replacedExprs))) - } - } - - private def addMetadataCol( - plan: SparkPlan, - replacedExprs: Map[String, AttributeReference]): SparkPlan = { - def genNewOutput(output: Seq[Attribute]): Seq[Attribute] = { - var newOutput = output - for ((_, newAttr) <- replacedExprs) { - if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) { - newOutput = newOutput :+ newAttr - } - } - newOutput - } - def genNewProjectList(projectList: Seq[NamedExpression]): Seq[NamedExpression] = { - var newProjectList = projectList - for ((_, newAttr) <- replacedExprs) { - if (!newProjectList.exists(attr => attr.exprId == newAttr.exprId)) { - newProjectList = newProjectList :+ newAttr.toAttribute - } - } - newProjectList - } - - plan match { - case f: FileSourceScanExec => - f.copy(output = genNewOutput(f.output)) - case f: FileSourceScanExecTransformer => - f.copy(output = genNewOutput(f.output)) - case b: BatchScanExec => - b.copy(output = genNewOutput(b.output).asInstanceOf[Seq[AttributeReference]]) - case b: BatchScanExecTransformer => - b.copy(output = genNewOutput(b.output).asInstanceOf[Seq[AttributeReference]]) - case p @ ProjectExec(projectList, child) => - p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs)) - case p @ ProjectExecTransformer(projectList, child) => - p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs)) - case u @ UnionExec(children) => - val newFirstChild = addMetadataCol(children.head, replacedExprs) - val newOtherChildren = children.tail.map { - child => - // Make sure exprId is unique in each child of Union. - val newReplacedExprs = replacedExprs.map { - expr => (expr._1, AttributeReference(expr._2.name, expr._2.dataType, false)()) - } - addMetadataCol(child, newReplacedExprs) - } - u.copy(children = newFirstChild +: newOtherChildren) - case _ => plan.withNewChildren(plan.children.map(addMetadataCol(_, replacedExprs))) - } - } - - private def tryOffloadProjectExecWithInputFileRelatedExprs( - projectExec: ProjectExec): SparkPlan = { - def findScanNodes(plan: SparkPlan): Seq[SparkPlan] = { - plan.collect { - case f @ (_: FileSourceScanExec | _: AbstractFileSourceScanExec | - _: DataSourceV2ScanExecBase) => - f - } - } - val addHint = AddFallbackTagRule() - val newProjectList = projectExec.projectList.filterNot(containsInputFileRelatedExpr) - val newProjectExec = ProjectExec(newProjectList, projectExec.child) - addHint.apply(newProjectExec) - if (FallbackTags.nonEmpty(newProjectExec)) { - // Project is still not transformable after remove `input_file_name` expressions. - projectExec - } else { - // the project with `input_file_name` expression may have multiple data source - // by union all, reference: - // https://github.com/apache/spark/blob/e459674127e7b21e2767cc62d10ea6f1f941936c - // /sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala#L519 - val leafScans = findScanNodes(projectExec) - if (leafScans.isEmpty || leafScans.exists(FallbackTags.nonEmpty)) { - // It means - // 1. projectExec has `input_file_name` but no scan child. - // 2. It has scan children node but the scan node fallback. - projectExec - } else { - val replacedExprs = scala.collection.mutable.Map[String, AttributeReference]() - val newProjectList = projectExec.projectList.map { - expr => rewriteExpr(expr, replacedExprs).asInstanceOf[NamedExpression] - } - val newChild = addMetadataCol(projectExec.child, replacedExprs) - logDebug( - s"Columnar Processing for ${projectExec.getClass} with " + - s"ProjectList ${projectExec.projectList} is currently supported.") - ProjectExecTransformer(newProjectList, newChild) - } - } - } - - private def genProjectExec(projectExec: ProjectExec): SparkPlan = { - if ( - FallbackTags.nonEmpty(projectExec) && - BackendsApiManager.getSettings.supportNativeInputFileRelatedExpr() && - projectExec.projectList.exists(containsInputFileRelatedExpr) - ) { - tryOffloadProjectExecWithInputFileRelatedExprs(projectExec) - } else if (FallbackTags.nonEmpty(projectExec)) { - projectExec - } else { - logDebug(s"Columnar Processing for ${projectExec.getClass} is currently supported.") - ProjectExecTransformer(projectExec.projectList, projectExec.child) - } - } - - override def offload(plan: SparkPlan): SparkPlan = plan match { - case p: ProjectExec => - genProjectExec(p) - case other => other - } -} - // Filter transformation. case class OffloadFilter() extends OffloadSingleNode with LogLevelUtil { import OffloadOthers._ @@ -439,6 +293,10 @@ object OffloadOthers { case plan: CoalesceExec => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarCoalesceExec(plan.numPartitions, plan.child) + case plan: ProjectExec => + val columnarChild = plan.child + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ProjectExecTransformer(plan.projectList, columnarChild) case plan: SortAggregateExec => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") HashAggregateExecBaseTransformer.from(plan) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala new file mode 100644 index 000000000000..b1b032455e3e --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.columnar + +import org.apache.gluten.execution.{BatchScanExecTransformer, FileSourceScanExecTransformer} + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan, UnionExec} + +import scala.collection.mutable +import scala.collection.mutable.Map + +/** + * The Spark implementations of input_file_name/input_file_block_start/input_file_block_length uses + * a thread local to stash the file name and retrieve it from the function. If there is a + * transformer node between project input_file_function and scan, the result of input_file_name is + * an empty string. So we should push down input_file_function to transformer scan or add fallback + * project of input_file_function before fallback scan. + * + * Two rules are involved: + * - Before offload, add new project before leaf node and push down input file expression to the + * new project + * - After offload, if scan be offloaded, push down input file expression into scan and remove + * project + */ +object PushDownInputFileExpression { + def containsInputFileRelatedExpr(expr: Expression): Boolean = { + expr match { + case _: InputFileName | _: InputFileBlockStart | _: InputFileBlockLength => true + case _ => expr.children.exists(containsInputFileRelatedExpr) + } + } +} + +object PushDownInputFileExpressionBeforeLeaf extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case ProjectExec(projectList, child) + if projectList.exists(PushDownInputFileExpression.containsInputFileRelatedExpr) => + val replacedExprs = mutable.Map[String, Alias]() + val newProjectList = projectList.map { + expr => rewriteExpr(expr, replacedExprs).asInstanceOf[NamedExpression] + } + val newChild = addMetadataCol(child, replacedExprs) + ProjectExec(newProjectList, newChild) + } + + private def rewriteExpr(expr: Expression, replacedExprs: Map[String, Alias]): Expression = { + expr match { + case _: InputFileName => + replacedExprs + .getOrElseUpdate(expr.prettyName, Alias(InputFileName(), expr.prettyName)()) + .toAttribute + case _: InputFileBlockStart => + replacedExprs + .getOrElseUpdate(expr.prettyName, Alias(InputFileBlockStart(), expr.prettyName)()) + .toAttribute + case _: InputFileBlockLength => + replacedExprs + .getOrElseUpdate(expr.prettyName, Alias(InputFileBlockLength(), expr.prettyName)()) + .toAttribute + case other => + other.withNewChildren(other.children.map(child => rewriteExpr(child, replacedExprs))) + } + } + + def addMetadataCol(plan: SparkPlan, replacedExprs: Map[String, Alias]): SparkPlan = plan match { + case p if p.children.isEmpty => + ProjectExec(p.output ++ replacedExprs.values, p) + case p: ProjectExec => + p.copy( + projectList = p.projectList ++ replacedExprs.values.toSeq.map(_.toAttribute), + child = addMetadataCol(p.child, replacedExprs)) + case u @ UnionExec(children) => + val newFirstChild = addMetadataCol(children.head, replacedExprs) + val newOtherChildren = children.tail.map { + child => + // Make sure exprId is unique in each child of Union. + val newReplacedExprs = replacedExprs.map { + expr => (expr._1, Alias(expr._2.child, expr._2.name)()) + } + addMetadataCol(child, newReplacedExprs) + } + u.copy(children = newFirstChild +: newOtherChildren) + case p => p.withNewChildren(p.children.map(child => addMetadataCol(child, replacedExprs))) + } +} + +object PushDownInputFileExpressionToScan extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case p @ ProjectExec(projectList, child: FileSourceScanExecTransformer) + if projectList.exists(PushDownInputFileExpression.containsInputFileRelatedExpr) => + child.copy(output = p.output) + case p @ ProjectExec(projectList, child: BatchScanExecTransformer) + if projectList.exists(PushDownInputFileExpression.containsInputFileRelatedExpr) => + child.copy(output = p.output.asInstanceOf[Seq[AttributeReference]]) + } +} From 93e14cc3b155df95782c5013653417a846903a23 Mon Sep 17 00:00:00 2001 From: zml1206 Date: Thu, 5 Sep 2024 16:53:21 +0800 Subject: [PATCH 2/4] fix --- .../extension/columnar/PushDownInputFileExpression.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala index b1b032455e3e..64cad4580d21 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala @@ -20,7 +20,7 @@ import org.apache.gluten.execution.{BatchScanExecTransformer, FileSourceScanExec import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ProjectExec, SparkPlan, UnionExec} +import org.apache.spark.sql.execution.{DeserializeToObjectExec, LeafExecNode, ProjectExec, SerializeFromObjectExec, SparkPlan, UnionExec} import scala.collection.mutable import scala.collection.mutable.Map @@ -79,7 +79,11 @@ object PushDownInputFileExpressionBeforeLeaf extends Rule[SparkPlan] { } def addMetadataCol(plan: SparkPlan, replacedExprs: Map[String, Alias]): SparkPlan = plan match { - case p if p.children.isEmpty => + case p: LeafExecNode => + ProjectExec(p.output ++ replacedExprs.values, p) + // Output of SerializeFromObjectExec's child and output of DeserializeToObjectExec must be a + // single-field row. + case p @ (_: SerializeFromObjectExec | _: DeserializeToObjectExec) => ProjectExec(p.output ++ replacedExprs.values, p) case p: ProjectExec => p.copy( From 50154962589c7def7be9d15b5b02d632bbb61948 Mon Sep 17 00:00:00 2001 From: zml1206 Date: Fri, 6 Sep 2024 14:43:50 +0800 Subject: [PATCH 3/4] update --- .../PushDownInputFileExpression.scala | 124 +++++++++--------- 1 file changed, 63 insertions(+), 61 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala index 64cad4580d21..e1219fead728 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownInputFileExpression.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{DeserializeToObjectExec, LeafExecNode, ProjectExec, SerializeFromObjectExec, SparkPlan, UnionExec} import scala.collection.mutable -import scala.collection.mutable.Map /** * The Spark implementations of input_file_name/input_file_block_start/input_file_block_length uses @@ -45,72 +44,75 @@ object PushDownInputFileExpression { case _ => expr.children.exists(containsInputFileRelatedExpr) } } -} -object PushDownInputFileExpressionBeforeLeaf extends Rule[SparkPlan] { - override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case ProjectExec(projectList, child) - if projectList.exists(PushDownInputFileExpression.containsInputFileRelatedExpr) => - val replacedExprs = mutable.Map[String, Alias]() - val newProjectList = projectList.map { - expr => rewriteExpr(expr, replacedExprs).asInstanceOf[NamedExpression] - } - val newChild = addMetadataCol(child, replacedExprs) - ProjectExec(newProjectList, newChild) - } - - private def rewriteExpr(expr: Expression, replacedExprs: Map[String, Alias]): Expression = { - expr match { - case _: InputFileName => - replacedExprs - .getOrElseUpdate(expr.prettyName, Alias(InputFileName(), expr.prettyName)()) - .toAttribute - case _: InputFileBlockStart => - replacedExprs - .getOrElseUpdate(expr.prettyName, Alias(InputFileBlockStart(), expr.prettyName)()) - .toAttribute - case _: InputFileBlockLength => - replacedExprs - .getOrElseUpdate(expr.prettyName, Alias(InputFileBlockLength(), expr.prettyName)()) - .toAttribute - case other => - other.withNewChildren(other.children.map(child => rewriteExpr(child, replacedExprs))) + object PreOffload extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case ProjectExec(projectList, child) if projectList.exists(containsInputFileRelatedExpr) => + val replacedExprs = mutable.Map[String, Alias]() + val newProjectList = projectList.map { + expr => rewriteExpr(expr, replacedExprs).asInstanceOf[NamedExpression] + } + val newChild = addMetadataCol(child, replacedExprs) + ProjectExec(newProjectList, newChild) } - } - def addMetadataCol(plan: SparkPlan, replacedExprs: Map[String, Alias]): SparkPlan = plan match { - case p: LeafExecNode => - ProjectExec(p.output ++ replacedExprs.values, p) - // Output of SerializeFromObjectExec's child and output of DeserializeToObjectExec must be a - // single-field row. - case p @ (_: SerializeFromObjectExec | _: DeserializeToObjectExec) => - ProjectExec(p.output ++ replacedExprs.values, p) - case p: ProjectExec => - p.copy( - projectList = p.projectList ++ replacedExprs.values.toSeq.map(_.toAttribute), - child = addMetadataCol(p.child, replacedExprs)) - case u @ UnionExec(children) => - val newFirstChild = addMetadataCol(children.head, replacedExprs) - val newOtherChildren = children.tail.map { - child => - // Make sure exprId is unique in each child of Union. - val newReplacedExprs = replacedExprs.map { - expr => (expr._1, Alias(expr._2.child, expr._2.name)()) + private def rewriteExpr( + expr: Expression, + replacedExprs: mutable.Map[String, Alias]): Expression = + expr match { + case _: InputFileName => + replacedExprs + .getOrElseUpdate(expr.prettyName, Alias(InputFileName(), expr.prettyName)()) + .toAttribute + case _: InputFileBlockStart => + replacedExprs + .getOrElseUpdate(expr.prettyName, Alias(InputFileBlockStart(), expr.prettyName)()) + .toAttribute + case _: InputFileBlockLength => + replacedExprs + .getOrElseUpdate(expr.prettyName, Alias(InputFileBlockLength(), expr.prettyName)()) + .toAttribute + case other => + other.withNewChildren(other.children.map(child => rewriteExpr(child, replacedExprs))) + } + + private def addMetadataCol( + plan: SparkPlan, + replacedExprs: mutable.Map[String, Alias]): SparkPlan = + plan match { + case p: LeafExecNode => + ProjectExec(p.output ++ replacedExprs.values, p) + // Output of SerializeFromObjectExec's child and output of DeserializeToObjectExec must be + // a single-field row. + case p @ (_: SerializeFromObjectExec | _: DeserializeToObjectExec) => + ProjectExec(p.output ++ replacedExprs.values, p) + case p: ProjectExec => + p.copy( + projectList = p.projectList ++ replacedExprs.values.toSeq.map(_.toAttribute), + child = addMetadataCol(p.child, replacedExprs)) + case u @ UnionExec(children) => + val newFirstChild = addMetadataCol(children.head, replacedExprs) + val newOtherChildren = children.tail.map { + child => + // Make sure exprId is unique in each child of Union. + val newReplacedExprs = replacedExprs.map { + expr => (expr._1, Alias(expr._2.child, expr._2.name)()) + } + addMetadataCol(child, newReplacedExprs) } - addMetadataCol(child, newReplacedExprs) + u.copy(children = newFirstChild +: newOtherChildren) + case p => p.withNewChildren(p.children.map(child => addMetadataCol(child, replacedExprs))) } - u.copy(children = newFirstChild +: newOtherChildren) - case p => p.withNewChildren(p.children.map(child => addMetadataCol(child, replacedExprs))) } -} -object PushDownInputFileExpressionToScan extends Rule[SparkPlan] { - override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case p @ ProjectExec(projectList, child: FileSourceScanExecTransformer) - if projectList.exists(PushDownInputFileExpression.containsInputFileRelatedExpr) => - child.copy(output = p.output) - case p @ ProjectExec(projectList, child: BatchScanExecTransformer) - if projectList.exists(PushDownInputFileExpression.containsInputFileRelatedExpr) => - child.copy(output = p.output.asInstanceOf[Seq[AttributeReference]]) + object PostOffload extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case p @ ProjectExec(projectList, child: FileSourceScanExecTransformer) + if projectList.exists(containsInputFileRelatedExpr) => + child.copy(output = p.output) + case p @ ProjectExec(projectList, child: BatchScanExecTransformer) + if projectList.exists(containsInputFileRelatedExpr) => + child.copy(output = p.output.asInstanceOf[Seq[AttributeReference]]) + } } } From c4dfcd4f11d45eb85c134bcc0b9516b6d68eedf7 Mon Sep 17 00:00:00 2001 From: zml1206 Date: Fri, 6 Sep 2024 15:02:15 +0800 Subject: [PATCH 4/4] update --- .../org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 32b9dbc37e8d..b278224b2a82 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -52,7 +52,7 @@ private object VeloxRuleApi { def injectLegacy(injector: LegacyInjector): Unit = { // Gluten columnar: Transform rules. injector.injectTransform(_ => RemoveTransitions) - injector.injectTransform(_ => PushDownInputFileExpressionBeforeLeaf) + injector.injectTransform(_ => PushDownInputFileExpression.PreOffload) injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) injector.injectTransform(c => FallbackMultiCodegens.apply(c.session)) injector.injectTransform(c => PlanOneRowRelation.apply(c.session)) @@ -65,7 +65,7 @@ private object VeloxRuleApi { injector.injectTransform(_ => TransformPreOverrides()) injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject()) injector.injectTransform(c => RewriteTransformer.apply(c.session)) - injector.injectTransform(_ => PushDownInputFileExpressionToScan) + injector.injectTransform(_ => PushDownInputFileExpression.PostOffload) injector.injectTransform(_ => EnsureLocalSortRequirements) injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer)