diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index 9d0a926e3b0f..cd9819c3e8e7 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -661,6 +661,15 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { | from lineitem limit 100""".stripMargin) { checkGlutenOperatorMatch[ProjectExecTransformer] } + + runQueryAndCompare("""SELECT input_file_name(), l_orderkey + | from + | (select l_orderkey from lineitem + | union all + | select o_orderkey as l_orderkey from orders) + | limit 100""".stripMargin) { + checkGlutenOperatorMatch[ProjectExecTransformer] + } } test("Test spark_partition_id function") { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 62c72af792e9..792968cd2ff5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -278,6 +278,17 @@ case class OffloadProject() extends OffloadSingleNode with LogLevelUtil { p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs)) case p @ ProjectExecTransformer(projectList, child) => p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs)) + case u @ UnionExec(children) => + val newFirstChild = addMetadataCol(children.head, replacedExprs) + val newOtherChildren = children.tail.map { + child => + // Make sure exprId is unique in each child of Union. + val newReplacedExprs = replacedExprs.map { + expr => (expr._1, AttributeReference(expr._2.name, expr._2.dataType, false)()) + } + addMetadataCol(child, newReplacedExprs) + } + u.copy(children = newFirstChild +: newOtherChildren) case _ => plan.withNewChildren(plan.children.map(addMetadataCol(_, replacedExprs))) } } @@ -299,16 +310,15 @@ case class OffloadProject() extends OffloadSingleNode with LogLevelUtil { // Project is still not transformable after remove `input_file_name` expressions. projectExec } else { - // the project with `input_file_name` expression should have at most - // one data source, reference: + // the project with `input_file_name` expression may have multiple data source + // by union all, reference: // https://github.com/apache/spark/blob/e459674127e7b21e2767cc62d10ea6f1f941936c - // /sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala#L506 + // /sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala#L519 val leafScans = findScanNodes(projectExec) - assert(leafScans.size <= 1) - if (leafScans.isEmpty || FallbackTags.nonEmpty(leafScans(0))) { + if (leafScans.isEmpty || leafScans.exists(FallbackTags.nonEmpty)) { // It means // 1. projectExec has `input_file_name` but no scan child. - // 2. It has scan child node but the scan node fallback. + // 2. It has scan children node but the scan node fallback. projectExec } else { val replacedExprs = scala.collection.mutable.Map[String, AttributeReference]()