diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 158be10f486c..82b45f2d4394 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -303,6 +303,8 @@ object VeloxBackendSettings extends BackendSettingsApi { override def supportNativeRowIndexColumn(): Boolean = true + override def supportNativeInputFileRelatedExpr(): Boolean = true + override def supportExpandExec(): Boolean = true override def supportSortExec(): Boolean = true diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala index 22862156c6b2..613e539456ec 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala @@ -19,7 +19,6 @@ package org.apache.gluten.backendsapi.velox import org.apache.gluten.GlutenNumaBindingInfo import org.apache.gluten.backendsapi.IteratorApi import org.apache.gluten.execution._ -import org.apache.gluten.extension.InputFileNameReplaceRule import org.apache.gluten.metrics.IMetrics import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.substrait.plan.PlanNode @@ -134,13 +133,6 @@ class VeloxIteratorApi extends IteratorApi with Logging { } val metadataColumn = SparkShimLoader.getSparkShims.generateMetadataColumns(file, metadataColumnNames) - metadataColumn.put(InputFileNameReplaceRule.replacedInputFileName, file.filePath.toString) - metadataColumn.put( - InputFileNameReplaceRule.replacedInputFileBlockStart, - file.start.toString) - metadataColumn.put( - InputFileNameReplaceRule.replacedInputFileBlockLength, - file.length.toString) metadataColumns.add(metadataColumn) val partitionColumn = new JHashMap[String, String]() for (i <- 0 until file.partitionValues.numFields) { diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index b48da15683e8..ed69a5893c25 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -806,12 +806,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { * @return */ override def genExtendedColumnarValidationRules(): List[SparkSession => Rule[SparkPlan]] = { - val buf: ListBuffer[SparkSession => Rule[SparkPlan]] = - ListBuffer(BloomFilterMightContainJointRewriteRule.apply, ArrowScanReplaceRule.apply) - if (GlutenConfig.getConf.enableInputFileNameReplaceRule) { - buf += InputFileNameReplaceRule.apply - } - buf.result + List(BloomFilterMightContainJointRewriteRule.apply, ArrowScanReplaceRule.apply) } /** diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/InputFileNameReplaceRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/InputFileNameReplaceRule.scala deleted file mode 100644 index cd3f50d8e77f..000000000000 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/InputFileNameReplaceRule.scala +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{FileSourceScanExec, ProjectExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec -import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan -import org.apache.spark.sql.types.{LongType, StringType} - -object InputFileNameReplaceRule { - val replacedInputFileName = "$input_file_name$" - val replacedInputFileBlockStart = "$input_file_block_start$" - val replacedInputFileBlockLength = "$input_file_block_length$" -} - -case class InputFileNameReplaceRule(spark: SparkSession) extends Rule[SparkPlan] { - import InputFileNameReplaceRule._ - - private def isInputFileName(expr: Expression): Boolean = { - expr match { - case _: InputFileName => true - case _ => false - } - } - - private def isInputFileBlockStart(expr: Expression): Boolean = { - expr match { - case _: InputFileBlockStart => true - case _ => false - } - } - - private def isInputFileBlockLength(expr: Expression): Boolean = { - expr match { - case _: InputFileBlockLength => true - case _ => false - } - } - - override def apply(plan: SparkPlan): SparkPlan = { - val replacedExprs = scala.collection.mutable.Map[String, AttributeReference]() - - def hasParquetScan(plan: SparkPlan): Boolean = { - plan match { - case fileScan: FileSourceScanExec - if fileScan.relation.fileFormat.isInstanceOf[ParquetFileFormat] => - true - case batchScan: BatchScanExec => - batchScan.scan match { - case _: ParquetScan => true - case _ => false - } - case _ => plan.children.exists(hasParquetScan) - } - } - - def mayNeedConvert(expr: Expression): Boolean = { - expr match { - case e if isInputFileName(e) => true - case s if isInputFileBlockStart(s) => true - case l if isInputFileBlockLength(l) => true - case other => other.children.exists(mayNeedConvert) - } - } - - def doConvert(expr: Expression): Expression = { - expr match { - case e if isInputFileName(e) => - replacedExprs.getOrElseUpdate( - replacedInputFileName, - AttributeReference(replacedInputFileName, StringType, true)()) - case s if isInputFileBlockStart(s) => - replacedExprs.getOrElseUpdate( - replacedInputFileBlockStart, - AttributeReference(replacedInputFileBlockStart, LongType, true)() - ) - case l if isInputFileBlockLength(l) => - replacedExprs.getOrElseUpdate( - replacedInputFileBlockLength, - AttributeReference(replacedInputFileBlockLength, LongType, true)() - ) - case other => - other.withNewChildren(other.children.map(child => doConvert(child))) - } - } - - def ensureChildOutputHasNewAttrs(plan: SparkPlan): SparkPlan = { - plan match { - case _ @ProjectExec(projectList, child) => - var newProjectList = projectList - for ((_, newAttr) <- replacedExprs) { - if (!newProjectList.exists(attr => attr.exprId == newAttr.exprId)) { - newProjectList = newProjectList :+ newAttr.toAttribute - } - } - val newChild = ensureChildOutputHasNewAttrs(child) - ProjectExec(newProjectList, newChild) - case f: FileSourceScanExec => - var newOutput = f.output - for ((_, newAttr) <- replacedExprs) { - if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) { - newOutput = newOutput :+ newAttr.toAttribute - } - } - f.copy(output = newOutput) - - case b: BatchScanExec => - var newOutput = b.output - for ((_, newAttr) <- replacedExprs) { - if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) { - newOutput = newOutput :+ newAttr - } - } - b.copy(output = newOutput) - case other => - val newChildren = other.children.map(ensureChildOutputHasNewAttrs) - other.withNewChildren(newChildren) - } - } - - def replaceInputFileNameInProject(plan: SparkPlan): SparkPlan = { - plan match { - case _ @ProjectExec(projectList, child) - if projectList.exists(mayNeedConvert) && hasParquetScan(plan) => - val newProjectList = projectList.map { - expr => doConvert(expr).asInstanceOf[NamedExpression] - } - val newChild = replaceInputFileNameInProject(ensureChildOutputHasNewAttrs(child)) - ProjectExec(newProjectList, newChild) - case other => - val newChildren = other.children.map(replaceInputFileNameInProject) - other.withNewChildren(newChildren) - } - } - replaceInputFileNameInProject(plan) - } -} diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index a2baf95ecdc0..bd32a799c3ac 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -645,13 +645,9 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { } test("Test input_file_name function") { - withSQLConf( - "spark.gluten.sql.enableInputFileNameReplaceRule" -> "true" - ) { - runQueryAndCompare("""SELECT input_file_name(), l_orderkey - | from lineitem limit 100""".stripMargin) { - checkGlutenOperatorMatch[ProjectExecTransformer] - } + runQueryAndCompare("""SELECT input_file_name(), l_orderkey + | from lineitem limit 100""".stripMargin) { + checkGlutenOperatorMatch[ProjectExecTransformer] } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index b132366e6e1d..50292839b684 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -44,6 +44,7 @@ trait BackendSettingsApi { def supportNativeWrite(fields: Array[StructField]): Boolean = true def supportNativeMetadataColumns(): Boolean = false def supportNativeRowIndexColumn(): Boolean = false + def supportNativeInputFileRelatedExpr(): Boolean = false def supportExpandExec(): Boolean = false def supportSortExec(): Boolean = false diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala index 8ed2137f4489..15fc8bea7054 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala @@ -30,7 +30,7 @@ object MiscColumnarRules { object TransformPreOverrides { def apply(): TransformPreOverrides = { TransformPreOverrides( - List(OffloadFilter()), + List(OffloadProject(), OffloadFilter()), List( OffloadOthers(), OffloadAggregate(), diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 75da28e30d39..11db0bc1faf1 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -26,15 +26,19 @@ import org.apache.gluten.utils.{LogLevelUtil, PlanUtil} import org.apache.spark.api.python.EvalPythonExecTransformer import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.datasources.WriteFilesExec -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanExecBase} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BatchEvalPythonExec} import org.apache.spark.sql.execution.window.{WindowExec, WindowGroupLimitExecShim} import org.apache.spark.sql.hive.HiveTableScanExecTransformer +import org.apache.spark.sql.types.{LongType, StringType} + +import scala.collection.mutable.Map /** * Converts a vanilla Spark plan node into Gluten plan node. Gluten plan is supposed to be executed @@ -181,7 +185,138 @@ case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil { case other => other } } +} + +case class OffloadProject() extends OffloadSingleNode with LogLevelUtil { + private def containsInputFileRelatedExpr(expr: Expression): Boolean = { + expr match { + case _: InputFileName | _: InputFileBlockStart | _: InputFileBlockLength => true + case _ => expr.children.exists(containsInputFileRelatedExpr) + } + } + + private def rewriteExpr( + expr: Expression, + replacedExprs: Map[String, AttributeReference]): Expression = { + expr match { + case _: InputFileName => + replacedExprs.getOrElseUpdate( + expr.prettyName, + AttributeReference(expr.prettyName, StringType, false)()) + case _: InputFileBlockStart => + replacedExprs.getOrElseUpdate( + expr.prettyName, + AttributeReference(expr.prettyName, LongType, false)()) + case _: InputFileBlockLength => + replacedExprs.getOrElseUpdate( + expr.prettyName, + AttributeReference(expr.prettyName, LongType, false)()) + case other => + other.withNewChildren(other.children.map(child => rewriteExpr(child, replacedExprs))) + } + } + + private def addMetadataCol( + plan: SparkPlan, + replacedExprs: Map[String, AttributeReference]): SparkPlan = { + def genNewOutput(output: Seq[Attribute]): Seq[Attribute] = { + var newOutput = output + for ((_, newAttr) <- replacedExprs) { + if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) { + newOutput = newOutput :+ newAttr + } + } + newOutput + } + def genNewProjectList(projectList: Seq[NamedExpression]): Seq[NamedExpression] = { + var newProjectList = projectList + for ((_, newAttr) <- replacedExprs) { + if (!newProjectList.exists(attr => attr.exprId == newAttr.exprId)) { + newProjectList = newProjectList :+ newAttr.toAttribute + } + } + newProjectList + } + + plan match { + case f: FileSourceScanExec => + f.copy(output = genNewOutput(f.output)) + case f: FileSourceScanExecTransformer => + f.copy(output = genNewOutput(f.output)) + case b: BatchScanExec => + b.copy(output = genNewOutput(b.output).asInstanceOf[Seq[AttributeReference]]) + case b: BatchScanExecTransformer => + b.copy(output = genNewOutput(b.output).asInstanceOf[Seq[AttributeReference]]) + case p @ ProjectExec(projectList, child) => + p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs)) + case p @ ProjectExecTransformer(projectList, child) => + p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs)) + case _ => plan.withNewChildren(plan.children.map(addMetadataCol(_, replacedExprs))) + } + } + + private def tryOffloadProjectExecWithInputFileRelatedExprs( + projectExec: ProjectExec): SparkPlan = { + def findScanNodes(plan: SparkPlan): Seq[SparkPlan] = { + plan.collect { + case f @ (_: FileSourceScanExec | _: AbstractFileSourceScanExec | + _: DataSourceV2ScanExecBase) => + f + } + } + val addHint = AddTransformHintRule() + val newProjectList = projectExec.projectList.filterNot(containsInputFileRelatedExpr) + val newProjectExec = ProjectExec(newProjectList, projectExec.child) + addHint.apply(newProjectExec) + if (TransformHints.isNotTransformable(newProjectExec)) { + // Project is still not transformable after remove `input_file_name` expressions. + projectExec + } else { + // the project with `input_file_name` expression should have at most + // one data source, reference: + // https://github.com/apache/spark/blob/e459674127e7b21e2767cc62d10ea6f1f941936c + // /sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala#L506 + val leafScans = findScanNodes(projectExec) + assert(leafScans.size <= 1) + if (leafScans.isEmpty || TransformHints.isNotTransformable(leafScans(0))) { + // It means + // 1. projectExec has `input_file_name` but no scan child. + // 2. It has scan child node but the scan node fallback. + projectExec + } else { + val replacedExprs = scala.collection.mutable.Map[String, AttributeReference]() + val newProjectList = projectExec.projectList.map { + expr => rewriteExpr(expr, replacedExprs).asInstanceOf[NamedExpression] + } + val newChild = addMetadataCol(projectExec.child, replacedExprs) + logDebug( + s"Columnar Processing for ${projectExec.getClass} with " + + s"ProjectList ${projectExec.projectList} is currently supported.") + ProjectExecTransformer(newProjectList, newChild) + } + } + } + private def genProjectExec(projectExec: ProjectExec): SparkPlan = { + if ( + TransformHints.isNotTransformable(projectExec) && + BackendsApiManager.getSettings.supportNativeInputFileRelatedExpr() && + projectExec.projectList.exists(containsInputFileRelatedExpr) + ) { + tryOffloadProjectExecWithInputFileRelatedExprs(projectExec) + } else if (TransformHints.isNotTransformable(projectExec)) { + projectExec + } else { + logDebug(s"Columnar Processing for ${projectExec.getClass} is currently supported.") + ProjectExecTransformer(projectExec.projectList, projectExec.child) + } + } + + override def offload(plan: SparkPlan): SparkPlan = plan match { + case p: ProjectExec => + genProjectExec(p) + case other => other + } } // Filter transformation. @@ -261,10 +396,6 @@ object OffloadOthers { case plan: CoalesceExec => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarCoalesceExec(plan.numPartitions, plan.child) - case plan: ProjectExec => - val columnarChild = plan.child - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ProjectExecTransformer(plan.projectList, columnarChild) case plan: SortAggregateExec => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") HashAggregateExecBaseTransformer.from(plan) { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala index 9a54a101453f..30e4c0a79823 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala @@ -54,6 +54,7 @@ case class EnumeratedTransform(session: SparkSession, outputsColumnar: Boolean) RasOffload.from[BaseJoinExec](OffloadJoin()).toRule, RasOffloadHashAggregate.toRule, RasOffloadFilter.toRule, + RasOffloadProject.toRule, RasOffload.from[DataSourceV2ScanExecBase](OffloadOthers()).toRule, RasOffload.from[DataSourceScanExec](OffloadOthers()).toRule, RasOffload diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadProject.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadProject.scala new file mode 100644 index 000000000000..0bbf57499b73 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadProject.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.columnar.enumerated + +import org.apache.gluten.execution.ProjectExecTransformer + +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} + +object RasOffloadProject extends RasOffload { + override def offload(node: SparkPlan): SparkPlan = node match { + case ProjectExec(projectList, child) => + ProjectExecTransformer(projectList, child) + case other => + other + } + + override def typeIdentifier(): RasOffload.TypeIdentifier = + RasOffload.TypeIdentifier.of[ProjectExec] +} diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala index edd2a5a9672d..a4b530e637af 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala @@ -16,4 +16,40 @@ */ package org.apache.spark.sql -class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait {} +import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.functions.{expr, input_file_name} +import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} + +class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait { + testGluten("input_file_name with scan is fallback") { + withTempPath { + dir => + val rawData = Seq( + Row(1, "Alice", Seq(Row(Seq(1, 2, 3)))), + Row(2, "Bob", Seq(Row(Seq(4, 5)))), + Row(3, "Charlie", Seq(Row(Seq(6, 7, 8, 9)))) + ) + val schema = StructType( + Array( + StructField("id", IntegerType, nullable = false), + StructField("name", StringType, nullable = false), + StructField( + "nested_column", + ArrayType( + StructType(Array( + StructField("array_in_struct", ArrayType(IntegerType), nullable = true) + ))), + nullable = true) + )) + val data: DataFrame = spark.createDataFrame(sparkContext.parallelize(rawData), schema) + data.write.parquet(dir.getCanonicalPath) + + val q = + spark.read.parquet(dir.getCanonicalPath).select(input_file_name(), expr("nested_column")) + val firstRow = q.head() + assert(firstRow.getString(0).contains(dir.toURI.getPath)) + val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p } + assert(project.size == 1) + } + } +} diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala index edd2a5a9672d..a4b530e637af 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala @@ -16,4 +16,40 @@ */ package org.apache.spark.sql -class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait {} +import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.functions.{expr, input_file_name} +import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} + +class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait { + testGluten("input_file_name with scan is fallback") { + withTempPath { + dir => + val rawData = Seq( + Row(1, "Alice", Seq(Row(Seq(1, 2, 3)))), + Row(2, "Bob", Seq(Row(Seq(4, 5)))), + Row(3, "Charlie", Seq(Row(Seq(6, 7, 8, 9)))) + ) + val schema = StructType( + Array( + StructField("id", IntegerType, nullable = false), + StructField("name", StringType, nullable = false), + StructField( + "nested_column", + ArrayType( + StructType(Array( + StructField("array_in_struct", ArrayType(IntegerType), nullable = true) + ))), + nullable = true) + )) + val data: DataFrame = spark.createDataFrame(sparkContext.parallelize(rawData), schema) + data.write.parquet(dir.getCanonicalPath) + + val q = + spark.read.parquet(dir.getCanonicalPath).select(input_file_name(), expr("nested_column")) + val firstRow = q.head() + assert(firstRow.getString(0).contains(dir.toURI.getPath)) + val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p } + assert(project.size == 1) + } + } +} diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala index edd2a5a9672d..a4b530e637af 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala @@ -16,4 +16,40 @@ */ package org.apache.spark.sql -class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait {} +import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.functions.{expr, input_file_name} +import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} + +class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait { + testGluten("input_file_name with scan is fallback") { + withTempPath { + dir => + val rawData = Seq( + Row(1, "Alice", Seq(Row(Seq(1, 2, 3)))), + Row(2, "Bob", Seq(Row(Seq(4, 5)))), + Row(3, "Charlie", Seq(Row(Seq(6, 7, 8, 9)))) + ) + val schema = StructType( + Array( + StructField("id", IntegerType, nullable = false), + StructField("name", StringType, nullable = false), + StructField( + "nested_column", + ArrayType( + StructType(Array( + StructField("array_in_struct", ArrayType(IntegerType), nullable = true) + ))), + nullable = true) + )) + val data: DataFrame = spark.createDataFrame(sparkContext.parallelize(rawData), schema) + data.write.parquet(dir.getCanonicalPath) + + val q = + spark.read.parquet(dir.getCanonicalPath).select(input_file_name(), expr("nested_column")) + val firstRow = q.head() + assert(firstRow.getString(0).contains(dir.toURI.getPath)) + val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p } + assert(project.size == 1) + } + } +} diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala index edd2a5a9672d..8a28c4e98a26 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenColumnExpressionSuite.scala @@ -16,4 +16,40 @@ */ package org.apache.spark.sql -class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait {} +import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.functions.{expr, input_file_name} +import org.apache.spark.sql.types._ + +class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait { + testGluten("input_file_name with scan is fallback") { + withTempPath { + dir => + val rawData = Seq( + Row(1, "Alice", Seq(Row(Seq(1, 2, 3)))), + Row(2, "Bob", Seq(Row(Seq(4, 5)))), + Row(3, "Charlie", Seq(Row(Seq(6, 7, 8, 9)))) + ) + val schema = StructType( + Array( + StructField("id", IntegerType, nullable = false), + StructField("name", StringType, nullable = false), + StructField( + "nested_column", + ArrayType( + StructType(Array( + StructField("array_in_struct", ArrayType(IntegerType), nullable = true) + ))), + nullable = true) + )) + val data: DataFrame = spark.createDataFrame(sparkContext.parallelize(rawData), schema) + data.write.parquet(dir.getCanonicalPath) + + val q = + spark.read.parquet(dir.getCanonicalPath).select(input_file_name(), expr("nested_column")) + val firstRow = q.head() + assert(firstRow.getString(0).contains(dir.toURI.getPath)) + val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p } + assert(project.size == 1) + } + } +} diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index 58b99a7f3064..ec80ba86a7b9 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -37,7 +37,6 @@ case class GlutenNumaBindingInfo( class GlutenConfig(conf: SQLConf) extends Logging { import GlutenConfig._ - def enableInputFileNameReplaceRule: Boolean = conf.getConf(INPUT_FILE_NAME_REPLACE_RULE_ENABLED) def enableAnsiMode: Boolean = conf.ansiEnabled def enableGluten: Boolean = conf.getConf(GLUTEN_ENABLED) @@ -767,16 +766,6 @@ object GlutenConfig { .booleanConf .createWithDefault(GLUTEN_ENABLE_BY_DEFAULT) - val INPUT_FILE_NAME_REPLACE_RULE_ENABLED = - buildConf("spark.gluten.sql.enableInputFileNameReplaceRule") - .internal() - .doc( - "Experimental: This config apply for velox backend to specify whether to enable " + - "inputFileNameReplaceRule to support offload input_file_name " + - "expression to native.") - .booleanConf - .createWithDefault(false) - // FIXME the option currently controls both JVM and native validation against a Substrait plan. val NATIVE_VALIDATION_ENABLED = buildConf("spark.gluten.sql.enable.native.validation") diff --git a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala index b9c37ef3d730..b036d6dd9a41 100644 --- a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala +++ b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.csv.CSVOptions -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution} @@ -236,8 +236,13 @@ class Spark32Shims extends SparkShims { override def generateMetadataColumns( file: PartitionedFile, - metadataColumnNames: Seq[String]): JMap[String, String] = - new JHashMap[String, String]() + metadataColumnNames: Seq[String]): JMap[String, String] = { + val metadataColumn = new JHashMap[String, String]() + metadataColumn.put(InputFileName().prettyName, file.filePath) + metadataColumn.put(InputFileBlockStart().prettyName, file.start.toString) + metadataColumn.put(InputFileBlockLength().prettyName, file.length.toString) + metadataColumn + } def getAnalysisExceptionPlan(ae: AnalysisException): Option[LogicalPlan] = { ae.plan diff --git a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala index d6292b46c261..8b12c2642c55 100644 --- a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala @@ -231,6 +231,9 @@ class Spark33Shims extends SparkShims { case _ => } } + metadataColumn.put(InputFileName().prettyName, file.filePath) + metadataColumn.put(InputFileBlockStart().prettyName, file.start.toString) + metadataColumn.put(InputFileBlockLength().prettyName, file.length.toString) metadataColumn } diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index c718f4ed25d6..420be8511937 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -240,8 +240,9 @@ class Spark34Shims extends SparkShims { case _ => } } - - // TODO: row_index metadata support + metadataColumn.put(InputFileName().prettyName, file.filePath.toString) + metadataColumn.put(InputFileBlockStart().prettyName, file.start.toString) + metadataColumn.put(InputFileBlockLength().prettyName, file.length.toString) metadataColumn } diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index f6feae01a8b2..8ac8d323efd6 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -238,8 +238,9 @@ class Spark35Shims extends SparkShims { case _ => } } - - // TODO row_index metadata support + metadataColumn.put(InputFileName().prettyName, file.filePath.toString) + metadataColumn.put(InputFileBlockStart().prettyName, file.start.toString) + metadataColumn.put(InputFileBlockLength().prettyName, file.length.toString) metadataColumn }