diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index cb706d817e71..d6e323679a8d 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -21,7 +21,6 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi} import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ import org.apache.gluten.expression._ -import org.apache.gluten.expression.ConverterUtils.FunctionConfig import org.apache.gluten.extension.{CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteToDateExpresstionRule} import org.apache.gluten.extension.columnar.AddTransformHintRule import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides @@ -62,7 +61,6 @@ import org.apache.spark.sql.extension.{CommonSubexpressionEliminateRule, Rewrite import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch -import com.google.common.collect.Lists import org.apache.commons.lang3.ClassUtils import java.lang.{Long => JLong} @@ -76,21 +74,12 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { override def batchType: Convention.BatchType = CHBatch /** Transform GetArrayItem to Substrait. */ - override def genGetArrayItemExpressionNode( + override def genGetArrayItemTransformer( substraitExprName: String, - functionMap: JMap[String, JLong], - leftNode: ExpressionNode, - rightNode: ExpressionNode, - original: GetArrayItem): ExpressionNode = { - val functionName = ConverterUtils.makeFuncName( - substraitExprName, - Seq(original.left.dataType, original.right.dataType), - FunctionConfig.OPT) - val exprNodes = Lists.newArrayList(leftNode, rightNode) - ExpressionBuilder.makeScalarFunction( - ExpressionBuilder.newScalarFunction(functionMap, functionName), - exprNodes, - ConverterUtils.getTypeNode(original.dataType, original.nullable)) + left: ExpressionTransformer, + right: ExpressionTransformer, + original: Expression): ExpressionTransformer = { + GetArrayItemTransformer(substraitExprName, left, right, original) } override def genProjectExecTransformer( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala index 98cc4a930d2f..6403471c7414 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala @@ -222,3 +222,44 @@ case class CHRegExpReplaceTransformer( .doTransform(args) } } + +case class GetArrayItemTransformer( + substraitExprName: String, + left: ExpressionTransformer, + right: ExpressionTransformer, + original: Expression) + extends ExpressionTransformerWithOrigin { + + override def doTransform(args: java.lang.Object): ExpressionNode = { + // Ignore failOnError for clickhouse backend + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val leftNode = left.doTransform(args) + var rightNode = right.doTransform(args) + + val getArrayItem = original.asInstanceOf[GetArrayItem] + + // In Spark, the index of getarrayitem starts from 0 + // But in CH, the index of arrayElement starts from 1, besides index argument must + // So we need to do transform: rightNode = add(rightNode, 1) + val addFunctionName = ConverterUtils.makeFuncName( + ExpressionNames.ADD, + Seq(IntegerType, getArrayItem.right.dataType), + FunctionConfig.OPT) + val addFunctionId = ExpressionBuilder.newScalarFunction(functionMap, addFunctionName) + val literalNode = ExpressionBuilder.makeLiteral(1.toInt, IntegerType, false) + rightNode = ExpressionBuilder.makeScalarFunction( + addFunctionId, + Lists.newArrayList(literalNode, rightNode), + ConverterUtils.getTypeNode(getArrayItem.right.dataType, getArrayItem.right.nullable)) + + val functionName = ConverterUtils.makeFuncName( + substraitExprName, + Seq(getArrayItem.left.dataType, getArrayItem.right.dataType), + FunctionConfig.OPT) + val exprNodes = Lists.newArrayList(leftNode, rightNode) + ExpressionBuilder.makeScalarFunction( + ExpressionBuilder.newScalarFunction(functionMap, functionName), + exprNodes, + ConverterUtils.getTypeNode(getArrayItem.dataType, getArrayItem.nullable)) + } +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 69e56b422561..c30e349529c6 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -22,14 +22,12 @@ import org.apache.gluten.datasource.ArrowConvertorRule import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ import org.apache.gluten.expression._ -import org.apache.gluten.expression.ConverterUtils.FunctionConfig import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet} import org.apache.gluten.extension._ import org.apache.gluten.extension.columnar.TransformHints import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.extension.columnar.transition.ConventionFunc.BatchOverride import org.apache.gluten.sql.shims.SparkShimLoader -import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, IfThenNode} import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSerializeResult} import org.apache.spark.{ShuffleDependency, SparkException} @@ -63,14 +61,10 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -import com.google.common.collect.Lists import org.apache.commons.lang3.ClassUtils import javax.ws.rs.core.UriBuilder -import java.lang.{Long => JLong} -import java.util.{Map => JMap} - import scala.collection.mutable.ListBuffer class VeloxSparkPlanExecApi extends SparkPlanExecApi { @@ -91,49 +85,13 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { VeloxBatch } - /** - * Transform GetArrayItem to Substrait. - * - * arrCol[index] => IF(index < 0, null, ElementAt(arrCol, index + 1)) - */ - override def genGetArrayItemExpressionNode( + /** Transform GetArrayItem to Substrait. */ + override def genGetArrayItemTransformer( substraitExprName: String, - functionMap: JMap[String, JLong], - leftNode: ExpressionNode, - rightNode: ExpressionNode, - original: GetArrayItem): ExpressionNode = { - if (original.dataType.isInstanceOf[DecimalType]) { - val decimalType = original.dataType.asInstanceOf[DecimalType] - val precision = decimalType.precision - if (precision > 18) { - throw new GlutenNotSupportException( - "GetArrayItem not support decimal precision more than 18") - } - } - // ignore origin substraitExprName - val functionName = ConverterUtils.makeFuncName( - ExpressionMappings.expressionsMap(classOf[ElementAt]), - Seq(original.dataType), - FunctionConfig.OPT) - val exprNodes = Lists.newArrayList(leftNode, rightNode) - val resultNode = ExpressionBuilder.makeScalarFunction( - ExpressionBuilder.newScalarFunction(functionMap, functionName), - exprNodes, - ConverterUtils.getTypeNode(original.dataType, original.nullable)) - val nullNode = ExpressionBuilder.makeLiteral(null, original.dataType, false) - val lessThanFuncId = ExpressionBuilder.newScalarFunction( - functionMap, - ConverterUtils.makeFuncName( - ExpressionNames.LESS_THAN, - Seq(original.right.dataType, IntegerType), - FunctionConfig.OPT)) - // right node already add 1 - val literalNode = ExpressionBuilder.makeLiteral(1.toInt, IntegerType, false) - val lessThanFuncNode = ExpressionBuilder.makeScalarFunction( - lessThanFuncId, - Lists.newArrayList(rightNode, literalNode), - ConverterUtils.getTypeNode(BooleanType, true)) - new IfThenNode(Lists.newArrayList(lessThanFuncNode), Lists.newArrayList(nullNode), resultNode) + left: ExpressionTransformer, + right: ExpressionTransformer, + original: Expression): ExpressionTransformer = { + GenericExpressionTransformer(substraitExprName, Seq(left, right), original) } /** Transform NaNvl to Substrait. */ @@ -521,7 +479,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { resultAttrs: Seq[Attribute], child: SparkPlan, evalType: Int): SparkPlan = { - new ColumnarArrowEvalPythonExec(udfs, resultAttrs, child, evalType) + ColumnarArrowEvalPythonExec(udfs, resultAttrs, child, evalType) } /** diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index cba8b6207a7f..d0df35b64e25 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -978,6 +978,23 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { } } + testWithSpecifiedSparkVersion("get", Some("3.4")) { + withTempPath { + path => + Seq[Seq[Integer]](Seq(1, null, 5, 4), Seq(5, -1, 8, 9, -7, 2), Seq.empty, null) + .toDF("value") + .write + .parquet(path.getCanonicalPath) + + spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("array_tbl") + + runQueryAndCompare( + "select get(value, 0), get(value, 1), get(value, 2), get(value, 3) from array_tbl;") { + checkGlutenOperatorMatch[ProjectExecTransformer] + } + } + } + test("length") { runQueryAndCompare( "select length(c_comment), length(cast(c_comment as binary))" + diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index 5a08d83337ec..6f221b78e9ac 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -401,6 +401,7 @@ std::unordered_map SubstraitParser::substraitVeloxFunc {"forall", "all_match"}, {"exists", "any_match"}, {"negative", "unaryminus"}, + {"get_array_item", "get"}, {"arrays_zip", "zip"}}; const std::unordered_map SubstraitParser::typeMap_ = { diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index 51f39a3abdbe..0b08ca20517b 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -194,9 +194,6 @@ bool SubstraitToVeloxPlanValidator::validateScalarFunction( } else if (name == "map_from_arrays") { LOG_VALIDATION_MSG("map_from_arrays is not supported."); return false; - } else if (name == "get_array_item") { - LOG_VALIDATION_MSG("get_array_item is not supported."); - return false; } else if (name == "concat") { for (const auto& type : types) { if (type.find("struct") != std::string::npos || type.find("map") != std::string::npos || diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index a6228e6715e8..7e72b1758c9a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -187,12 +187,11 @@ trait SparkPlanExecApi { } /** Transform GetArrayItem to Substrait. */ - def genGetArrayItemExpressionNode( + def genGetArrayItemTransformer( substraitExprName: String, - functionMap: JMap[String, JLong], - leftNode: ExpressionNode, - rightNode: ExpressionNode, - original: GetArrayItem): ExpressionNode + left: ExpressionTransformer, + right: ExpressionTransformer, + original: Expression): ExpressionTransformer /** Transform NaNvl to Substrait. */ def genNaNvlTransformer( diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala index 68a464f13222..38f65c17893b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala @@ -16,15 +16,11 @@ */ package org.apache.gluten.expression -import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.expression.ConverterUtils.FunctionConfig import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -import com.google.common.collect.Lists import scala.collection.JavaConverters._ @@ -55,41 +51,3 @@ case class CreateArrayTransformer( ExpressionBuilder.makeScalarFunction(functionId, childNodes, typeNode) } } - -case class GetArrayItemTransformer( - substraitExprName: String, - left: ExpressionTransformer, - right: ExpressionTransformer, - failOnError: Boolean, - original: GetArrayItem) - extends ExpressionTransformerWithOrigin { - - override def doTransform(args: java.lang.Object): ExpressionNode = { - // Ignore failOnError for clickhouse backend - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val leftNode = left.doTransform(args) - var rightNode = right.doTransform(args) - - // In Spark, the index of getarrayitem starts from 0 - // But in CH and velox, the index of arrayElement starts from 1, besides index argument must - // So we need to do transform: rightNode = add(rightNode, 1) - val addFunctionName = ConverterUtils.makeFuncName( - ExpressionNames.ADD, - Seq(IntegerType, original.right.dataType), - FunctionConfig.OPT) - val addFunctionId = ExpressionBuilder.newScalarFunction(functionMap, addFunctionName) - val literalNode = ExpressionBuilder.makeLiteral(1.toInt, IntegerType, false) - rightNode = ExpressionBuilder.makeScalarFunction( - addFunctionId, - Lists.newArrayList(literalNode, rightNode), - ConverterUtils.getTypeNode(original.right.dataType, original.right.nullable)) - - BackendsApiManager.getSparkPlanExecApiInstance.genGetArrayItemExpressionNode( - substraitExprName, - functionMap, - leftNode, - rightNode, - original - ) - } -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index b64a23e860fa..e22a20e0dc4c 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -174,11 +174,10 @@ object ExpressionConverter extends SQLConfHelper with Logging { c.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) CreateArrayTransformer(substraitExprName, children, useStringTypeWhenEmpty = true, c) case g: GetArrayItem => - GetArrayItemTransformer( + BackendsApiManager.getSparkPlanExecApiInstance.genGetArrayItemTransformer( substraitExprName, replaceWithExpressionTransformerInternal(g.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(g.right, attributeSeq, expressionsMap), - g.failOnError, g ) case c: CreateMap => diff --git a/gluten-core/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala b/gluten-core/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala index 32266f1a6245..35afc731bc2e 100644 --- a/gluten-core/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala +++ b/gluten-core/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala @@ -60,13 +60,13 @@ abstract class GlutenQueryTest extends PlanTest { minSparkVersion: Option[String] = None, maxSparkVersion: Option[String] = None): Boolean = { var shouldRun = true - if (!minSparkVersion.isEmpty) { + if (minSparkVersion.isDefined) { shouldRun = isSparkVersionGE(minSparkVersion.get) - if (!maxSparkVersion.isEmpty) { + if (maxSparkVersion.isDefined) { shouldRun = shouldRun && isSparkVersionLE(maxSparkVersion.get) } } else { - if (!maxSparkVersion.isEmpty) { + if (maxSparkVersion.isDefined) { shouldRun = isSparkVersionLE(maxSparkVersion.get) } }