Skip to content

Commit

Permalink
Add array insert function for spark 3.4+
Browse files Browse the repository at this point in the history
  • Loading branch information
ivoson committed Sep 5, 2024
1 parent d289b54 commit 07be952
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1365,4 +1365,30 @@ abstract class ScalarFunctionsValidateSuite extends FunctionsValidateSuite {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}

testWithSpecifiedSparkVersion("array insert", Some("3.4")) {
withTempPath {
path =>
Seq[Seq[Integer]](Seq(1, null, 5, 4), Seq(5, -1, 8, 9, -7, 2), Seq.empty, null)
.toDF("value")
.write
.parquet(path.getCanonicalPath)

spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("array_tbl")

Seq("true", "false").foreach {
legacyNegativeIndex =>
withSQLConf("spark.sql.legacy.negativeIndexInArrayInsert" -> legacyNegativeIndex) {
runQueryAndCompare("""
|select
| array_insert(value, 1, 0), array_insert(value, 10, 0),
| array_insert(value, -1, 0), array_insert(value, -10, 0)
|from array_tbl
|""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,14 @@ object ExpressionConverter extends SQLConfHelper with Logging {
replaceWithExpressionTransformer0(a.function, attributeSeq, expressionsMap),
a
)
case arrayInsert if arrayInsert.getClass.getSimpleName.equals("ArrayInsert") =>
// Since spark 3.4.0
val children = SparkShimLoader.getSparkShims.extractExpressionArrayInsert(arrayInsert)
GenericExpressionTransformer(
substraitExprName,
children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)),
arrayInsert
)
case s: Shuffle =>
GenericExpressionTransformer(
substraitExprName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ object ExpressionNames {
final val SHUFFLE = "shuffle"
final val ZIP_WITH = "zip_with"
final val FLATTEN = "flatten"
final val ARRAY_INSERT = "array_insert"

// Map functions
final val CREATE_MAP = "map"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,8 @@ trait SparkShims {
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
}
}

def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = {
throw new UnsupportedOperationException("ArrayInsert not supported.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ class Spark34Shims extends SparkShims {
Sig[TimestampAdd](ExpressionNames.TIMESTAMP_ADD),
Sig[RoundFloor](ExpressionNames.FLOOR),
Sig[RoundCeil](ExpressionNames.CEIL),
Sig[Mask](ExpressionNames.MASK)
Sig[Mask](ExpressionNames.MASK),
Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT)
)
}

Expand Down Expand Up @@ -492,4 +493,9 @@ class Spark34Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}

override def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = {
val expr = arrayInsert.asInstanceOf[ArrayInsert]
Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ class Spark35Shims extends SparkShims {
Sig[Mask](ExpressionNames.MASK),
Sig[TimestampAdd](ExpressionNames.TIMESTAMP_ADD),
Sig[RoundFloor](ExpressionNames.FLOOR),
Sig[RoundCeil](ExpressionNames.CEIL)
Sig[RoundCeil](ExpressionNames.CEIL),
Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT)
)
}

Expand Down Expand Up @@ -517,4 +518,9 @@ class Spark35Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}

override def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = {
val expr = arrayInsert.asInstanceOf[ArrayInsert]
Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex))
}
}

0 comments on commit 07be952

Please sign in to comment.