Skip to content

Commit

Permalink
[UT] Test input_file_name, input_file_block_start & input_file_block_…
Browse files Browse the repository at this point in the history
…length when scan falls back (#6318)
  • Loading branch information
gaoyangxiaozhu authored Jul 11, 2024
1 parent e8e93e7 commit 8e0d56e
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,32 @@ package org.apache.spark.sql

import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.functions.{expr, input_file_name}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType}

class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait {
testGluten("input_file_name with scan is fallback") {
withTempPath {
dir =>
val rawData = Seq(
Row(1, "Alice", Seq(Row(Seq(1, 2, 3)))),
Row(2, "Bob", Seq(Row(Seq(4, 5)))),
Row(3, "Charlie", Seq(Row(Seq(6, 7, 8, 9))))
)
val schema = StructType(
Array(
StructField("id", IntegerType, nullable = false),
StructField("name", StringType, nullable = false),
StructField(
"nested_column",
ArrayType(
StructType(Array(
StructField("array_in_struct", ArrayType(IntegerType), nullable = true)
))),
nullable = true)
))
val data: DataFrame = spark.createDataFrame(sparkContext.parallelize(rawData), schema)
data.write.parquet(dir.getCanonicalPath)
import testImplicits._
testGluten(
"input_file_name, input_file_block_start and input_file_block_length " +
"should fall back if scan falls back") {
withSQLConf(("spark.gluten.sql.columnar.filescan", "false")) {
withTempPath {
dir =>
val data = sparkContext.parallelize(0 to 10).toDF("id")
data.write.parquet(dir.getCanonicalPath)

val q =
spark.read.parquet(dir.getCanonicalPath).select(input_file_name(), expr("nested_column"))
val firstRow = q.head()
assert(firstRow.getString(0).contains(dir.toURI.getPath))
val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p }
assert(project.size == 1)
val q =
spark.read
.parquet(dir.getCanonicalPath)
.select(
input_file_name(),
expr("input_file_block_start()"),
expr("input_file_block_length()"))
val firstRow = q.head()
assert(firstRow.getString(0).contains(dir.toURI.getPath))
assert(firstRow.getLong(1) == 0)
assert(firstRow.getLong(2) > 0)
val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p }
assert(project.size == 1)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,32 @@ package org.apache.spark.sql

import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.functions.{expr, input_file_name}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType}

class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait {
testGluten("input_file_name with scan is fallback") {
withTempPath {
dir =>
val rawData = Seq(
Row(1, "Alice", Seq(Row(Seq(1, 2, 3)))),
Row(2, "Bob", Seq(Row(Seq(4, 5)))),
Row(3, "Charlie", Seq(Row(Seq(6, 7, 8, 9))))
)
val schema = StructType(
Array(
StructField("id", IntegerType, nullable = false),
StructField("name", StringType, nullable = false),
StructField(
"nested_column",
ArrayType(
StructType(Array(
StructField("array_in_struct", ArrayType(IntegerType), nullable = true)
))),
nullable = true)
))
val data: DataFrame = spark.createDataFrame(sparkContext.parallelize(rawData), schema)
data.write.parquet(dir.getCanonicalPath)
import testImplicits._
testGluten(
"input_file_name, input_file_block_start and input_file_block_length " +
"should fall back if scan falls back") {
withSQLConf(("spark.gluten.sql.columnar.filescan", "false")) {
withTempPath {
dir =>
val data = sparkContext.parallelize(0 to 10).toDF("id")
data.write.parquet(dir.getCanonicalPath)

val q =
spark.read.parquet(dir.getCanonicalPath).select(input_file_name(), expr("nested_column"))
val firstRow = q.head()
assert(firstRow.getString(0).contains(dir.toURI.getPath))
val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p }
assert(project.size == 1)
val q =
spark.read
.parquet(dir.getCanonicalPath)
.select(
input_file_name(),
expr("input_file_block_start()"),
expr("input_file_block_length()"))
val firstRow = q.head()
assert(firstRow.getString(0).contains(dir.toURI.getPath))
assert(firstRow.getLong(1) == 0)
assert(firstRow.getLong(2) > 0)
val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p }
assert(project.size == 1)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,32 @@ package org.apache.spark.sql

import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.functions.{expr, input_file_name}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType}

class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait {
testGluten("input_file_name with scan is fallback") {
withTempPath {
dir =>
val rawData = Seq(
Row(1, "Alice", Seq(Row(Seq(1, 2, 3)))),
Row(2, "Bob", Seq(Row(Seq(4, 5)))),
Row(3, "Charlie", Seq(Row(Seq(6, 7, 8, 9))))
)
val schema = StructType(
Array(
StructField("id", IntegerType, nullable = false),
StructField("name", StringType, nullable = false),
StructField(
"nested_column",
ArrayType(
StructType(Array(
StructField("array_in_struct", ArrayType(IntegerType), nullable = true)
))),
nullable = true)
))
val data: DataFrame = spark.createDataFrame(sparkContext.parallelize(rawData), schema)
data.write.parquet(dir.getCanonicalPath)
import testImplicits._
testGluten(
"input_file_name, input_file_block_start and input_file_block_length " +
"should fall back if scan falls back") {
withSQLConf(("spark.gluten.sql.columnar.filescan", "false")) {
withTempPath {
dir =>
val data = sparkContext.parallelize(0 to 10).toDF("id")
data.write.parquet(dir.getCanonicalPath)

val q =
spark.read.parquet(dir.getCanonicalPath).select(input_file_name(), expr("nested_column"))
val firstRow = q.head()
assert(firstRow.getString(0).contains(dir.toURI.getPath))
val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p }
assert(project.size == 1)
val q =
spark.read
.parquet(dir.getCanonicalPath)
.select(
input_file_name(),
expr("input_file_block_start()"),
expr("input_file_block_length()"))
val firstRow = q.head()
assert(firstRow.getString(0).contains(dir.toURI.getPath))
assert(firstRow.getLong(1) == 0)
assert(firstRow.getLong(2) > 0)
val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p }
assert(project.size == 1)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,32 @@ package org.apache.spark.sql

import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.functions.{expr, input_file_name}
import org.apache.spark.sql.types._

class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait {
testGluten("input_file_name with scan is fallback") {
withTempPath {
dir =>
val rawData = Seq(
Row(1, "Alice", Seq(Row(Seq(1, 2, 3)))),
Row(2, "Bob", Seq(Row(Seq(4, 5)))),
Row(3, "Charlie", Seq(Row(Seq(6, 7, 8, 9))))
)
val schema = StructType(
Array(
StructField("id", IntegerType, nullable = false),
StructField("name", StringType, nullable = false),
StructField(
"nested_column",
ArrayType(
StructType(Array(
StructField("array_in_struct", ArrayType(IntegerType), nullable = true)
))),
nullable = true)
))
val data: DataFrame = spark.createDataFrame(sparkContext.parallelize(rawData), schema)
data.write.parquet(dir.getCanonicalPath)
import testImplicits._
testGluten(
"input_file_name, input_file_block_start and input_file_block_length " +
"should fall back if scan falls back") {
withSQLConf(("spark.gluten.sql.columnar.filescan", "false")) {
withTempPath {
dir =>
val data = sparkContext.parallelize(0 to 10).toDF("id")
data.write.parquet(dir.getCanonicalPath)

val q =
spark.read.parquet(dir.getCanonicalPath).select(input_file_name(), expr("nested_column"))
val firstRow = q.head()
assert(firstRow.getString(0).contains(dir.toURI.getPath))
val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p }
assert(project.size == 1)
val q =
spark.read
.parquet(dir.getCanonicalPath)
.select(
input_file_name(),
expr("input_file_block_start()"),
expr("input_file_block_length()"))
val firstRow = q.head()
assert(firstRow.getString(0).contains(dir.toURI.getPath))
assert(firstRow.getLong(1) == 0)
assert(firstRow.getLong(2) > 0)
val project = q.queryExecution.executedPlan.collect { case p: ProjectExec => p }
assert(project.size == 1)
}
}
}
}

0 comments on commit 8e0d56e

Please sign in to comment.