Skip to content

Commit

Permalink
[CORE] Fix fallback for spark sequence function with literal array da…
Browse files Browse the repository at this point in the history
…ta as input (#6433)
  • Loading branch information
gaoyangxiaozhu authored Jul 16, 2024
1 parent 86a683a commit dbbb4a7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.execution

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.optimizer.NullPropagation
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -664,6 +665,15 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest {
}
}

test("Test sequence function optimized by Spark constant folding") {
withSQLConf(("spark.sql.optimizer.excludedRules", NullPropagation.ruleName)) {
runQueryAndCompare("""SELECT sequence(1, 5), l_orderkey
| from lineitem limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}
}

test("Test raise_error, assert_true function") {
runQueryAndCompare("""SELECT assert_true(l_orderkey >= 1), l_orderkey
| from lineitem limit 100""".stripMargin) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.GenericArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.*;

Expand Down Expand Up @@ -215,6 +217,19 @@ public static LiteralNode makeLiteral(Object obj, TypeNode typeNode) {

public static LiteralNode makeLiteral(Object obj, DataType dataType, Boolean nullable) {
TypeNode typeNode = ConverterUtils.getTypeNode(dataType, nullable);
if (obj instanceof UnsafeArrayData) {
UnsafeArrayData oldObj = (UnsafeArrayData) obj;
int numElements = oldObj.numElements();
Object[] elements = new Object[numElements];
DataType elementType = ((ArrayType) dataType).elementType();

for (int i = 0; i < numElements; i++) {
elements[i] = oldObj.get(i, elementType);
}

GenericArrayData newObj = new GenericArrayData(elements);
return makeListLiteral(newObj, typeNode);
}
return makeLiteral(obj, typeNode);
}

Expand Down

0 comments on commit dbbb4a7

Please sign in to comment.