Skip to content

Commit

Permalink
[VL] Support skewness aggregate function (apache#4939)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored and taiyang-li committed Oct 9, 2024
1 parent bfe632c commit b2e7805
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ object CHExpressionUtil {
ENCODE -> EncodeDecodeValidator(),
ARRAY_EXCEPT -> DefaultValidator(),
ARRAY_REPEAT -> DefaultValidator(),
DATE_FROM_UNIX_DATE -> DefaultValidator()
DATE_FROM_UNIX_DATE -> DefaultValidator(),
SKEWNESS -> DefaultValidator()
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -386,23 +386,38 @@ abstract class HashAggregateExecTransformer(
val adjustedOrders = veloxOrders.map(sparkOrders.indexOf(_))
veloxTypes.zipWithIndex.foreach {
case (veloxType, idx) =>
val sparkType = sparkTypes(adjustedOrders(idx))
val attr = rewrittenInputAttributes(adjustedOrders(idx))
val aggFuncInputAttrNode = ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
.doTransform(args)
val expressionNode = if (sparkType != veloxType) {
newInputAttributes +=
attr.copy(dataType = veloxType)(attr.exprId, attr.qualifier)
ExpressionBuilder.makeCast(
ConverterUtils.getTypeNode(veloxType, attr.nullable),
aggFuncInputAttrNode,
SQLConf.get.ansiEnabled)
val adjustedIdx = adjustedOrders(idx)
if (adjustedIdx == -1) {
// The Velox aggregate intermediate buffer column not found in Spark.
// For example, skewness and kurtosis share the same aggregate buffer in Velox,
// and Kurtosis additionally requires the buffer column of m4, which is
// always 0 for skewness. In Spark, the aggregate buffer of skewness does not
// have the column of m4, thus a placeholder m4 with a value of 0 must be passed
// to Velox, and this value cannot be omitted. Velox will always read m4 column
// when accessing the intermediate data.
val extraAttr = AttributeReference(veloxOrders(idx), veloxType)()
newInputAttributes += extraAttr
val lt = Literal.default(veloxType)
childNodes.add(ExpressionBuilder.makeLiteral(lt.value, lt.dataType, false))
} else {
newInputAttributes += attr
aggFuncInputAttrNode
val sparkType = sparkTypes(adjustedIdx)
val attr = rewrittenInputAttributes(adjustedIdx)
val aggFuncInputAttrNode = ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
.doTransform(args)
val expressionNode = if (sparkType != veloxType) {
newInputAttributes +=
attr.copy(dataType = veloxType)(attr.exprId, attr.qualifier)
ExpressionBuilder.makeCast(
ConverterUtils.getTypeNode(veloxType, attr.nullable),
aggFuncInputAttrNode,
SQLConf.get.ansiEnabled)
} else {
newInputAttributes += attr
aggFuncInputAttrNode
}
childNodes.add(expressionNode)
}
childNodes.add(expressionNode)
}
exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes, aggFunc))
case other =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,25 @@ import scala.collection.JavaConverters._
object VeloxIntermediateData {
// Agg functions with inconsistent ordering of intermediate data between Velox and Spark.
// Corr
val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
private val veloxCorrIntermediateDataOrder: Seq[String] =
Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
// CovPopulation, CovSample
val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg")
private val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg")
// Skewness
private val veloxSkewnessIntermediateDataOrder: Seq[String] = Seq("n", "avg", "m2", "m3", "m4")

// Agg functions with inconsistent types of intermediate data between Velox and Spark.
// StddevSamp, StddevPop, VarianceSamp, VariancePop
val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType)
private val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType)
// CovPopulation, CovSample
val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType)
private val veloxCovarIntermediateTypes: Seq[DataType] =
Seq(DoubleType, LongType, DoubleType, DoubleType)
// Corr
val veloxCorrIntermediateTypes: Seq[DataType] =
private val veloxCorrIntermediateTypes: Seq[DataType] =
Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType)
// Skewness
private val veloxSkewnessIntermediateTypes: Seq[DataType] =
Seq(LongType, DoubleType, DoubleType, DoubleType, DoubleType)

/**
* Return the intermediate columns order of Velox aggregation functions, with special matching
Expand All @@ -55,6 +62,8 @@ object VeloxIntermediateData {
veloxCorrIntermediateDataOrder
case _: CovPopulation | _: CovSample =>
veloxCovarIntermediateDataOrder
case _: Skewness =>
veloxSkewnessIntermediateDataOrder
case _ =>
aggFunc.aggBufferAttributes.map(_.name)
}
Expand Down Expand Up @@ -134,6 +143,8 @@ object VeloxIntermediateData {
Some(veloxCovarIntermediateTypes)
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
Some(veloxVarianceIntermediateTypes)
case _: Skewness =>
Some(veloxSkewnessIntermediateTypes)
case _ if aggFunc.aggBufferAttributes.size > 1 =>
Some(aggFunc.aggBufferAttributes.map(_.dataType))
case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,24 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
|""".stripMargin)(
df => assert(getExecutedPlan(df).count(_.isInstanceOf[HashAggregateExecTransformer]) == 2))
}

test("skewness") {
runQueryAndCompare("""
|select skewness(l_partkey) from lineitem;
|""".stripMargin) {
checkOperatorMatch[HashAggregateExecTransformer]
}
runQueryAndCompare("select skewness(l_partkey), count(distinct l_orderkey) from lineitem") {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 4)
}
}
}
}

class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite {
Expand Down
3 changes: 2 additions & 1 deletion cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,8 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag
"corr",
"covar_pop",
"covar_samp",
"approx_distinct"};
"approx_distinct",
"skewness"};

for (const auto& funcSpec : funcSpecs) {
auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec);
Expand Down
8 changes: 4 additions & 4 deletions docs/velox-backend-support-progress.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Gluten supports 28 operators (Draw to right to see all data types)
Gluten supports 199 functions. (Draw to right to see all data types)

| Spark Functions | Velox/Presto Functions | Velox/Spark functions | Gluten | Restrictions | BOOLEAN | BYTE | SHORT | INT | LONG | FLOAT | DOUBLE | DATE | TIMESTAMP | STRING | DECIMAL | NULL | BINARS | CALENDAR | ARRAY | MAP | STRUCT | UDT |
|-------------------------------|------------------------|-----------------------|--------|------------------------|---------|------|-------|-----|------|-------|--------|------|-----------|--------|---------|------|--------| -------- |-------| ---- | ------ | ---- |
|-------------------------------|------------------------|-----------------------|--------|------------------------|---------|------|-------|-----|------|-------|--------|------|-----------|--------|---------|------|--------| -------- |-------| ---- |--------| ---- |
| ! | | not | S | | S | S | S | S | S | S | S | | | S | | | | | | | | |
| != | neq | | S | | S | S | S | S | S | S | S | | | S | | | | | | | | |
| % | mod | remainder | S | Ansi Off | | S | S | S | S | S | | | | | | | | | | | | |
Expand Down Expand Up @@ -372,7 +372,7 @@ Gluten supports 199 functions. (Draw to right to see all data types)
| mean | avg | | S | Ansi Off | | | | | | | | | | | | | | | | | | |
| min | min | | S | | | | S | S | S | S | S | | | | | | | | | | | |
| min_by | | | S | | | | | | | | | | | | | | | | | | | |
| skewness | | | | | | | | | | | | | | | | | | | | | | |
| skewness | skewness | skewness | S | | | | S | S | S | S | S | | | | | | | | | | | |
| some | | | | | | | | | | | | | | | | | | | | | | |
| std,stddev | stddev | | S | | | | S | S | S | S | S | | | | | | | | | | | |
| stddev,std | stddev | | S | | | | S | S | S | S | S | | | | | | | | | | | |
Expand All @@ -387,7 +387,7 @@ Gluten supports 199 functions. (Draw to right to see all data types)
| lag | | | | | | | | | | | | | | | | | | | | | | |
| lead | | | | | | | | | | | | | | | | | | | | | | |
| nth_value | nth_value | nth_value | PS | | | | | | | | | | | | | | | | | | | |
| ntile | ntile | ntile | S | | | | | | | | | | | | | | | | | | | |
| ntile | ntile | ntile | S | | | | | | | | | | | | | | | | | | | |
| percent_rank | percent_rank | | S | | | | | | | | | | | | | | | | | | | |
| rank | rank | | S | | | | | | | | | | | | | | | | | | | |
| row_number | row_number | | S | | | | S | S | S | | | | | | | | | | | | | |
Expand All @@ -404,7 +404,7 @@ Gluten supports 199 functions. (Draw to right to see all data types)
| coalesce | | | PS | | | | | | | | | | | | | | | | | | | |
| crc32 | crc32 | | S | | | | | | | | | | | S | | | | | | | | |
| current_user | | | S* | | | | | | | | | | | S | | | | | | | | |
| current_catalog | | | S | | | | | | | | | | | | | | | | | | | |
| current_catalog | | | S | | | | | | | | | | | | | | | | | | | |
| current_database | | | S | | | | | | | | | | | | | | | | | | | |
| greatest | greatest | greatest | S | | | | | | S | S | S | S | S | | | | | | | | | |
| hash | hash | hash | S | | S | S | S | S | S | S | S | | | | | | | | | | | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ object ExpressionMappings {
Sig[CovPopulation](COVAR_POP),
Sig[CovSample](COVAR_SAMP),
Sig[Last](LAST),
Sig[First](FIRST)
Sig[First](FIRST),
Sig[Skewness](SKEWNESS)
)

/** Mapping Spark window expression to Substrait function name */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ object ExpressionNames {
final val FIRST = "first"
final val FIRST_IGNORE_NULL = "first_ignore_null"
final val APPROX_DISTINCT = "approx_distinct"
final val SKEWNESS = "skewness"

// Function names used by Substrait plan.
final val ADD = "add"
Expand Down

0 comments on commit b2e7805

Please sign in to comment.