diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 606cbd96e026..6e59d4da856e 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -481,6 +481,12 @@ object ExpressionConverter extends SQLConfHelper with Logging { substraitExprName, replaceWithExpressionTransformer0(c.child, attributeSeq, expressionsMap), c) + case c if c.getClass.getSimpleName.equals("CheckOverflowInTableInsert") => + ChildTransformer( + substraitExprName, + replaceWithExpressionTransformer0(expr.children.head, attributeSeq, expressionsMap), + expr + ) case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) => DecimalArithmeticUtil.checkAllowDecimalArithmetic() if (!BackendsApiManager.getSettings.transformCheckOverflow) { diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/sources/GlutenInsertSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/sources/GlutenInsertSuite.scala index 74c4df197759..efbabb0e1778 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/sources/GlutenInsertSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/sources/GlutenInsertSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.executor.OutputMetrics import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.{ColumnarWriteFilesExec, CommandResultExec, GlutenImplicits, QueryExecution} +import org.apache.spark.sql.execution.{ColumnarWriteFilesExec, CommandResultExec, GlutenImplicits, ProjectExec, QueryExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.metric.SQLMetric @@ -623,6 +623,21 @@ class GlutenInsertSuite } } } + + testGluten("GLUTEN-7213: Check no fallback with CheckOverflowInTableInsert") { + withTable("t1", "t2") { + sql("create table t1 (a float) using parquet") + sql("insert into t1 values(1.1)") + sql("create table t2 (b decimal(10,4)) using parquet") + + val df = sql("insert overwrite t2 select * from t1") + val executedPlan = df.queryExecution.executedPlan + .asInstanceOf[CommandResultExec] + .commandPhysicalPlan + assert(find(executedPlan)(_.isInstanceOf[ProjectExecTransformer]).isDefined) + assert(find(executedPlan)(_.isInstanceOf[ProjectExec]).isEmpty) + } + } } class GlutenRenameFromSparkStagingToFinalDirAlwaysTurnsFalseFilesystem extends RawLocalFileSystem { diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/sources/GlutenInsertSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/sources/GlutenInsertSuite.scala index 1cb905e10abf..aeb6df49d3c0 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/sources/GlutenInsertSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/sources/GlutenInsertSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources import org.apache.gluten.GlutenColumnarWriteTestSupport -import org.apache.gluten.execution.SortExecTransformer +import org.apache.gluten.execution.{ProjectExecTransformer, SortExecTransformer} import org.apache.gluten.extension.GlutenPlan import org.apache.spark.SparkConf @@ -25,7 +25,7 @@ import org.apache.spark.executor.OutputMetrics import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.{CommandResultExec, GlutenImplicits, QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.{CommandResultExec, GlutenImplicits, ProjectExec, QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.metric.SQLMetric @@ -595,6 +595,19 @@ class GlutenInsertSuite } } } + + testGluten("GLUTEN-7213: Check no fallback with CheckOverflowInTableInsert") { + withTable("t1", "t2") { + sql("create table t1 (a float) using parquet") + sql("insert into t1 values(1.1)") + sql("create table t2 (b decimal(10,4)) using parquet") + + val df = sql("insert overwrite t2 select * from t1") + val (_, child) = checkWriteFilesAndGetChild(df) + assert(find(child)(_.isInstanceOf[ProjectExecTransformer]).isDefined) + assert(find(child)(_.isInstanceOf[ProjectExec]).isEmpty) + } + } } class GlutenRenameFromSparkStagingToFinalDirAlwaysTurnsFalseFilesystem extends RawLocalFileSystem { diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index c45f0b2d4e9c..0e08e013cb17 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -314,6 +314,7 @@ object ExpressionNames { final val INLINE = "inline" final val POSEXPLODE = "posexplode" final val CHECK_OVERFLOW = "check_overflow" + final val CHECK_OVERFLOW_IN_TABLE_INSERT = "check_overflow_in_table_insert" final val MAKE_DECIMAL = "make_decimal" final val PROMOTE_PRECISION = "promote_precision" final val SPARK_PARTITION_ID = "spark_partition_id" diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index 5e42f66ba3c1..558d7f60d5eb 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -82,7 +82,8 @@ class Spark34Shims extends SparkShims { Sig[RoundFloor](ExpressionNames.FLOOR), Sig[RoundCeil](ExpressionNames.CEIL), Sig[Mask](ExpressionNames.MASK), - Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT) + Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT), + Sig[CheckOverflowInTableInsert](ExpressionNames.CHECK_OVERFLOW_IN_TABLE_INSERT) ) } diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index ddb023b5a4e9..4a6590161c4a 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -82,7 +82,8 @@ class Spark35Shims extends SparkShims { Sig[TimestampAdd](ExpressionNames.TIMESTAMP_ADD), Sig[RoundFloor](ExpressionNames.FLOOR), Sig[RoundCeil](ExpressionNames.CEIL), - Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT) + Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT), + Sig[CheckOverflowInTableInsert](ExpressionNames.CHECK_OVERFLOW_IN_TABLE_INSERT) ) }