Skip to content

Commit

Permalink
[GLUTEN-7144][VL][RAS] Spark input file function support
Browse files Browse the repository at this point in the history
  • Loading branch information
zml1206 committed Sep 6, 2024
1 parent 8078f24 commit 21314c2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ private object VeloxRuleApi {
def injectRas(injector: RasInjector): Unit = {
// Gluten RAS: Pre rules.
injector.inject(_ => RemoveTransitions)
injector.inject(_ => PushDownInputFileExpression.PreOffload)
injector.inject(c => FallbackOnANSIMode.apply(c.session))
injector.inject(c => PlanOneRowRelation.apply(c.session))
injector.inject(_ => FallbackEmptySchemaRelation())
Expand All @@ -106,6 +107,7 @@ private object VeloxRuleApi {
injector.inject(_ => RemoveTransitions)
injector.inject(_ => RemoveNativeWriteFilesSortAndProject())
injector.inject(c => RewriteTransformer.apply(c.session))
injector.inject(_ => PushDownInputFileExpression.PostOffload)
injector.inject(_ => EnsureLocalSortRequirements)
injector.inject(_ => EliminateLocalSort)
injector.inject(_ => CollapseProjectExecTransformer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,58 +28,13 @@ class ScalarFunctionsValidateSuiteRasOff extends ScalarFunctionsValidateSuite {
super.sparkConf
.set("spark.gluten.ras.enabled", "false")
}
import testImplicits._
// Since https://github.com/apache/incubator-gluten/pull/6200.
test("Test input_file_name function") {
runQueryAndCompare("""SELECT input_file_name(), l_orderkey
| from lineitem limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}

runQueryAndCompare("""SELECT input_file_name(), l_orderkey
| from
| (select l_orderkey from lineitem
| union all
| select o_orderkey as l_orderkey from orders)
| limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
withTempPath {
path =>
Seq(1, 2, 3).toDF("a").write.json(path.getCanonicalPath)
spark.read.json(path.getCanonicalPath).createOrReplaceTempView("json_table")
val sql =
"""
|SELECT input_file_name(), a
|FROM
|(SELECT a FROM json_table
|UNION ALL
|SELECT l_orderkey as a FROM lineitem)
|LIMIT 100
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
}
}

class ScalarFunctionsValidateSuiteRasOn extends ScalarFunctionsValidateSuite {
override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.gluten.ras.enabled", "true")
}

// TODO: input_file_name is not yet supported in RAS
ignore("Test input_file_name function") {
runQueryAndCompare("""SELECT input_file_name(), l_orderkey
| from lineitem limit 100""".stripMargin) { _ => }

runQueryAndCompare("""SELECT input_file_name(), l_orderkey
| from
| (select l_orderkey from lineitem
| union all
| select o_orderkey as l_orderkey from orders)
| limit 100""".stripMargin) { _ => }
}
}

abstract class ScalarFunctionsValidateSuite extends FunctionsValidateSuite {
Expand Down Expand Up @@ -1381,6 +1336,38 @@ abstract class ScalarFunctionsValidateSuite extends FunctionsValidateSuite {
}
}

test("Test input_file_name function") {
runQueryAndCompare("""SELECT input_file_name(), l_orderkey
| from lineitem limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}

runQueryAndCompare("""SELECT input_file_name(), l_orderkey
| from
| (select l_orderkey from lineitem
| union all
| select o_orderkey as l_orderkey from orders)
| limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}

withTempPath {
path =>
Seq(1, 2, 3).toDF("a").write.json(path.getCanonicalPath)
spark.read.json(path.getCanonicalPath).createOrReplaceTempView("json_table")
val sql =
"""
|SELECT input_file_name(), a
|FROM
|(SELECT a FROM json_table
|UNION ALL
|SELECT l_orderkey as a FROM lineitem)
|LIMIT 100
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
}

testWithSpecifiedSparkVersion("array insert", Some("3.4")) {
withTempPath {
path =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import org.apache.gluten.execution.ProjectExecTransformer

import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}

/** TODO: Map [[org.apache.gluten.extension.columnar.OffloadProject]] to RAS. */
object RasOffloadProject extends RasOffload {
override def offload(node: SparkPlan): SparkPlan = node match {
case ProjectExec(projectList, child) =>
Expand Down

0 comments on commit 21314c2

Please sign in to comment.