Skip to content

Commit

Permalink
[VL] [Core] Spark Input_file_name Support (#6021)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyangxiaozhu authored Jun 19, 2024
1 parent 800cadd commit cf3a98e
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.gluten.backendsapi.velox
import org.apache.gluten.GlutenNumaBindingInfo
import org.apache.gluten.backendsapi.IteratorApi
import org.apache.gluten.execution._
import org.apache.gluten.extension.InputFileNameReplaceRule
import org.apache.gluten.metrics.IMetrics
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.plan.PlanNode
Expand Down Expand Up @@ -112,7 +113,7 @@ class VeloxIteratorApi extends IteratorApi with Logging {
val fileSizes = new JArrayList[JLong]()
val modificationTimes = new JArrayList[JLong]()
val partitionColumns = new JArrayList[JMap[String, String]]
var metadataColumns = new JArrayList[JMap[String, String]]
val metadataColumns = new JArrayList[JMap[String, String]]
files.foreach {
file =>
// The "file.filePath" in PartitionedFile is not the original encoded path, so the decoded
Expand All @@ -132,6 +133,13 @@ class VeloxIteratorApi extends IteratorApi with Logging {
}
val metadataColumn =
SparkShimLoader.getSparkShims.generateMetadataColumns(file, metadataColumnNames)
metadataColumn.put(InputFileNameReplaceRule.replacedInputFileName, file.filePath.toString)
metadataColumn.put(
InputFileNameReplaceRule.replacedInputFileBlockStart,
file.start.toString)
metadataColumn.put(
InputFileNameReplaceRule.replacedInputFileBlockLength,
file.length.toString)
metadataColumns.add(metadataColumn)
val partitionColumn = new JHashMap[String, String]()
for (i <- 0 until file.partitionValues.numFields) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
*/
override def genExtendedColumnarValidationRules(): List[SparkSession => Rule[SparkPlan]] = List(
BloomFilterMightContainJointRewriteRule.apply,
ArrowScanReplaceRule.apply
ArrowScanReplaceRule.apply,
InputFileNameReplaceRule.apply
)

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{FileSourceScanExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.types.{LongType, StringType}

object InputFileNameReplaceRule {
val replacedInputFileName = "$input_file_name$"
val replacedInputFileBlockStart = "$input_file_block_start$"
val replacedInputFileBlockLength = "$input_file_block_length$"
}

case class InputFileNameReplaceRule(spark: SparkSession) extends Rule[SparkPlan] {
import InputFileNameReplaceRule._

private def isInputFileName(expr: Expression): Boolean = {
expr match {
case _: InputFileName => true
case _ => false
}
}

private def isInputFileBlockStart(expr: Expression): Boolean = {
expr match {
case _: InputFileBlockStart => true
case _ => false
}
}

private def isInputFileBlockLength(expr: Expression): Boolean = {
expr match {
case _: InputFileBlockLength => true
case _ => false
}
}

override def apply(plan: SparkPlan): SparkPlan = {
val replacedExprs = scala.collection.mutable.Map[String, AttributeReference]()

def hasParquetScan(plan: SparkPlan): Boolean = {
plan match {
case fileScan: FileSourceScanExec
if fileScan.relation.fileFormat.isInstanceOf[ParquetFileFormat] =>
true
case batchScan: BatchScanExec =>
batchScan.scan match {
case _: ParquetScan => true
case _ => false
}
case _ => plan.children.exists(hasParquetScan)
}
}

def mayNeedConvert(expr: Expression): Boolean = {
expr match {
case e if isInputFileName(e) => true
case s if isInputFileBlockStart(s) => true
case l if isInputFileBlockLength(l) => true
case other => other.children.exists(mayNeedConvert)
}
}

def doConvert(expr: Expression): Expression = {
expr match {
case e if isInputFileName(e) =>
replacedExprs.getOrElseUpdate(
replacedInputFileName,
AttributeReference(replacedInputFileName, StringType, true)())
case s if isInputFileBlockStart(s) =>
replacedExprs.getOrElseUpdate(
replacedInputFileBlockStart,
AttributeReference(replacedInputFileBlockStart, LongType, true)()
)
case l if isInputFileBlockLength(l) =>
replacedExprs.getOrElseUpdate(
replacedInputFileBlockLength,
AttributeReference(replacedInputFileBlockLength, LongType, true)()
)
case other =>
other.withNewChildren(other.children.map(child => doConvert(child)))
}
}

def ensureChildOutputHasNewAttrs(plan: SparkPlan): SparkPlan = {
plan match {
case _ @ProjectExec(projectList, child) =>
var newProjectList = projectList
for ((_, newAttr) <- replacedExprs) {
if (!newProjectList.exists(attr => attr.exprId == newAttr.exprId)) {
newProjectList = newProjectList :+ newAttr.toAttribute
}
}
val newChild = ensureChildOutputHasNewAttrs(child)
ProjectExec(newProjectList, newChild)
case f: FileSourceScanExec =>
var newOutput = f.output
for ((_, newAttr) <- replacedExprs) {
if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) {
newOutput = newOutput :+ newAttr.toAttribute
}
}
f.copy(output = newOutput)

case b: BatchScanExec =>
var newOutput = b.output
for ((_, newAttr) <- replacedExprs) {
if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) {
newOutput = newOutput :+ newAttr
}
}
b.copy(output = newOutput)
case other =>
val newChildren = other.children.map(ensureChildOutputHasNewAttrs)
other.withNewChildren(newChildren)
}
}

def replaceInputFileNameInProject(plan: SparkPlan): SparkPlan = {
plan match {
case _ @ProjectExec(projectList, child)
if projectList.exists(mayNeedConvert) && hasParquetScan(plan) =>
val newProjectList = projectList.map {
expr => doConvert(expr).asInstanceOf[NamedExpression]
}
val newChild = replaceInputFileNameInProject(ensureChildOutputHasNewAttrs(child))
ProjectExec(newProjectList, newChild)
case other =>
val newChildren = other.children.map(replaceInputFileNameInProject)
other.withNewChildren(newChildren)
}
}
replaceInputFileNameInProject(plan)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,13 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest {
}
}

test("Test input_file_name function") {
runQueryAndCompare("""SELECT input_file_name(), l_orderkey
| from lineitem limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}

test("Test spark_partition_id function") {
runQueryAndCompare("""SELECT spark_partition_id(), l_orderkey
| from lineitem limit 100""".stripMargin) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ class HeuristicApplier(session: SparkSession)
(spark: SparkSession) => FallbackOnANSIMode(spark),
(spark: SparkSession) => FallbackMultiCodegens(spark),
(spark: SparkSession) => PlanOneRowRelation(spark),
(_: SparkSession) => FallbackEmptySchemaRelation(),
(_: SparkSession) => RewriteSubqueryBroadcast()
) :::
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules() :::
List(
(_: SparkSession) => FallbackEmptySchemaRelation(),
(spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark),
(_: SparkSession) => RewriteSparkPlanRulesManager(),
(_: SparkSession) => AddTransformHintRule()
Expand Down

0 comments on commit cf3a98e

Please sign in to comment.