diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index b8de30b1b06f..81da24f8ed47 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -1365,4 +1365,30 @@ abstract class ScalarFunctionsValidateSuite extends FunctionsValidateSuite { checkGlutenOperatorMatch[ProjectExecTransformer] } } + + testWithSpecifiedSparkVersion("array insert", Some("3.4")) { + withTempPath { + path => + Seq[Seq[Integer]](Seq(1, null, 5, 4), Seq(5, -1, 8, 9, -7, 2), Seq.empty, null) + .toDF("value") + .write + .parquet(path.getCanonicalPath) + + spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("array_tbl") + + Seq("true", "false").foreach { + legacyNegativeIndex => + withSQLConf("spark.sql.legacy.negativeIndexInArrayInsert" -> legacyNegativeIndex) { + runQueryAndCompare(""" + |select + | array_insert(value, 1, 0), array_insert(value, 10, 0), + | array_insert(value, -1, 0), array_insert(value, -10, 0) + |from array_tbl + |""".stripMargin) { + checkGlutenOperatorMatch[ProjectExecTransformer] + } + } + } + } + } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 6f6e2cf12ee2..606cbd96e026 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -633,6 +633,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformer0(a.function, attributeSeq, expressionsMap), a ) + case arrayInsert if arrayInsert.getClass.getSimpleName.equals("ArrayInsert") => + // Since spark 3.4.0 + val children = SparkShimLoader.getSparkShims.extractExpressionArrayInsert(arrayInsert) + GenericExpressionTransformer( + substraitExprName, + children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), + arrayInsert + ) case s: Shuffle => GenericExpressionTransformer( substraitExprName, diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index 96a615615179..f198bb7e17c9 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -272,6 +272,7 @@ object ExpressionNames { final val SHUFFLE = "shuffle" final val ZIP_WITH = "zip_with" final val FLATTEN = "flatten" + final val ARRAY_INSERT = "array_insert" // Map functions final val CREATE_MAP = "map" diff --git a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala index fa6ed18e9fa8..7671f236c917 100644 --- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala @@ -266,4 +266,8 @@ trait SparkShims { DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale) } } + + def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = { + throw new UnsupportedOperationException("ArrayInsert not supported.") + } } diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index b277139e8300..5e42f66ba3c1 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -81,7 +81,8 @@ class Spark34Shims extends SparkShims { Sig[TimestampAdd](ExpressionNames.TIMESTAMP_ADD), Sig[RoundFloor](ExpressionNames.FLOOR), Sig[RoundCeil](ExpressionNames.CEIL), - Sig[Mask](ExpressionNames.MASK) + Sig[Mask](ExpressionNames.MASK), + Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT) ) } @@ -492,4 +493,9 @@ class Spark34Shims extends SparkShims { RebaseSpec(LegacyBehaviorPolicy.CORRECTED) ) } + + override def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = { + val expr = arrayInsert.asInstanceOf[ArrayInsert] + Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex)) + } } diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index 6474c74fe8f3..ddb023b5a4e9 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -81,7 +81,8 @@ class Spark35Shims extends SparkShims { Sig[Mask](ExpressionNames.MASK), Sig[TimestampAdd](ExpressionNames.TIMESTAMP_ADD), Sig[RoundFloor](ExpressionNames.FLOOR), - Sig[RoundCeil](ExpressionNames.CEIL) + Sig[RoundCeil](ExpressionNames.CEIL), + Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT) ) } @@ -517,4 +518,9 @@ class Spark35Shims extends SparkShims { RebaseSpec(LegacyBehaviorPolicy.CORRECTED) ) } + + override def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = { + val expr = arrayInsert.asInstanceOf[ArrayInsert] + Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex)) + } }