Skip to content

Commit

Permalink
Merge branch 'main' into wip-test-naming
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Aug 6, 2024
2 parents 5ec13ad + 43d0ff9 commit 4642c52
Show file tree
Hide file tree
Showing 29 changed files with 591 additions and 198 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,7 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
}

wExpression.windowFunction match {
case _: RowNumber | _: AggregateExpression | _: Rank | _: DenseRank | _: PercentRank |
_: NTile =>
case _: RowNumber | _: AggregateExpression | _: Rank | _: DenseRank | _: NTile =>
allSupported = allSupported
case l: Lag =>
checkLagOrLead(l.third)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("window percent_rank") {
ignore("window percent_rank") {
val sql =
"""
|select n_regionkey, n_nationkey,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ object VeloxBackendSettings extends BackendSettingsApi {
val GLUTEN_VELOX_UDF_LIB_PATHS = getBackendConfigPrefix() + ".udfLibraryPaths"
val GLUTEN_VELOX_DRIVER_UDF_LIB_PATHS = getBackendConfigPrefix() + ".driver.udfLibraryPaths"
val GLUTEN_VELOX_INTERNAL_UDF_LIB_PATHS = getBackendConfigPrefix() + ".internal.udfLibraryPaths"
val GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION = getBackendConfigPrefix() + ".udfAllowTypeConversion"

val MAXIMUM_BATCH_SIZE: Int = 32768

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, ExpressionInfo, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -74,18 +75,21 @@ trait UDFSignatureBase {
val expressionType: ExpressionType
val children: Seq[DataType]
val variableArity: Boolean
val allowTypeConversion: Boolean
}

case class UDFSignature(
expressionType: ExpressionType,
children: Seq[DataType],
variableArity: Boolean)
variableArity: Boolean,
allowTypeConversion: Boolean)
extends UDFSignatureBase

case class UDAFSignature(
expressionType: ExpressionType,
children: Seq[DataType],
variableArity: Boolean,
allowTypeConversion: Boolean,
intermediateAttrs: Seq[AttributeReference])
extends UDFSignatureBase

Expand Down Expand Up @@ -130,26 +134,30 @@ object UDFResolver extends Logging {
name: String,
returnType: Array[Byte],
argTypes: Array[Byte],
variableArity: Boolean): Unit = {
variableArity: Boolean,
allowTypeConversion: Boolean): Unit = {
registerUDF(
name,
ConverterUtils.parseFromBytes(returnType),
ConverterUtils.parseFromBytes(argTypes),
variableArity)
variableArity,
allowTypeConversion)
}

private def registerUDF(
name: String,
returnType: ExpressionType,
argTypes: ExpressionType,
variableArity: Boolean): Unit = {
variableArity: Boolean,
allowTypeConversion: Boolean): Unit = {
assert(argTypes.dataType.isInstanceOf[StructType])
val v =
UDFMap.getOrElseUpdate(name, mutable.MutableList[UDFSignature]())
v += UDFSignature(
returnType,
argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType),
variableArity)
variableArity,
allowTypeConversion)
UDFNames += name
logInfo(s"Registered UDF: $name($argTypes) -> $returnType")
}
Expand All @@ -159,13 +167,15 @@ object UDFResolver extends Logging {
returnType: Array[Byte],
argTypes: Array[Byte],
intermediateTypes: Array[Byte],
variableArity: Boolean): Unit = {
variableArity: Boolean,
enableTypeConversion: Boolean): Unit = {
registerUDAF(
name,
ConverterUtils.parseFromBytes(returnType),
ConverterUtils.parseFromBytes(argTypes),
ConverterUtils.parseFromBytes(intermediateTypes),
variableArity
variableArity,
enableTypeConversion
)
}

Expand All @@ -174,7 +184,8 @@ object UDFResolver extends Logging {
returnType: ExpressionType,
argTypes: ExpressionType,
intermediateTypes: ExpressionType,
variableArity: Boolean): Unit = {
variableArity: Boolean,
allowTypeConversion: Boolean): Unit = {
assert(argTypes.dataType.isInstanceOf[StructType])

val aggBufferAttributes: Seq[AttributeReference] =
Expand All @@ -194,6 +205,7 @@ object UDFResolver extends Logging {
returnType,
argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType),
variableArity,
allowTypeConversion,
aggBufferAttributes)
UDAFNames += name
logInfo(s"Registered UDAF: $name($argTypes) -> $returnType")
Expand Down Expand Up @@ -346,16 +358,27 @@ object UDFResolver extends Logging {
}
}

private def checkAllowTypeConversion: Boolean = {
SQLConf.get
.getConfString(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION, "false")
.toBoolean
}

private def getUdfExpression(name: String)(children: Seq[Expression]) = {
def errorMessage: String =
s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered."

val allowTypeConversion = checkAllowTypeConversion
val signatures =
UDFMap.getOrElse(name, throw new UnsupportedOperationException(errorMessage));

signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
case Some(sig) =>
UDFExpression(name, sig.expressionType.dataType, sig.expressionType.nullable, children)
UDFExpression(
name,
sig.expressionType.dataType,
sig.expressionType.nullable,
if (!allowTypeConversion && !sig.allowTypeConversion) children
else applyCast(children, sig))
case None =>
throw new UnsupportedOperationException(errorMessage)
}
Expand All @@ -365,62 +388,116 @@ object UDFResolver extends Logging {
def errorMessage: String =
s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered."

val allowTypeConversion = checkAllowTypeConversion
val signatures =
UDAFMap.getOrElse(
name,
throw new UnsupportedOperationException(errorMessage)
)

signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
case Some(sig) =>
UserDefinedAggregateFunction(
name,
sig.expressionType.dataType,
sig.expressionType.nullable,
children,
sig.intermediateAttrs)
if (!allowTypeConversion && !sig.allowTypeConversion) children
else applyCast(children, sig),
sig.intermediateAttrs
)
case None =>
throw new UnsupportedOperationException(errorMessage)
}
}

private def tryBind(
sig: UDFSignatureBase,
requiredDataTypes: Seq[DataType],
allowTypeConversion: Boolean): Boolean = {
if (
!tryBindStrict(sig, requiredDataTypes) && (allowTypeConversion || sig.allowTypeConversion)
) {
tryBindWithTypeConversion(sig, requiredDataTypes)
} else {
true
}
}

// Returns true if required data types match the function signature.
// If the function signature is variable arity, the number of the last argument can be zero
// or more.
private def tryBind(sig: UDFSignatureBase, requiredDataTypes: Seq[DataType]): Boolean = {
private def tryBindWithTypeConversion(
sig: UDFSignatureBase,
requiredDataTypes: Seq[DataType]): Boolean = {
tryBind0(sig, requiredDataTypes, Cast.canCast)
}

private def tryBindStrict(sig: UDFSignatureBase, requiredDataTypes: Seq[DataType]): Boolean = {
tryBind0(sig, requiredDataTypes, DataTypeUtils.sameType)
}

private def tryBind0(
sig: UDFSignatureBase,
requiredDataTypes: Seq[DataType],
checkType: (DataType, DataType) => Boolean): Boolean = {
if (!sig.variableArity) {
sig.children.size == requiredDataTypes.size &&
sig.children
.zip(requiredDataTypes)
.forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) }
requiredDataTypes
.zip(sig.children)
.forall { case (required, candidate) => checkType(required, candidate) }
} else {
// If variableArity is true, there must be at least one argument in the signature.
if (requiredDataTypes.size < sig.children.size - 1) {
false
} else if (requiredDataTypes.size == sig.children.size - 1) {
sig.children
.dropRight(1)
.zip(requiredDataTypes)
.forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) }
requiredDataTypes
.zip(sig.children.dropRight(1))
.forall { case (required, candidate) => checkType(required, candidate) }
} else {
val varArgStartIndex = sig.children.size - 1
// First check all var args has the same type with the last argument of the signature.
if (
!requiredDataTypes
.drop(varArgStartIndex)
.forall(argType => DataTypeUtils.sameType(sig.children.last, argType))
.forall(argType => checkType(argType, sig.children.last))
) {
false
} else if (varArgStartIndex == 0) {
// No fixed args.
true
} else {
// Whether fixed args matches.
sig.children
.dropRight(1)
.zip(requiredDataTypes.dropRight(1 + requiredDataTypes.size - sig.children.size))
.forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) }
requiredDataTypes
.dropRight(1 + requiredDataTypes.size - sig.children.size)
.zip(sig.children.dropRight(1))
.forall { case (required, candidate) => checkType(required, candidate) }
}
}
}
}

private def applyCast(children: Seq[Expression], sig: UDFSignatureBase): Seq[Expression] = {
def maybeCast(expr: Expression, toType: DataType): Expression = {
if (!expr.dataType.sameType(toType)) {
Cast(expr, toType)
} else {
expr
}
}

if (!sig.variableArity) {
children.zip(sig.children).map { case (expr, toType) => maybeCast(expr, toType) }
} else {
val fixedArgs = Math.min(children.size, sig.children.size)
val newChildren = children.take(fixedArgs).zip(sig.children.take(fixedArgs)).map {
case (expr, toType) => maybeCast(expr, toType)
}
if (children.size > sig.children.size) {
val varArgType = sig.children.last
newChildren ++ children.takeRight(children.size - sig.children.size).map {
expr => maybeCast(expr, varArgType)
}
} else {
newChildren
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -900,12 +900,11 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla

test("combine small batches before shuffle") {
val minBatchSize = 15
val maxBatchSize = 100
withSQLConf(
"spark.gluten.sql.columnar.backend.velox.resizeBatches.shuffleInput" -> "true",
"spark.gluten.sql.columnar.maxBatchSize" -> "2",
"spark.gluten.sql.columnar.backend.velox.resizeBatches.shuffleInput.range" ->
s"$minBatchSize~$maxBatchSize"
"spark.gluten.sql.columnar.backend.velox.resizeBatches.shuffleInput.minSize" ->
s"$minBatchSize"
) {
val df = runQueryAndCompare(
"select l_orderkey, sum(l_partkey) as sum from lineitem " +
Expand All @@ -921,16 +920,10 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
assert(metrics("numOutputRows").value == 27)
assert(metrics("numOutputBatches").value == 2)
}
}

test("split small batches before shuffle") {
val minBatchSize = 1
val maxBatchSize = 4
withSQLConf(
"spark.gluten.sql.columnar.backend.velox.resizeBatches.shuffleInput" -> "true",
"spark.gluten.sql.columnar.maxBatchSize" -> "100",
"spark.gluten.sql.columnar.backend.velox.resizeBatches.shuffleInput.range" ->
s"$minBatchSize~$maxBatchSize"
"spark.gluten.sql.columnar.maxBatchSize" -> "2"
) {
val df = runQueryAndCompare(
"select l_orderkey, sum(l_partkey) as sum from lineitem " +
Expand All @@ -939,12 +932,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
val ops = collect(df.queryExecution.executedPlan) { case p: VeloxResizeBatchesExec => p }
assert(ops.size == 1)
val op = ops.head
assert(op.minOutputBatchSize == minBatchSize)
assert(op.minOutputBatchSize == 1)
val metrics = op.metrics
assert(metrics("numInputRows").value == 27)
assert(metrics("numInputBatches").value == 1)
assert(metrics("numInputBatches").value == 14)
assert(metrics("numOutputRows").value == 27)
assert(metrics("numOutputBatches").value == 7)
assert(metrics("numOutputBatches").value == 14)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.expression

import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.tags.{SkipTestTags, UDFTest}

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -88,6 +89,23 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.sameElements(Array(Row(105L, 6, 6L, 5, 6, 11, 6L, 11L, Date.valueOf("2024-03-30")))))
}

test("test udf allow type conversion") {
withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "true") {
val df = spark.sql("""select myudf1("100"), myudf1(1), mydate('2024-03-25', 5)""")
assert(
df.collect()
.sameElements(Array(Row(105L, 6L, Date.valueOf("2024-03-30")))))
}

withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "false") {
assert(
spark
.sql("select mydate2('2024-03-25', 5)")
.collect()
.sameElements(Array(Row(Date.valueOf("2024-03-30")))))
}
}

test("test udaf") {
val df = spark.sql("""select
| myavg(1),
Expand All @@ -101,6 +119,15 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
df.collect()
.sameElements(Array(Row(1.0, 1.0, 1.0, 1.0, 1L))))
}

test("test udaf allow type conversion") {
withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "true") {
val df = spark.sql("""select myavg("1"), myavg("1.0"), mycount_if("true")""")
assert(
df.collect()
.sameElements(Array(Row(1.0, 1.0, 1L))))
}
}
}

@UDFTest
Expand Down
Loading

0 comments on commit 4642c52

Please sign in to comment.