diff --git a/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenSQLTestsBaseTrait.scala b/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenSQLTestsBaseTrait.scala index 8c55b823a06c..4c06b02a1fb4 100644 --- a/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenSQLTestsBaseTrait.scala +++ b/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenSQLTestsBaseTrait.scala @@ -20,36 +20,13 @@ import org.apache.gluten.GlutenConfig import org.apache.gluten.utils.{BackendTestUtils, SystemParameters} import org.apache.spark.SparkConf -import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, ShuffleQueryStageExec} import org.apache.spark.sql.test.SharedSparkSession -import org.scalactic.source.Position -import org.scalatest.Tag - /** Basic trait for Gluten SQL test cases. */ trait GlutenSQLTestsBaseTrait extends SharedSparkSession with GlutenTestsBaseTrait { - protected def testGluten(testName: String, testTag: Tag*)(testFun: => Any)(implicit - pos: Position): Unit = { - test(GLUTEN_TEST + testName, testTag: _*)(testFun) - } - - protected def ignoreGluten(testName: String, testTag: Tag*)(testFun: => Any)(implicit - pos: Position): Unit = { - super.ignore(GLUTEN_TEST + testName, testTag: _*)(testFun) - } - - override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit - pos: Position): Unit = { - if (shouldRun(testName)) { - super.test(testName, testTags: _*)(testFun) - } else { - super.ignore(testName, testTags: _*)(testFun) - } - } - override def sparkConf: SparkConf = { GlutenSQLTestsBaseTrait.nativeSparkConf(super.sparkConf, warehouse) } diff --git a/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsBaseTrait.scala b/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsBaseTrait.scala index 7c6dcbbee83d..a0ab97306166 100644 --- a/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsBaseTrait.scala +++ b/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsBaseTrait.scala @@ -18,7 +18,13 @@ package org.apache.spark.sql import org.apache.gluten.utils.BackendTestSettings -trait GlutenTestsBaseTrait { +import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST + +import org.scalactic.source.Position +import org.scalatest.Tag +import org.scalatest.funsuite.AnyFunSuiteLike + +trait GlutenTestsBaseTrait extends AnyFunSuiteLike { protected val rootPath: String = getClass.getResource("/").getPath protected val basePath: String = rootPath + "unit-tests-working-home" @@ -30,7 +36,7 @@ trait GlutenTestsBaseTrait { // list will never be run with no regard to backend test settings. def testNameBlackList: Seq[String] = Seq() - def shouldRun(testName: String): Boolean = { + protected def shouldRun(testName: String): Boolean = { if (testNameBlackList.exists(_.equalsIgnoreCase(GlutenTestConstants.IGNORE_ALL))) { return false } @@ -39,4 +45,24 @@ trait GlutenTestsBaseTrait { } BackendTestSettings.shouldRun(getClass.getCanonicalName, testName) } + + protected def testGluten(testName: String, testTag: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + test(GLUTEN_TEST + testName, testTag: _*)(testFun) + } + + protected def ignoreGluten(testName: String, testTag: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + super.ignore(GLUTEN_TEST + testName, testTag: _*)(testFun) + } + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + if (shouldRun(testName)) { + super.test(testName, testTags: _*)(testFun) + } else { + super.ignore(testName, testTags: _*)(testFun) + } + } + } diff --git a/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsCommonTrait.scala b/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsCommonTrait.scala index 06b9fca67bf7..b9ee199eb1af 100644 --- a/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsCommonTrait.scala +++ b/gluten-ut/common/src/test/scala/org/apache/spark/sql/GlutenTestsCommonTrait.scala @@ -19,11 +19,9 @@ package org.apache.spark.sql import org.apache.gluten.test.TestStats import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST import org.apache.spark.sql.catalyst.expressions._ -import org.scalactic.source.Position -import org.scalatest.{Args, Status, Tag} +import org.scalatest.{Args, Status} trait GlutenTestsCommonTrait extends SparkFunSuite @@ -48,23 +46,4 @@ trait GlutenTestsCommonTrait TestStats.endCase(status.succeeds()); status } - - protected def testGluten(testName: String, testTag: Tag*)(testFun: => Any)(implicit - pos: Position): Unit = { - test(GLUTEN_TEST + testName, testTag: _*)(testFun) - } - - protected def ignoreGluten(testName: String, testTag: Tag*)(testFun: => Any)(implicit - pos: Position): Unit = { - super.ignore(GLUTEN_TEST + testName, testTag: _*)(testFun) - } - - override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit - pos: Position): Unit = { - if (shouldRun(testName)) { - super.test(testName, testTags: _*)(testFun) - } else { - super.ignore(testName, testTags: _*)(testFun) - } - } } diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala index 338d7992e38d..70579c886248 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala @@ -16,7 +16,9 @@ */ package org.apache.spark -class GlutenSortShuffleSuite extends SortShuffleSuite { +import org.apache.spark.sql.GlutenTestsBaseTrait + +class GlutenSortShuffleSuite extends SortShuffleSuite with GlutenTestsBaseTrait { override def beforeAll(): Unit = { super.beforeAll() conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala index 338d7992e38d..70579c886248 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala @@ -16,7 +16,9 @@ */ package org.apache.spark -class GlutenSortShuffleSuite extends SortShuffleSuite { +import org.apache.spark.sql.GlutenTestsBaseTrait + +class GlutenSortShuffleSuite extends SortShuffleSuite with GlutenTestsBaseTrait { override def beforeAll(): Unit = { super.beforeAll() conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala index 338d7992e38d..70579c886248 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala @@ -16,7 +16,9 @@ */ package org.apache.spark -class GlutenSortShuffleSuite extends SortShuffleSuite { +import org.apache.spark.sql.GlutenTestsBaseTrait + +class GlutenSortShuffleSuite extends SortShuffleSuite with GlutenTestsBaseTrait { override def beforeAll(): Unit = { super.beforeAll() conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala index 338d7992e38d..70579c886248 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/GlutenSortShuffleSuite.scala @@ -16,7 +16,9 @@ */ package org.apache.spark -class GlutenSortShuffleSuite extends SortShuffleSuite { +import org.apache.spark.sql.GlutenTestsBaseTrait + +class GlutenSortShuffleSuite extends SortShuffleSuite with GlutenTestsBaseTrait { override def beforeAll(): Unit = { super.beforeAll() conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")