Skip to content

Commit

Permalink
[VL] Code refactoring of variable names related to TPCH in Velox UTs (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored Dec 28, 2023
1 parent 7e7c2a0 commit 7b31df1
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class VeloxColumnarCacheSuite extends VeloxWholeStageTransformerSuite with Adapt
}

test("input columnar batch") {
TPCHTables.foreach {
case (table, _) =>
TPCHTables.map(_.name).foreach {
table =>
runQueryAndCompare(s"SELECT * FROM $table", cache = true) {
df => checkColumnarTableCache(df.queryExecution.executedPlan)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import io.glutenproject.sql.shims.SparkShimLoader
import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.InputIteratorTransformer

import java.io.File

class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite {
override protected val backend: String = "velox"
override protected val resourcePath: String = "/tpch-data-parquet-velox"
Expand All @@ -35,59 +33,38 @@ class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite {
override protected def sparkConf: SparkConf = super.sparkConf
.set("spark.unsafe.exceptionOnMemoryLeak", "true")

override protected def createTPCHNotNullTables(): Unit = {
Seq("customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier")
.foreach {
table =>
val tableDir = getClass.getResource(resourcePath).getFile
val tablePath = new File(tableDir, table).getAbsolutePath
val tableDF = spark.read.format(fileFormat).load(tablePath)
tableDF.createOrReplaceTempView(table)
}
}

test("generate hash join plan - v1") {
withSQLConf(
("spark.sql.autoBroadcastJoinThreshold", "-1"),
("spark.sql.adaptive.enabled", "false"),
("spark.gluten.sql.columnar.forceShuffledHashJoin", "true")) {
withTable(
"customer",
"lineitem",
"nation",
"orders",
"part",
"partsupp",
"region",
"supplier") {
createTPCHNotNullTables()
val df = spark.sql("""select l_partkey from
| lineitem join part join partsupp
| on l_partkey = p_partkey
| and l_suppkey = ps_suppkey""".stripMargin)
val plan = df.queryExecution.executedPlan
val joins = plan.collect { case shj: ShuffledHashJoinExecTransformer => shj }
// scalastyle:off println
System.out.println(plan)
// scalastyle:on println line=68 column=19
assert(joins.length == 2)
createTPCHNotNullTables()
val df = spark.sql("""select l_partkey from
| lineitem join part join partsupp
| on l_partkey = p_partkey
| and l_suppkey = ps_suppkey""".stripMargin)
val plan = df.queryExecution.executedPlan
val joins = plan.collect { case shj: ShuffledHashJoinExecTransformer => shj }
// scalastyle:off println
System.out.println(plan)
// scalastyle:on println line=68 column=19
assert(joins.length == 2)

// Children of Join should be seperated into different `TransformContext`s.
assert(joins.forall(_.children.forall(_.isInstanceOf[InputIteratorTransformer])))
// Children of Join should be seperated into different `TransformContext`s.
assert(joins.forall(_.children.forall(_.isInstanceOf[InputIteratorTransformer])))

// WholeStageTransformer should be inserted for joins and its children separately.
val wholeStages = plan.collect { case wst: WholeStageTransformer => wst }
assert(wholeStages.length == 5)
// WholeStageTransformer should be inserted for joins and its children separately.
val wholeStages = plan.collect { case wst: WholeStageTransformer => wst }
assert(wholeStages.length == 5)

// Join should be in `TransformContext`
val countSHJ = wholeStages.map {
_.collectFirst {
case _: InputIteratorTransformer => 0
case _: ShuffledHashJoinExecTransformer => 1
}.getOrElse(0)
}.sum
assert(countSHJ == 2)
}
// Join should be in `TransformContext`
val countSHJ = wholeStages.map {
_.collectFirst {
case _: InputIteratorTransformer => 0
case _: ShuffledHashJoinExecTransformer => 1
}.getOrElse(0)
}.sum
assert(countSHJ == 2)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class VeloxOrcDataTypeValidationSuite extends VeloxWholeStageTransformerSuite {
}

protected def createDataTypeTable(): Unit = {
TPCHTables = Seq(
TPCHTableDataFrames = Seq(
"type1",
"type2"
).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class VeloxParquetDataTypeValidationSuite extends VeloxWholeStageTransformerSuit
}

protected def createDataTypeTable(): Unit = {
TPCHTables = Seq(
TPCHTableDataFrames = Seq(
"type1",
"type2"
).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class VeloxTPCHV2BhjSuite extends VeloxTPCHSuite {

class VeloxPartitionedTableTPCHSuite extends VeloxTPCHSuite {
override protected def createTPCHNotNullTables(): Unit = {
TPCHTables = TPCHTable.map {
TPCHTableDataFrames = TPCHTables.map {
table =>
val tableDir = getClass.getResource(resourcePath).getFile
val tablePath = new File(tableDir, table.name).getAbsolutePath
Expand All @@ -260,8 +260,8 @@ class VeloxPartitionedTableTPCHSuite extends VeloxTPCHSuite {
}

override protected def afterAll(): Unit = {
if (TPCHTables != null) {
TPCHTables.keys.foreach(v => spark.sql(s"DROP TABLE IF EXISTS $v"))
if (TPCHTableDataFrames != null) {
TPCHTableDataFrames.keys.foreach(table => spark.sql(s"DROP TABLE IF EXISTS $table"))
}
super.afterAll()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class VeloxParquetWriteSuite extends VeloxWholeStageTransformerSuite {
case _ => codec
}

TPCHTables.foreach {
TPCHTableDataFrames.foreach {
case (_, df) =>
withTempPath {
f =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ abstract class WholeStageTransformerSuite extends GlutenQueryTest with SharedSpa
protected val fileFormat: String
protected val logLevel: String = "WARN"

protected val TPCHTable = Seq(
protected val TPCHTables = Seq(
Table("part", partitionColumns = "p_brand" :: Nil),
Table("supplier", partitionColumns = Nil),
Table("partsupp", partitionColumns = Nil),
Expand All @@ -52,7 +52,7 @@ abstract class WholeStageTransformerSuite extends GlutenQueryTest with SharedSpa
Table("region", partitionColumns = Nil)
)

protected var TPCHTables: Map[String, DataFrame] = _
protected var TPCHTableDataFrames: Map[String, DataFrame] = _

private val isFallbackCheckDisabled0 = new AtomicBoolean(false)

Expand All @@ -67,14 +67,14 @@ abstract class WholeStageTransformerSuite extends GlutenQueryTest with SharedSpa
}

override protected def afterAll(): Unit = {
if (TPCHTables != null) {
TPCHTables.keys.foreach(v => spark.sessionState.catalog.dropTempView(v))
if (TPCHTableDataFrames != null) {
TPCHTableDataFrames.keys.foreach(v => spark.sessionState.catalog.dropTempView(v))
}
super.afterAll()
}

protected def createTPCHNotNullTables(): Unit = {
TPCHTables = TPCHTable
TPCHTableDataFrames = TPCHTables
.map(_.name)
.map {
table =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class VeloxTPCHDeltaSuite extends VeloxTPCHSuite {
}

override protected def createTPCHNotNullTables(): Unit = {
TPCHTables = TPCHTable
TPCHTables
.map(_.name)
.map {
table =>
Expand All @@ -52,4 +52,9 @@ class VeloxTPCHDeltaSuite extends VeloxTPCHSuite {
}
.toMap
}

override protected def afterAll(): Unit = {
TPCHTables.map(_.name).foreach(table => spark.sql(s"DROP TABLE IF EXISTS $table"))
super.afterAll()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class VeloxTPCHIcebergSuite extends VeloxTPCHSuite {
}

override protected def createTPCHNotNullTables(): Unit = {
TPCHTables = TPCHTable
TPCHTables
.map(_.name)
.map {
table =>
Expand All @@ -60,9 +60,7 @@ class VeloxTPCHIcebergSuite extends VeloxTPCHSuite {
}

override protected def afterAll(): Unit = {
if (TPCHTables != null) {
TPCHTables.keys.foreach(v => spark.sql(s"DROP TABLE IF EXISTS $v"))
}
TPCHTables.map(_.name).foreach(table => spark.sql(s"DROP TABLE IF EXISTS $table"))
super.afterAll()
}

Expand Down Expand Up @@ -96,7 +94,7 @@ class VeloxTPCHIcebergSuite extends VeloxTPCHSuite {

class VeloxPartitionedTableTPCHIcebergSuite extends VeloxTPCHIcebergSuite {
override protected def createTPCHNotNullTables(): Unit = {
TPCHTables = TPCHTable.map {
TPCHTables.map {
table =>
val tablePath = new File(resourcePath, table.name).getAbsolutePath
val tableDF = spark.read.format(fileFormat).load(tablePath)
Expand Down

0 comments on commit 7b31df1

Please sign in to comment.