diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 7b8d523a6d27f..b48da15683e85 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -852,7 +852,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG), Sig[TransformKeys](TRANSFORM_KEYS), - Sig[TransformValues](TRANSFORM_VALUES) + Sig[TransformValues](TRANSFORM_VALUES), + // For test purpose. + Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION) ) } diff --git a/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala new file mode 100644 index 0000000000000..e2af66b599d3d --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.expression + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType + +abstract class DummyExpression(child: Expression) extends UnaryExpression with Serializable { + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen(ctx, ev, c => c) + + override def dataType: DataType = child.dataType + + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, "The input row of DummyExpression should have only 1 field.") + accessor(input, 0) + } +} + +// Can be used as a wrapper to force fall back the original expression to mock the fallback behavior +// of an supported expression in Gluten which fails native validation. +case class VeloxDummyExpression(child: Expression) + extends DummyExpression(child) + with Transformable { + override def getTransformer( + childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer = { + if (childrenTransformers.size != children.size) { + throw new IllegalStateException( + this.getClass.getSimpleName + + ": getTransformer called before children transformer initialized.") + } + + GenericExpressionTransformer( + VeloxDummyExpression.VELOX_DUMMY_EXPRESSION, + childrenTransformers, + this) + } + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) +} + +object VeloxDummyExpression { + val VELOX_DUMMY_EXPRESSION = "velox_dummy_expression" + + private val identifier = new FunctionIdentifier(VELOX_DUMMY_EXPRESSION) + + def registerFunctions(registry: FunctionRegistry): Unit = { + registry.registerFunction( + identifier, + new ExpressionInfo(classOf[VeloxDummyExpression].getName, VELOX_DUMMY_EXPRESSION), + (e: Seq[Expression]) => VeloxDummyExpression(e.head) + ) + } + + def unregisterFunctions(registry: FunctionRegistry): Unit = { + registry.dropFunction(identifier) + } +} diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala index 915fc554584c7..e45e8b6fa6d76 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala @@ -27,7 +27,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo, Unevaluable} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -94,7 +94,8 @@ case class UDFExpression( dataType: DataType, nullable: Boolean, children: Seq[Expression]) - extends Transformable { + extends Unevaluable + with Transformable { override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): Expression = { this.copy(children = newChildren) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala index a892b6f313a4e..9b47a519cd284 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala @@ -19,6 +19,7 @@ package org.apache.gluten.execution import org.apache.gluten.GlutenConfig import org.apache.gluten.datasource.ArrowCSVFileFormat import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec +import org.apache.gluten.expression.VeloxDummyExpression import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.SparkConf @@ -45,6 +46,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla override def beforeAll(): Unit = { super.beforeAll() createTPCHNotNullTables() + VeloxDummyExpression.registerFunctions(spark.sessionState.functionRegistry) + } + + override def afterAll(): Unit = { + VeloxDummyExpression.unregisterFunctions(spark.sessionState.functionRegistry) + super.afterAll() } override protected def sparkConf: SparkConf = { @@ -66,14 +73,20 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla test("select_part_column") { val df = runQueryAndCompare("select l_shipdate, l_orderkey from lineitem limit 1") { - df => { assert(df.schema.fields.length == 2) } + df => + { + assert(df.schema.fields.length == 2) + } } checkLengthAndPlan(df, 1) } test("select_as") { val df = runQueryAndCompare("select l_shipdate as my_col from lineitem limit 1") { - df => { assert(df.schema.fieldNames(0).equals("my_col")) } + df => + { + assert(df.schema.fieldNames(0).equals("my_col")) + } } checkLengthAndPlan(df, 1) } @@ -1074,6 +1087,13 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla // No ProjectExecTransformer is introduced. checkSparkOperatorChainMatch[GenerateExecTransformer, FilterExecTransformer] } + + runQueryAndCompare( + s""" + |SELECT $func(${VeloxDummyExpression.VELOX_DUMMY_EXPRESSION}(a)) from t2; + |""".stripMargin) { + checkGlutenOperatorMatch[GenerateExecTransformer] + } } } } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index 8b8a9262403c4..73047b2f49073 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -26,6 +26,7 @@ #include "utils/ConfigExtractor.h" #include "config/GlutenConfig.h" +#include "operators/plannodes/RowVectorStream.h" namespace gluten { namespace { @@ -710,16 +711,23 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: namespace { void extractUnnestFieldExpr( - std::shared_ptr projNode, + std::shared_ptr child, int32_t index, std::vector& unnestFields) { - auto name = projNode->names()[index]; - auto expr = projNode->projections()[index]; - auto type = expr->type(); + if (auto projNode = std::dynamic_pointer_cast(child)) { + auto name = projNode->names()[index]; + auto expr = projNode->projections()[index]; + auto type = expr->type(); - auto unnestFieldExpr = std::make_shared(type, name); - VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only support field"); - unnestFields.emplace_back(unnestFieldExpr); + auto unnestFieldExpr = std::make_shared(type, name); + VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only support field"); + unnestFields.emplace_back(unnestFieldExpr); + } else { + auto name = child->outputType()->names()[index]; + auto field = child->outputType()->childAt(index); + auto unnestFieldExpr = std::make_shared(field, name); + unnestFields.emplace_back(unnestFieldExpr); + } } } // namespace @@ -752,10 +760,13 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), "injectedProject="); if (injectedProject) { - auto projNode = std::dynamic_pointer_cast(childNode); + // Child should be either ProjectNode or ValueStreamNode in case of project fallback. VELOX_CHECK( - projNode != nullptr && projNode->names().size() > requiredChildOutput.size(), - "injectedProject is true, but the Project is missing or does not have the corresponding projection field") + (std::dynamic_pointer_cast(childNode) != nullptr || + std::dynamic_pointer_cast(childNode) != nullptr) && + childNode->outputType()->size() > requiredChildOutput.size(), + "injectedProject is true, but the ProjectNode or ValueStreamNode (in case of projection fallback)" + " is missing or does not have the corresponding projection field") bool isStack = generateRel.has_advanced_extension() && SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), "isStack="); @@ -768,7 +779,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: // +- Project [fake_column#128, [1,2,3] AS _pre_0#129] // +- RewrittenNodeWall Scan OneRowRelation[fake_column#128] // The last projection column in GeneratorRel's child(Project) is the column we need to unnest - extractUnnestFieldExpr(projNode, projNode->projections().size() - 1, unnest); + auto index = childNode->outputType()->size() - 1; + extractUnnestFieldExpr(childNode, index, unnest); } else { // For stack function, e.g. stack(2, 1,2,3), a sample // input substrait plan is like the following: @@ -782,10 +794,10 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: auto generatorFunc = generator.scalar_function(); auto numRows = SubstraitParser::getLiteralValue(generatorFunc.arguments(0).value().literal()); auto numFields = static_cast(std::ceil((generatorFunc.arguments_size() - 1.0) / numRows)); - auto totalProjectCount = projNode->names().size(); + auto totalProjectCount = childNode->outputType()->size(); for (auto i = totalProjectCount - numFields; i < totalProjectCount; ++i) { - extractUnnestFieldExpr(projNode, i, unnest); + extractUnnestFieldExpr(childNode, i, unnest); } } } else { diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index b7b0889dc1eb8..da5625cd45e54 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.HiveUDFTransformer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -trait Transformable extends Unevaluable { +trait Transformable { def getTransformer(childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer }