diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala index 459a7886ea23..880e1e56b852 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala @@ -19,6 +19,7 @@ package org.apache.gluten.backendsapi.velox import org.apache.gluten.GlutenNumaBindingInfo import org.apache.gluten.backendsapi.IteratorApi import org.apache.gluten.execution._ +import org.apache.gluten.extension.InputFileNameReplaceRule import org.apache.gluten.metrics.IMetrics import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.substrait.plan.PlanNode @@ -112,7 +113,7 @@ class VeloxIteratorApi extends IteratorApi with Logging { val fileSizes = new JArrayList[JLong]() val modificationTimes = new JArrayList[JLong]() val partitionColumns = new JArrayList[JMap[String, String]] - var metadataColumns = new JArrayList[JMap[String, String]] + val metadataColumns = new JArrayList[JMap[String, String]] files.foreach { file => // The "file.filePath" in PartitionedFile is not the original encoded path, so the decoded @@ -132,6 +133,13 @@ class VeloxIteratorApi extends IteratorApi with Logging { } val metadataColumn = SparkShimLoader.getSparkShims.generateMetadataColumns(file, metadataColumnNames) + metadataColumn.put(InputFileNameReplaceRule.replacedInputFileName, file.filePath.toString) + metadataColumn.put( + InputFileNameReplaceRule.replacedInputFileBlockStart, + file.start.toString) + metadataColumn.put( + InputFileNameReplaceRule.replacedInputFileBlockLength, + file.length.toString) metadataColumns.add(metadataColumn) val partitionColumn = new JHashMap[String, String]() for (i <- 0 until file.partitionValues.numFields) { diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index ebf82ea767a4..71930d7e0f47 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -807,7 +807,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { */ override def genExtendedColumnarValidationRules(): List[SparkSession => Rule[SparkPlan]] = List( BloomFilterMightContainJointRewriteRule.apply, - ArrowScanReplaceRule.apply + ArrowScanReplaceRule.apply, + InputFileNameReplaceRule.apply ) /** diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/InputFileNameReplaceRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/InputFileNameReplaceRule.scala new file mode 100644 index 000000000000..cd3f50d8e77f --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/InputFileNameReplaceRule.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{FileSourceScanExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.types.{LongType, StringType} + +object InputFileNameReplaceRule { + val replacedInputFileName = "$input_file_name$" + val replacedInputFileBlockStart = "$input_file_block_start$" + val replacedInputFileBlockLength = "$input_file_block_length$" +} + +case class InputFileNameReplaceRule(spark: SparkSession) extends Rule[SparkPlan] { + import InputFileNameReplaceRule._ + + private def isInputFileName(expr: Expression): Boolean = { + expr match { + case _: InputFileName => true + case _ => false + } + } + + private def isInputFileBlockStart(expr: Expression): Boolean = { + expr match { + case _: InputFileBlockStart => true + case _ => false + } + } + + private def isInputFileBlockLength(expr: Expression): Boolean = { + expr match { + case _: InputFileBlockLength => true + case _ => false + } + } + + override def apply(plan: SparkPlan): SparkPlan = { + val replacedExprs = scala.collection.mutable.Map[String, AttributeReference]() + + def hasParquetScan(plan: SparkPlan): Boolean = { + plan match { + case fileScan: FileSourceScanExec + if fileScan.relation.fileFormat.isInstanceOf[ParquetFileFormat] => + true + case batchScan: BatchScanExec => + batchScan.scan match { + case _: ParquetScan => true + case _ => false + } + case _ => plan.children.exists(hasParquetScan) + } + } + + def mayNeedConvert(expr: Expression): Boolean = { + expr match { + case e if isInputFileName(e) => true + case s if isInputFileBlockStart(s) => true + case l if isInputFileBlockLength(l) => true + case other => other.children.exists(mayNeedConvert) + } + } + + def doConvert(expr: Expression): Expression = { + expr match { + case e if isInputFileName(e) => + replacedExprs.getOrElseUpdate( + replacedInputFileName, + AttributeReference(replacedInputFileName, StringType, true)()) + case s if isInputFileBlockStart(s) => + replacedExprs.getOrElseUpdate( + replacedInputFileBlockStart, + AttributeReference(replacedInputFileBlockStart, LongType, true)() + ) + case l if isInputFileBlockLength(l) => + replacedExprs.getOrElseUpdate( + replacedInputFileBlockLength, + AttributeReference(replacedInputFileBlockLength, LongType, true)() + ) + case other => + other.withNewChildren(other.children.map(child => doConvert(child))) + } + } + + def ensureChildOutputHasNewAttrs(plan: SparkPlan): SparkPlan = { + plan match { + case _ @ProjectExec(projectList, child) => + var newProjectList = projectList + for ((_, newAttr) <- replacedExprs) { + if (!newProjectList.exists(attr => attr.exprId == newAttr.exprId)) { + newProjectList = newProjectList :+ newAttr.toAttribute + } + } + val newChild = ensureChildOutputHasNewAttrs(child) + ProjectExec(newProjectList, newChild) + case f: FileSourceScanExec => + var newOutput = f.output + for ((_, newAttr) <- replacedExprs) { + if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) { + newOutput = newOutput :+ newAttr.toAttribute + } + } + f.copy(output = newOutput) + + case b: BatchScanExec => + var newOutput = b.output + for ((_, newAttr) <- replacedExprs) { + if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) { + newOutput = newOutput :+ newAttr + } + } + b.copy(output = newOutput) + case other => + val newChildren = other.children.map(ensureChildOutputHasNewAttrs) + other.withNewChildren(newChildren) + } + } + + def replaceInputFileNameInProject(plan: SparkPlan): SparkPlan = { + plan match { + case _ @ProjectExec(projectList, child) + if projectList.exists(mayNeedConvert) && hasParquetScan(plan) => + val newProjectList = projectList.map { + expr => doConvert(expr).asInstanceOf[NamedExpression] + } + val newChild = replaceInputFileNameInProject(ensureChildOutputHasNewAttrs(child)) + ProjectExec(newProjectList, newChild) + case other => + val newChildren = other.children.map(replaceInputFileNameInProject) + other.withNewChildren(newChildren) + } + } + replaceInputFileNameInProject(plan) + } +} diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index 9718b8e7358e..d08ba11ee787 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -623,6 +623,13 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { } } + test("Test input_file_name function") { + runQueryAndCompare("""SELECT input_file_name(), l_orderkey + | from lineitem limit 100""".stripMargin) { + checkGlutenOperatorMatch[ProjectExecTransformer] + } + } + test("Test spark_partition_id function") { runQueryAndCompare("""SELECT spark_partition_id(), l_orderkey | from lineitem limit 100""".stripMargin) { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala index ad68786e6579..d925bc231cd9 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala @@ -96,11 +96,11 @@ class HeuristicApplier(session: SparkSession) (spark: SparkSession) => FallbackOnANSIMode(spark), (spark: SparkSession) => FallbackMultiCodegens(spark), (spark: SparkSession) => PlanOneRowRelation(spark), - (_: SparkSession) => FallbackEmptySchemaRelation(), (_: SparkSession) => RewriteSubqueryBroadcast() ) ::: BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules() ::: List( + (_: SparkSession) => FallbackEmptySchemaRelation(), (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark), (_: SparkSession) => RewriteSparkPlanRulesManager(), (_: SparkSession) => AddTransformHintRule()