diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxColumnarCacheSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxColumnarCacheSuite.scala index cc94fd006841..224f9fc63642 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxColumnarCacheSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxColumnarCacheSuite.scala @@ -57,8 +57,8 @@ class VeloxColumnarCacheSuite extends VeloxWholeStageTransformerSuite with Adapt } test("input columnar batch") { - TPCHTables.foreach { - case (table, _) => + TPCHTables.map(_.name).foreach { + table => runQueryAndCompare(s"SELECT * FROM $table", cache = true) { df => checkColumnarTableCache(df.queryExecution.executedPlan) } diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxHashJoinSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxHashJoinSuite.scala index f543c92f02bf..e2cc4dc0b9c7 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxHashJoinSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxHashJoinSuite.scala @@ -21,8 +21,6 @@ import io.glutenproject.sql.shims.SparkShimLoader import org.apache.spark.SparkConf import org.apache.spark.sql.execution.InputIteratorTransformer -import java.io.File - class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite { override protected val backend: String = "velox" override protected val resourcePath: String = "/tpch-data-parquet-velox" @@ -35,59 +33,38 @@ class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite { override protected def sparkConf: SparkConf = super.sparkConf .set("spark.unsafe.exceptionOnMemoryLeak", "true") - override protected def createTPCHNotNullTables(): Unit = { - Seq("customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier") - .foreach { - table => - val tableDir = getClass.getResource(resourcePath).getFile - val tablePath = new File(tableDir, table).getAbsolutePath - val tableDF = spark.read.format(fileFormat).load(tablePath) - tableDF.createOrReplaceTempView(table) - } - } - test("generate hash join plan - v1") { withSQLConf( ("spark.sql.autoBroadcastJoinThreshold", "-1"), ("spark.sql.adaptive.enabled", "false"), ("spark.gluten.sql.columnar.forceShuffledHashJoin", "true")) { - withTable( - "customer", - "lineitem", - "nation", - "orders", - "part", - "partsupp", - "region", - "supplier") { - createTPCHNotNullTables() - val df = spark.sql("""select l_partkey from - | lineitem join part join partsupp - | on l_partkey = p_partkey - | and l_suppkey = ps_suppkey""".stripMargin) - val plan = df.queryExecution.executedPlan - val joins = plan.collect { case shj: ShuffledHashJoinExecTransformer => shj } - // scalastyle:off println - System.out.println(plan) - // scalastyle:on println line=68 column=19 - assert(joins.length == 2) + createTPCHNotNullTables() + val df = spark.sql("""select l_partkey from + | lineitem join part join partsupp + | on l_partkey = p_partkey + | and l_suppkey = ps_suppkey""".stripMargin) + val plan = df.queryExecution.executedPlan + val joins = plan.collect { case shj: ShuffledHashJoinExecTransformer => shj } + // scalastyle:off println + System.out.println(plan) + // scalastyle:on println line=68 column=19 + assert(joins.length == 2) - // Children of Join should be seperated into different `TransformContext`s. - assert(joins.forall(_.children.forall(_.isInstanceOf[InputIteratorTransformer]))) + // Children of Join should be seperated into different `TransformContext`s. + assert(joins.forall(_.children.forall(_.isInstanceOf[InputIteratorTransformer]))) - // WholeStageTransformer should be inserted for joins and its children separately. - val wholeStages = plan.collect { case wst: WholeStageTransformer => wst } - assert(wholeStages.length == 5) + // WholeStageTransformer should be inserted for joins and its children separately. + val wholeStages = plan.collect { case wst: WholeStageTransformer => wst } + assert(wholeStages.length == 5) - // Join should be in `TransformContext` - val countSHJ = wholeStages.map { - _.collectFirst { - case _: InputIteratorTransformer => 0 - case _: ShuffledHashJoinExecTransformer => 1 - }.getOrElse(0) - }.sum - assert(countSHJ == 2) - } + // Join should be in `TransformContext` + val countSHJ = wholeStages.map { + _.collectFirst { + case _: InputIteratorTransformer => 0 + case _: ShuffledHashJoinExecTransformer => 1 + }.getOrElse(0) + }.sum + assert(countSHJ == 2) } } diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxOrcDataTypeValidationSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxOrcDataTypeValidationSuite.scala index 2fa9db65af66..4c8a13a51fc5 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxOrcDataTypeValidationSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxOrcDataTypeValidationSuite.scala @@ -32,7 +32,7 @@ class VeloxOrcDataTypeValidationSuite extends VeloxWholeStageTransformerSuite { } protected def createDataTypeTable(): Unit = { - TPCHTables = Seq( + TPCHTableDataFrames = Seq( "type1", "type2" ).map { diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxParquetDataTypeValidationSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxParquetDataTypeValidationSuite.scala index e30fd6b7b643..8cd5bcdb8cc8 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxParquetDataTypeValidationSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxParquetDataTypeValidationSuite.scala @@ -32,7 +32,7 @@ class VeloxParquetDataTypeValidationSuite extends VeloxWholeStageTransformerSuit } protected def createDataTypeTable(): Unit = { - TPCHTables = Seq( + TPCHTableDataFrames = Seq( "type1", "type2" ).map { diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxTPCHSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxTPCHSuite.scala index 37dad8412c2b..fdf06f9d1863 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxTPCHSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxTPCHSuite.scala @@ -244,7 +244,7 @@ class VeloxTPCHV2BhjSuite extends VeloxTPCHSuite { class VeloxPartitionedTableTPCHSuite extends VeloxTPCHSuite { override protected def createTPCHNotNullTables(): Unit = { - TPCHTables = TPCHTable.map { + TPCHTableDataFrames = TPCHTables.map { table => val tableDir = getClass.getResource(resourcePath).getFile val tablePath = new File(tableDir, table.name).getAbsolutePath @@ -260,8 +260,8 @@ class VeloxPartitionedTableTPCHSuite extends VeloxTPCHSuite { } override protected def afterAll(): Unit = { - if (TPCHTables != null) { - TPCHTables.keys.foreach(v => spark.sql(s"DROP TABLE IF EXISTS $v")) + if (TPCHTableDataFrames != null) { + TPCHTableDataFrames.keys.foreach(table => spark.sql(s"DROP TABLE IF EXISTS $table")) } super.afterAll() } diff --git a/backends-velox/src/test/scala/org/apache/spark/sql/execution/VeloxParquetWriteSuite.scala b/backends-velox/src/test/scala/org/apache/spark/sql/execution/VeloxParquetWriteSuite.scala index 3e37d6c2c714..57711152ad22 100644 --- a/backends-velox/src/test/scala/org/apache/spark/sql/execution/VeloxParquetWriteSuite.scala +++ b/backends-velox/src/test/scala/org/apache/spark/sql/execution/VeloxParquetWriteSuite.scala @@ -50,7 +50,7 @@ class VeloxParquetWriteSuite extends VeloxWholeStageTransformerSuite { case _ => codec } - TPCHTables.foreach { + TPCHTableDataFrames.foreach { case (_, df) => withTempPath { f => diff --git a/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala b/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala index b720608f26fa..9f4653451dfb 100644 --- a/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala +++ b/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala @@ -41,7 +41,7 @@ abstract class WholeStageTransformerSuite extends GlutenQueryTest with SharedSpa protected val fileFormat: String protected val logLevel: String = "WARN" - protected val TPCHTable = Seq( + protected val TPCHTables = Seq( Table("part", partitionColumns = "p_brand" :: Nil), Table("supplier", partitionColumns = Nil), Table("partsupp", partitionColumns = Nil), @@ -52,7 +52,7 @@ abstract class WholeStageTransformerSuite extends GlutenQueryTest with SharedSpa Table("region", partitionColumns = Nil) ) - protected var TPCHTables: Map[String, DataFrame] = _ + protected var TPCHTableDataFrames: Map[String, DataFrame] = _ private val isFallbackCheckDisabled0 = new AtomicBoolean(false) @@ -67,14 +67,14 @@ abstract class WholeStageTransformerSuite extends GlutenQueryTest with SharedSpa } override protected def afterAll(): Unit = { - if (TPCHTables != null) { - TPCHTables.keys.foreach(v => spark.sessionState.catalog.dropTempView(v)) + if (TPCHTableDataFrames != null) { + TPCHTableDataFrames.keys.foreach(v => spark.sessionState.catalog.dropTempView(v)) } super.afterAll() } protected def createTPCHNotNullTables(): Unit = { - TPCHTables = TPCHTable + TPCHTableDataFrames = TPCHTables .map(_.name) .map { table => diff --git a/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxTPCHDeltaSuite.scala b/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxTPCHDeltaSuite.scala index 2c7fce454122..0a6e3083ae6b 100644 --- a/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxTPCHDeltaSuite.scala +++ b/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxTPCHDeltaSuite.scala @@ -41,7 +41,7 @@ class VeloxTPCHDeltaSuite extends VeloxTPCHSuite { } override protected def createTPCHNotNullTables(): Unit = { - TPCHTables = TPCHTable + TPCHTables .map(_.name) .map { table => @@ -52,4 +52,9 @@ class VeloxTPCHDeltaSuite extends VeloxTPCHSuite { } .toMap } + + override protected def afterAll(): Unit = { + TPCHTables.map(_.name).foreach(table => spark.sql(s"DROP TABLE IF EXISTS $table")) + super.afterAll() + } } diff --git a/gluten-iceberg/src/test/scala/io/glutenproject/execution/VeloxTPCHIcebergSuite.scala b/gluten-iceberg/src/test/scala/io/glutenproject/execution/VeloxTPCHIcebergSuite.scala index 1c87851c311e..6d4c0e566311 100644 --- a/gluten-iceberg/src/test/scala/io/glutenproject/execution/VeloxTPCHIcebergSuite.scala +++ b/gluten-iceberg/src/test/scala/io/glutenproject/execution/VeloxTPCHIcebergSuite.scala @@ -47,7 +47,7 @@ class VeloxTPCHIcebergSuite extends VeloxTPCHSuite { } override protected def createTPCHNotNullTables(): Unit = { - TPCHTables = TPCHTable + TPCHTables .map(_.name) .map { table => @@ -60,9 +60,7 @@ class VeloxTPCHIcebergSuite extends VeloxTPCHSuite { } override protected def afterAll(): Unit = { - if (TPCHTables != null) { - TPCHTables.keys.foreach(v => spark.sql(s"DROP TABLE IF EXISTS $v")) - } + TPCHTables.map(_.name).foreach(table => spark.sql(s"DROP TABLE IF EXISTS $table")) super.afterAll() } @@ -96,7 +94,7 @@ class VeloxTPCHIcebergSuite extends VeloxTPCHSuite { class VeloxPartitionedTableTPCHIcebergSuite extends VeloxTPCHIcebergSuite { override protected def createTPCHNotNullTables(): Unit = { - TPCHTables = TPCHTable.map { + TPCHTables.map { table => val tablePath = new File(resourcePath, table.name).getAbsolutePath val tableDF = spark.read.format(fileFormat).load(tablePath)