From dbbb4a7dfad14f7dffa208d70cb3ea587c31633e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E9=98=B3=E9=98=B3?= Date: Tue, 16 Jul 2024 14:20:55 +0800 Subject: [PATCH] [CORE] Fix fallback for spark sequence function with literal array data as input (#6433) --- .../execution/ScalarFunctionsValidateSuite.scala | 10 ++++++++++ .../substrait/expression/ExpressionBuilder.java | 15 +++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index 39c1b4560646..3b9e2479547f 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -17,6 +17,7 @@ package org.apache.gluten.execution import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.optimizer.NullPropagation import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.types._ @@ -664,6 +665,15 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { } } + test("Test sequence function optimized by Spark constant folding") { + withSQLConf(("spark.sql.optimizer.excludedRules", NullPropagation.ruleName)) { + runQueryAndCompare("""SELECT sequence(1, 5), l_orderkey + | from lineitem limit 100""".stripMargin) { + checkGlutenOperatorMatch[ProjectExecTransformer] + } + } + } + test("Test raise_error, assert_true function") { runQueryAndCompare("""SELECT assert_true(l_orderkey >= 1), l_orderkey | from lineitem limit 100""".stripMargin) { diff --git a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java index e322e1528cac..16ae5412ea76 100644 --- a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java +++ b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java @@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; @@ -215,6 +217,19 @@ public static LiteralNode makeLiteral(Object obj, TypeNode typeNode) { public static LiteralNode makeLiteral(Object obj, DataType dataType, Boolean nullable) { TypeNode typeNode = ConverterUtils.getTypeNode(dataType, nullable); + if (obj instanceof UnsafeArrayData) { + UnsafeArrayData oldObj = (UnsafeArrayData) obj; + int numElements = oldObj.numElements(); + Object[] elements = new Object[numElements]; + DataType elementType = ((ArrayType) dataType).elementType(); + + for (int i = 0; i < numElements; i++) { + elements[i] = oldObj.get(i, elementType); + } + + GenericArrayData newObj = new GenericArrayData(elements); + return makeListLiteral(newObj, typeNode); + } return makeLiteral(obj, typeNode); }