From 62fc603819df9fcc67dbca6179b5afe7b11bc419 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Fri, 12 Apr 2024 11:57:23 +0800 Subject: [PATCH] [CORE] Support KnownNullable and KnownNotNull (#5365) --- .../apache/gluten/expression/ExpressionConverter.scala | 10 +++++++--- .../apache/gluten/expression/ExpressionMappings.scala | 1 + .../org/apache/gluten/expression/ExpressionNames.scala | 2 ++ .../apache/gluten/sql/shims/spark33/Spark33Shims.scala | 2 ++ .../apache/gluten/sql/shims/spark34/Spark34Shims.scala | 5 ++++- 5 files changed, 16 insertions(+), 4 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 26295678aacf..66cac9b0d06e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -54,7 +54,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { val expressionsMap = ExpressionMappings.expressionsMap exprs.map { expr => replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap) - }.toSeq + } } def replaceWithExpressionTransformer( @@ -85,7 +85,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { udf: ScalaUDF, attributeSeq: Seq[Attribute], expressionsMap: Map[Class[_], String]): ExpressionTransformer = { - if (!udf.udfName.isDefined) { + if (udf.udfName.isEmpty) { throw new GlutenNotSupportException("UDF name is not found!") } val substraitExprName = UDFMappings.scalaUDFMap.get(udf.udfName.get) @@ -488,7 +488,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { substraitExprName, replaceWithExpressionTransformerInternal(rand.child, attributeSeq, expressionsMap), rand) - case _: KnownFloatingPointNormalized | _: NormalizeNaNAndZero | _: PromotePrecision => + case _: NormalizeNaNAndZero | _: PromotePrecision => ChildTransformer( replaceWithExpressionTransformerInternal(expr.children.head, attributeSeq, expressionsMap) ) @@ -570,6 +570,10 @@ object ExpressionConverter extends SQLConfHelper with Logging { add.dataType, add.nullable ) + case e: TaggingExpression => + ChildTransformer( + replaceWithExpressionTransformerInternal(e.child, attributeSeq, expressionsMap) + ) case e: Transformable => val childrenTransformers = e.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index ce410842b353..618798b15c00 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -241,6 +241,7 @@ object ExpressionMappings { Sig[GetStructField](GET_STRUCT_FIELD), Sig[CreateNamedStruct](NAMED_STRUCT), // Directly use child expression transformer + Sig[KnownNotNull](KNOWN_NOT_NULL), Sig[KnownFloatingPointNormalized](KNOWN_FLOATING_POINT_NORMALIZED), Sig[NormalizeNaNAndZero](NORMALIZE_NANAND_ZERO), // Specific expression diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index d0be6b599ba8..ca8b098aa32d 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -287,6 +287,8 @@ object ExpressionNames { final val MONOTONICALLY_INCREASING_ID = "monotonically_increasing_id" // Directly use child expression transformer + final val KNOWN_NULLABLE = "known_nullable" + final val KNOWN_NOT_NULL = "known_not_null" final val KNOWN_FLOATING_POINT_NORMALIZED = "known_floating_point_normalized" final val NORMALIZE_NANAND_ZERO = "normalize_nanand_zero" diff --git a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala index d20266508964..8537211e9819 100644 --- a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala @@ -19,6 +19,7 @@ package org.apache.gluten.sql.shims.spark33 import org.apache.gluten.GlutenConfig import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects import org.apache.gluten.expression.{ExpressionNames, Sig} +import org.apache.gluten.expression.ExpressionNames.KNOWN_NULLABLE import org.apache.gluten.expression.ExpressionNames.TIMESTAMP_ADD import org.apache.gluten.sql.shims.{ShimDescriptor, SparkShims} @@ -71,6 +72,7 @@ class Spark33Shims extends SparkShims { Sig[SplitPart](ExpressionNames.SPLIT_PART), Sig[Sec](ExpressionNames.SEC), Sig[Csc](ExpressionNames.CSC), + Sig[KnownNullable](KNOWN_NULLABLE), Sig[Empty2Null](ExpressionNames.EMPTY2NULL), Sig[TimestampAdd](TIMESTAMP_ADD) ) diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index b667ead63814..fe06d7857a86 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -18,6 +18,7 @@ package org.apache.gluten.sql.shims.spark34 import org.apache.gluten.GlutenConfig import org.apache.gluten.expression.{ExpressionNames, Sig} +import org.apache.gluten.expression.ExpressionNames.KNOWN_NULLABLE import org.apache.gluten.sql.shims.{ShimDescriptor, SparkShims} import org.apache.spark.{ShuffleUtils, SparkContext, SparkContextUtils, SparkException, TaskContext, TaskContextUtils} @@ -75,7 +76,9 @@ class Spark34Shims extends SparkShims { Sig[SplitPart](ExpressionNames.SPLIT_PART), Sig[Sec](ExpressionNames.SEC), Sig[Csc](ExpressionNames.CSC), - Sig[Empty2Null](ExpressionNames.EMPTY2NULL)) + Sig[KnownNullable](KNOWN_NULLABLE), + Sig[Empty2Null](ExpressionNames.EMPTY2NULL) + ) } override def aggregateExpressionMappings: Seq[Sig] = {