Skip to content

Commit

Permalink
[VL] Fix offload input_file_name assert error (#6390)
Browse files Browse the repository at this point in the history
  • Loading branch information
zml1206 committed Jul 17, 2024
1 parent 4c83e34 commit f2995e7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,15 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest {
| from lineitem limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}

runQueryAndCompare("""SELECT input_file_name(), l_orderkey
| from
| (select l_orderkey from lineitem
| union all
| select o_orderkey as l_orderkey from orders)
| limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}

test("Test spark_partition_id function") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,17 @@ case class OffloadProject() extends OffloadSingleNode with LogLevelUtil {
p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs))
case p @ ProjectExecTransformer(projectList, child) =>
p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs))
case u @ UnionExec(children) =>
val newFirstChild = addMetadataCol(children.head, replacedExprs)
val newOtherChildren = children.tail.map {
child =>
// Make sure exprId is unique in each child of Union.
val newReplacedExprs = replacedExprs.map {
expr => (expr._1, AttributeReference(expr._2.name, expr._2.dataType, false)())
}
addMetadataCol(child, newReplacedExprs)
}
u.copy(children = newFirstChild +: newOtherChildren)
case _ => plan.withNewChildren(plan.children.map(addMetadataCol(_, replacedExprs)))
}
}
Expand All @@ -299,16 +310,15 @@ case class OffloadProject() extends OffloadSingleNode with LogLevelUtil {
// Project is still not transformable after remove `input_file_name` expressions.
projectExec
} else {
// the project with `input_file_name` expression should have at most
// one data source, reference:
// the project with `input_file_name` expression may have multiple data source
// by union all, reference:
// https://github.com/apache/spark/blob/e459674127e7b21e2767cc62d10ea6f1f941936c
// /sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala#L506
// /sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala#L519
val leafScans = findScanNodes(projectExec)
assert(leafScans.size <= 1)
if (leafScans.isEmpty || FallbackTags.nonEmpty(leafScans(0))) {
if (leafScans.isEmpty || leafScans.exists(FallbackTags.nonEmpty)) {
// It means
// 1. projectExec has `input_file_name` but no scan child.
// 2. It has scan child node but the scan node fallback.
// 2. It has scan children node but the scan node fallback.
projectExec
} else {
val replacedExprs = scala.collection.mutable.Map[String, AttributeReference]()
Expand Down

0 comments on commit f2995e7

Please sign in to comment.