Skip to content

Commit

Permalink
[VL] Fix input_file_name results empty string
Browse files Browse the repository at this point in the history
  • Loading branch information
zml1206 authored and zhztheplayer committed Sep 4, 2024
1 parent bdf3421 commit 52bdaf1
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ScalarFunctionsValidateSuiteRasOff extends ScalarFunctionsValidateSuite {
super.sparkConf
.set("spark.gluten.ras.enabled", "false")
}

import testImplicits._
// Since https://github.com/apache/incubator-gluten/pull/6200.
test("Test input_file_name function") {
runQueryAndCompare("""SELECT input_file_name(), l_orderkey
Expand All @@ -44,6 +44,22 @@ class ScalarFunctionsValidateSuiteRasOff extends ScalarFunctionsValidateSuite {
| limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}

withTempPath {
path =>
Seq(1, 2, 3).toDF("a").write.json(path.getCanonicalPath)
spark.read.json(path.getCanonicalPath).createOrReplaceTempView("json_table")
val sql =
"""
|SELECT input_file_name(), a
|FROM
|(SELECT a FROM json_table
|UNION ALL
|SELECT l_orderkey as a FROM lineitem)
|LIMIT 100
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ object MiscColumnarRules {
object TransformPreOverrides {
def apply(): TransformPreOverrides = {
TransformPreOverrides(
List(OffloadProject(), OffloadFilter()),
List(OffloadFilter()),
List(
OffloadProject(),
OffloadOthers(),
OffloadAggregate(),
OffloadExchange(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.gluten.utils.{LogLevelUtil, PlanUtil}

import org.apache.spark.api.python.EvalPythonExecTransformer
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -271,13 +271,35 @@ case class OffloadProject() extends OffloadSingleNode with LogLevelUtil {
newProjectList
}

def replaceInputFileAttr(namedExpr: NamedExpression): NamedExpression = namedExpr.name match {
case "input_file_name" =>
Alias(InputFileName(), namedExpr.name)(namedExpr.exprId)
case "input_file_block_start" =>
Alias(InputFileBlockStart(), namedExpr.name)(namedExpr.exprId)
case "input_file_block_length" =>
Alias(InputFileBlockLength(), namedExpr.name)(namedExpr.exprId)
case _ => namedExpr
}

plan match {
case f: FileSourceScanExec =>
f.copy(output = genNewOutput(f.output))
if (FallbackTags.nonEmpty(f)) {
val p = ProjectExec(genNewOutput(f.output).map(replaceInputFileAttr), f)
FallbackTags.add(p, "fallback input file expression")
p
} else {
f.copy(output = genNewOutput(f.output))
}
case f: FileSourceScanExecTransformer =>
f.copy(output = genNewOutput(f.output))
case b: BatchScanExec =>
b.copy(output = genNewOutput(b.output).asInstanceOf[Seq[AttributeReference]])
if (FallbackTags.nonEmpty(b)) {
val p = ProjectExec(genNewOutput(b.output).map(replaceInputFileAttr), b)
FallbackTags.add(p, "fallback input file expression")
p
} else {
b.copy(output = genNewOutput(b.output).asInstanceOf[Seq[AttributeReference]])
}
case b: BatchScanExecTransformer =>
b.copy(output = genNewOutput(b.output).asInstanceOf[Seq[AttributeReference]])
case p @ ProjectExec(projectList, child) =>
Expand Down Expand Up @@ -308,34 +330,36 @@ case class OffloadProject() extends OffloadSingleNode with LogLevelUtil {
f
}
}
val addHint = AddFallbackTagRule()
val newProjectList = projectExec.projectList.filterNot(containsInputFileRelatedExpr)
val newProjectExec = ProjectExec(newProjectList, projectExec.child)
addHint.apply(newProjectExec)
if (FallbackTags.nonEmpty(newProjectExec)) {
// Project is still not transformable after remove `input_file_name` expressions.
val transformNodes =
projectExec.collect { case p if !FallbackTags.nonEmpty(p) => p }
if (transformNodes.size == 0) {
projectExec
} else {
// the project with `input_file_name` expression may have multiple data source
// by union all, reference:
// https://github.com/apache/spark/blob/e459674127e7b21e2767cc62d10ea6f1f941936c
// /sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala#L519
val leafScans = findScanNodes(projectExec)
if (leafScans.isEmpty || leafScans.exists(FallbackTags.nonEmpty)) {
// It means
// 1. projectExec has `input_file_name` but no scan child.
// 2. It has scan children node but the scan node fallback.
if (leafScans.isEmpty) {
// It means projectExec has `input_file_name` but no scan child.
projectExec
} else {
val replacedExprs = scala.collection.mutable.Map[String, AttributeReference]()
val newProjectList = projectExec.projectList.map {
expr => rewriteExpr(expr, replacedExprs).asInstanceOf[NamedExpression]
}
val newChild = addMetadataCol(projectExec.child, replacedExprs)
logDebug(
s"Columnar Processing for ${projectExec.getClass} with " +
s"ProjectList ${projectExec.projectList} is currently supported.")
ProjectExecTransformer(newProjectList, newChild)
val newProject = ProjectExec(newProjectList, newChild)
val addHint = AddFallbackTagRule()
addHint.apply(newProject)
if (FallbackTags.nonEmpty(newProject)) {
newProject
} else {
logDebug(
s"Columnar Processing for ${projectExec.getClass} with " +
s"ProjectList ${projectExec.projectList} is currently supported.")
ProjectExecTransformer(newProjectList, newChild)
}
}
}
}
Expand Down

0 comments on commit 52bdaf1

Please sign in to comment.