Skip to content

Commit

Permalink
[CORE] Support KnownNullable and KnownNotNull (#5365)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yohahaha authored Apr 12, 2024
1 parent a8b3161 commit 62fc603
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ object ExpressionConverter extends SQLConfHelper with Logging {
val expressionsMap = ExpressionMappings.expressionsMap
exprs.map {
expr => replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap)
}.toSeq
}
}

def replaceWithExpressionTransformer(
Expand Down Expand Up @@ -85,7 +85,7 @@ object ExpressionConverter extends SQLConfHelper with Logging {
udf: ScalaUDF,
attributeSeq: Seq[Attribute],
expressionsMap: Map[Class[_], String]): ExpressionTransformer = {
if (!udf.udfName.isDefined) {
if (udf.udfName.isEmpty) {
throw new GlutenNotSupportException("UDF name is not found!")
}
val substraitExprName = UDFMappings.scalaUDFMap.get(udf.udfName.get)
Expand Down Expand Up @@ -488,7 +488,7 @@ object ExpressionConverter extends SQLConfHelper with Logging {
substraitExprName,
replaceWithExpressionTransformerInternal(rand.child, attributeSeq, expressionsMap),
rand)
case _: KnownFloatingPointNormalized | _: NormalizeNaNAndZero | _: PromotePrecision =>
case _: NormalizeNaNAndZero | _: PromotePrecision =>
ChildTransformer(
replaceWithExpressionTransformerInternal(expr.children.head, attributeSeq, expressionsMap)
)
Expand Down Expand Up @@ -570,6 +570,10 @@ object ExpressionConverter extends SQLConfHelper with Logging {
add.dataType,
add.nullable
)
case e: TaggingExpression =>
ChildTransformer(
replaceWithExpressionTransformerInternal(e.child, attributeSeq, expressionsMap)
)
case e: Transformable =>
val childrenTransformers =
e.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ object ExpressionMappings {
Sig[GetStructField](GET_STRUCT_FIELD),
Sig[CreateNamedStruct](NAMED_STRUCT),
// Directly use child expression transformer
Sig[KnownNotNull](KNOWN_NOT_NULL),
Sig[KnownFloatingPointNormalized](KNOWN_FLOATING_POINT_NORMALIZED),
Sig[NormalizeNaNAndZero](NORMALIZE_NANAND_ZERO),
// Specific expression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ object ExpressionNames {
final val MONOTONICALLY_INCREASING_ID = "monotonically_increasing_id"

// Directly use child expression transformer
final val KNOWN_NULLABLE = "known_nullable"
final val KNOWN_NOT_NULL = "known_not_null"
final val KNOWN_FLOATING_POINT_NORMALIZED = "known_floating_point_normalized"
final val NORMALIZE_NANAND_ZERO = "normalize_nanand_zero"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.gluten.sql.shims.spark33
import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects
import org.apache.gluten.expression.{ExpressionNames, Sig}
import org.apache.gluten.expression.ExpressionNames.KNOWN_NULLABLE
import org.apache.gluten.expression.ExpressionNames.TIMESTAMP_ADD
import org.apache.gluten.sql.shims.{ShimDescriptor, SparkShims}

Expand Down Expand Up @@ -71,6 +72,7 @@ class Spark33Shims extends SparkShims {
Sig[SplitPart](ExpressionNames.SPLIT_PART),
Sig[Sec](ExpressionNames.SEC),
Sig[Csc](ExpressionNames.CSC),
Sig[KnownNullable](KNOWN_NULLABLE),
Sig[Empty2Null](ExpressionNames.EMPTY2NULL),
Sig[TimestampAdd](TIMESTAMP_ADD)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.gluten.sql.shims.spark34

import org.apache.gluten.GlutenConfig
import org.apache.gluten.expression.{ExpressionNames, Sig}
import org.apache.gluten.expression.ExpressionNames.KNOWN_NULLABLE
import org.apache.gluten.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark.{ShuffleUtils, SparkContext, SparkContextUtils, SparkException, TaskContext, TaskContextUtils}
Expand Down Expand Up @@ -75,7 +76,9 @@ class Spark34Shims extends SparkShims {
Sig[SplitPart](ExpressionNames.SPLIT_PART),
Sig[Sec](ExpressionNames.SEC),
Sig[Csc](ExpressionNames.CSC),
Sig[Empty2Null](ExpressionNames.EMPTY2NULL))
Sig[KnownNullable](KNOWN_NULLABLE),
Sig[Empty2Null](ExpressionNames.EMPTY2NULL)
)
}

override def aggregateExpressionMappings: Seq[Sig] = {
Expand Down

0 comments on commit 62fc603

Please sign in to comment.