Skip to content

Commit

Permalink
pass spark3.5
Browse files Browse the repository at this point in the history
  • Loading branch information
baibaichen committed Jul 30, 2024
1 parent 3190bc9 commit 2ff0fb5
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ class GlutenClickHouseDecimalSuite
private val decimalTable: String = "decimal_table"
private val decimalTPCHTables: Seq[(DecimalType, Seq[Int])] = Seq.apply(
(DecimalType.apply(9, 4), Seq()),
// 1: ch decimal avg is float
(DecimalType.apply(18, 8), Seq()),
// 1: ch decimal avg is float, 3/10: all value is null and compare with limit
(DecimalType.apply(38, 19), Seq(3, 10))
// 3/10: all value is null and compare with limit
// 1 Spark 3.5
(DecimalType.apply(38, 19), if (isSparkVersionLE("3.3")) Seq(3, 10) else Seq(1, 3, 10))
)

private def createDecimalTables(dataType: DecimalType): Unit = {
Expand Down Expand Up @@ -343,27 +343,22 @@ class GlutenClickHouseDecimalSuite
decimalTPCHTables.foreach {
dt =>
{
val fallBack = (sql_num == 16 || sql_num == 21)
val compareResult = !dt._2.contains(sql_num)
val native = if (fallBack) "fallback" else "native"
val compare = if (compareResult) "compare" else "noCompare"
val PrecisionLoss = s"allowPrecisionLoss=$allowPrecisionLoss"
val decimalType = dt._1
test(s"""TPCH Decimal(${decimalType.precision},${decimalType.scale})
| Q$sql_num[allowPrecisionLoss=$allowPrecisionLoss]""".stripMargin) {
var noFallBack = true
var compareResult = true
if (sql_num == 16 || sql_num == 21) {
noFallBack = false
}

if (dt._2.contains(sql_num)) {
compareResult = false
}

| Q$sql_num[$PrecisionLoss,$native,$compare]""".stripMargin) {
spark.sql(s"use decimal_${decimalType.precision}_${decimalType.scale}")
withSQLConf(
(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key, allowPrecisionLoss)) {
runTPCHQuery(
sql_num,
tpchQueries,
compareResult = compareResult,
noFallBack = noFallBack) { _ => {} }
noFallBack = !fallBack) { _ => {} }
}
spark.sql(s"use default")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1051,8 +1051,12 @@ class GlutenClickHouseHiveTableSuite
spark.sql(
s"CREATE FUNCTION my_add as " +
s"'org.apache.hadoop.hive.contrib.udf.example.UDFExampleAdd2' USING JAR '$jarUrl'")
runQueryAndCompare("select MY_ADD(id, id+1) from range(10)")(
checkGlutenOperatorMatch[ProjectExecTransformer])
if (isSparkVersionLE("3.3")) {
runQueryAndCompare("select MY_ADD(id, id+1) from range(10)")(
checkGlutenOperatorMatch[ProjectExecTransformer])
} else {
runQueryAndCompare("select MY_ADD(id, id+1) from range(10)", noFallBack = false)(_ => {})
}
}

test("GLUTEN-4333: fix CSE in aggregate operator") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,11 @@ class GlutenClickHouseTPCHBucketSuite
val plans = collect(df.queryExecution.executedPlan) {
case scanExec: BasicScanExecTransformer => scanExec
}
assert(!(plans(0).asInstanceOf[FileSourceScanExecTransformer].bucketedScan))
assert(plans(0).metrics("numFiles").value === 2)
assert(plans(0).metrics("pruningTime").value === -1)
assert(plans(0).metrics("numOutputRows").value === 591673)
assert(!plans.head.asInstanceOf[FileSourceScanExecTransformer].bucketedScan)
assert(plans.head.metrics("numFiles").value === 2)
val pruningTimeValue = if (isSparkVersionGE("3.4")) 0 else -1
assert(plans.head.metrics("pruningTime").value === pruningTimeValue)
assert(plans.head.metrics("numOutputRows").value === 591673)
})
}

Expand Down Expand Up @@ -409,10 +410,11 @@ class GlutenClickHouseTPCHBucketSuite
val plans = collect(df.queryExecution.executedPlan) {
case scanExec: BasicScanExecTransformer => scanExec
}
assert(!(plans(0).asInstanceOf[FileSourceScanExecTransformer].bucketedScan))
assert(plans(0).metrics("numFiles").value === 2)
assert(plans(0).metrics("pruningTime").value === -1)
assert(plans(0).metrics("numOutputRows").value === 11618)
assert(!plans.head.asInstanceOf[FileSourceScanExecTransformer].bucketedScan)
assert(plans.head.metrics("numFiles").value === 2)
val pruningTimeValue = if (isSparkVersionGE("3.4")) 0 else -1
assert(plans.head.metrics("pruningTime").value === pruningTimeValue)
assert(plans.head.metrics("numOutputRows").value === 11618)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class GlutenClickhouseCountDistinctSuite extends GlutenClickHouseWholeStageTrans
values (0, null,1), (0,null,2), (1, 1,4) as data(a,b,c) group by try_add(c,b)
""";
val df = spark.sql(sql)
WholeStageTransformerSuite.checkFallBack(df, noFallback = false)
WholeStageTransformerSuite.checkFallBack(df, noFallback = isSparkVersionGE("3.4"))
}

test("check count distinct with filter") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,16 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite
assert(plans.size == 3)

assert(plans(2).metrics("numFiles").value === 1)
assert(plans(2).metrics("pruningTime").value === -1)
val pruningTimeValue = if (isSparkVersionGE("3.4")) 0 else -1
assert(plans(2).metrics("pruningTime").value === pruningTimeValue)
assert(plans(2).metrics("filesSize").value === 19230111)

assert(plans(1).metrics("numOutputRows").value === 4)
assert(plans(1).metrics("outputVectors").value === 1)

// Execute Sort operator, it will read the data twice.
assert(plans(0).metrics("numOutputRows").value === 4)
assert(plans(0).metrics("outputVectors").value === 1)
assert(plans.head.metrics("numOutputRows").value === 4)
assert(plans.head.metrics("outputVectors").value === 1)
}
}

Expand Down Expand Up @@ -139,7 +140,8 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite
assert(plans.size == 3)

assert(plans(2).metrics("numFiles").value === 1)
assert(plans(2).metrics("pruningTime").value === -1)
val pruningTimeValue = if (isSparkVersionGE("3.4")) 0 else -1
assert(plans(2).metrics("pruningTime").value === pruningTimeValue)
assert(plans(2).metrics("filesSize").value === 19230111)

assert(plans(1).metrics("numOutputRows").value === 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,26 +460,34 @@ class GlutenParquetFilterSuite
"orders1" -> Nil)
)

def runTest(i: Int): Unit = withDataFrame(tpchSQL(i + 1, tpchQueriesResourceFolder)) {
df =>
val scans = df.queryExecution.executedPlan
.collect { case scan: FileSourceScanExecTransformer => scan }
assertResult(result(i).size)(scans.size)
scans.zipWithIndex
.foreach {
case (scan, fileIndex) =>
val tableName = scan.tableIdentifier
.map(_.table)
.getOrElse(scan.relation.options("path").split("/").last)
val predicates = scan.filterExprs()
val expected = result(i)(s"$tableName$fileIndex")
assertResult(expected.size)(predicates.size)
if (expected.isEmpty) assert(predicates.isEmpty)
else compareExpressions(expected.reduceLeft(And), predicates.reduceLeft(And))
}
}

tpchQueries.zipWithIndex.foreach {
case (q, i) =>
test(q) {
withDataFrame(tpchSQL(i + 1, tpchQueriesResourceFolder)) {
df =>
val scans = df.queryExecution.executedPlan
.collect { case scan: FileSourceScanExecTransformer => scan }
assertResult(result(i).size)(scans.size)
scans.zipWithIndex
.foreach {
case (scan, fileIndex) =>
val tableName = scan.tableIdentifier
.map(_.table)
.getOrElse(scan.relation.options("path").split("/").last)
val predicates = scan.filterExprs()
val expected = result(i)(s"$tableName$fileIndex")
assertResult(expected.size)(predicates.size)
if (expected.isEmpty) assert(predicates.isEmpty)
else compareExpressions(expected.reduceLeft(And), predicates.reduceLeft(And))
}
if (q == "q2" || q == "q9") {
testSparkVersionLE33(q) {
runTest(i)
}
} else {
test(q) {
runTest(i)
}
}
}
Expand Down

0 comments on commit 2ff0fb5

Please sign in to comment.