Skip to content

Commit

Permalink
Use Spark API to build expression
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Jul 10, 2024
1 parent 87bc57e commit 22a5574
Showing 1 changed file with 94 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,21 @@ package org.apache.spark.sql

import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.{ProjectExecTransformer, WholeStageTransformerSuite}
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.utils.{BackendTestUtils, SystemParameters}

import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types._

class GlutenExpressionDataTypesValidation extends WholeStageTransformerSuite {
protected val resourcePath: String = null
protected val fileFormat: String = null

import testImplicits._
override protected val logLevel: String = "INFO"

override protected def sparkConf: SparkConf = {
val conf = super.sparkConf
Expand All @@ -46,29 +50,99 @@ class GlutenExpressionDataTypesValidation extends WholeStageTransformerSuite {
conf
}

def doTest(sql: String): Unit = {
withTempPath {
path =>
Seq[String]("abc", null, "123", "aaa", "bbb", "ccc", "ddd", "a", "b", "c", null)
.toDF("col")
.write
.parquet(path.getCanonicalPath)
spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("tbl")
runQueryAndCompare(sql, false) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
private case class DummyPlan() extends LeafExecNode {
override def output: Seq[Attribute] = Seq()

override val metrics: Map[String, SQLMetric] = Map.empty

override def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException(
"Just a dummy plan.")
}

private val allDataTypes: Seq[DataType] =
Seq(
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
DecimalType(5, 1),
StringType,
BinaryType,
DateType,
TimestampType)

def generateChildExpression(t: DataType): Expression = {
t match {
case _: IntegralType => Literal(null, t)
case _: FractionalType => Literal(null, t)
case StringType | BinaryType => Literal("123")
case DateType => Literal(null, t)
case TimestampType => Literal(null, t)
case _ => throw new UnsupportedOperationException
}
}
def generateGlutenProjectPlan(expr: Expression): GlutenPlan = {
val namedExpr = Seq(Alias(expr, "r")())
ProjectExecTransformer(namedExpr, DummyPlan())
}

def generateSimpleSQL(functionName: String): String = {
String.format("SELECT %s(col) from tbl", functionName)
test("cast") {
for (from <- allDataTypes) {
for (to <- allDataTypes) {
if (to != from) {
val castExpr = Cast(generateChildExpression(from), to)
if (castExpr.checkInputDataTypes().isSuccess) {
val glutenProject = generateGlutenProjectPlan(castExpr)
if (glutenProject.doValidate().isValid) {
logInfo("## cast validation passes: cast from " + from + " to " + to)
} else {
// scalastyle:off
logInfo("!! cast validation fails: cast from " + from + " to " + to)
}
}
}
}
}
}

test("ascii") {
val inputTypes = Ascii(null).inputTypes
Ascii(null).prettyName
if (inputTypes.size == 1 && inputTypes.head.acceptsType(StringType)) {
doTest(generateSimpleSQL(Ascii(null).prettyName))
test("unary expressions with expected input types") {
val functionRegistry = spark.sessionState.functionRegistry
val sparkBuiltInFunctions = functionRegistry.listFunction()
for (func <- sparkBuiltInFunctions) {
val builder = functionRegistry.lookupFunctionBuilder(func).get
var expr: Expression = null
try {
// Instantiate an expression with null input. Just for obtaining the instance for checking
// its allowed input types.
expr = builder(Seq(null))
} catch {
// Ignore the exception as some expression builders require more than one input.
case _: Throwable =>
}
if (
expr != null && expr.isInstanceOf[ExpectsInputTypes] && expr.isInstanceOf[UnaryExpression]
) {
val acceptedTypes =
allDataTypes.filter(expr.asInstanceOf[ExpectsInputTypes].inputTypes.head.acceptsType(_))
if (acceptedTypes.isEmpty) {
logWarning("Any given type is not supported for " + expr.getClass.getSimpleName)
}
acceptedTypes.foreach(
t => {
val child = generateChildExpression(t)
// Builds an expression whose child's type is really supported in Spark.
val targetExpr = builder(Seq(child))
val glutenProject = generateGlutenProjectPlan(targetExpr)
if (glutenProject.doValidate().isValid) {
logInfo("## validation passes: " + targetExpr.getClass.getSimpleName + "(" + t + ")")
} else {
logInfo("!! validation fails: " + targetExpr.getClass.getSimpleName + "(" + t + ")")
}
})
}
}
}

}

0 comments on commit 22a5574

Please sign in to comment.