From a73ca729b6add58637de430b6fb4249525e6a6af Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Thu, 20 Jun 2024 09:57:04 +0000 Subject: [PATCH] add ut --- .../velox/VeloxSparkPlanExecApi.scala | 6 +- .../gluten/expression/DummyExpression.scala | 77 +++++++++++++++++++ .../spark/sql/expression/UDFResolver.scala | 5 +- .../gluten/execution/TestOperator.scala | 18 ++++- .../expression/ExpressionConverter.scala | 2 +- .../RewriteSparkPlanRulesManager.scala | 9 +-- 6 files changed, 105 insertions(+), 12 deletions(-) create mode 100644 backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 1f868c4c2044a..10ad41e6d50c8 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -491,6 +491,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { projectList ) } + override def genCartesianProductExecTransformer( left: SparkPlan, right: SparkPlan, @@ -861,13 +862,14 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG), Sig[TransformKeys](TRANSFORM_KEYS), - Sig[TransformValues](TRANSFORM_VALUES) + Sig[TransformValues](TRANSFORM_VALUES), + Sig[VeloxDummyExpression](DummyExpression.VELOX_DUMMY_EXPRESSION) ) } override def genInjectedFunctions() : Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = { - UDFResolver.getFunctionSignatures + UDFResolver.getFunctionSignatures ++ DummyExpression.getInjectedFunctions } override def rewriteSpillPath(path: String): String = { diff --git a/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala new file mode 100644 index 0000000000000..2d8e30676b18c --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.expression + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType + +abstract class DummyExpression(child: Expression) extends UnaryExpression with Serializable { + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen(ctx, ev, c => c) + + override def dataType: DataType = child.dataType + + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, "The input row of DummyExpression should have only 1 field.") + accessor(input, 0) + } +} + +case class SparkDummyExpression(child: Expression) extends DummyExpression(child) { + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) +} + +case class VeloxDummyExpression(child: Expression) + extends DummyExpression(child) + with Transformable { + override def getTransformer( + childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer = { + if (childrenTransformers.size != children.size) { + throw new IllegalStateException( + this.getClass.getSimpleName + + ": getTransformer called before children transformer initialized.") + } + + GenericExpressionTransformer(DummyExpression.VELOX_DUMMY_EXPRESSION, childrenTransformers, this) + } + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) +} + +object DummyExpression { + val VELOX_DUMMY_EXPRESSION = "velox_dummy_expression" + + val SPARK_DUMMY_EXPRESSION = "spark_dummy_expression" + + def getInjectedFunctions: Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = { + Seq( + ( + new FunctionIdentifier(SPARK_DUMMY_EXPRESSION), + new ExpressionInfo(classOf[SparkDummyExpression].getName, SPARK_DUMMY_EXPRESSION), + (e: Seq[Expression]) => SparkDummyExpression(e.head)), + ( + new FunctionIdentifier(VELOX_DUMMY_EXPRESSION), + new ExpressionInfo(classOf[VeloxDummyExpression].getName, VELOX_DUMMY_EXPRESSION), + (e: Seq[Expression]) => VeloxDummyExpression(e.head)) + ) + } +} diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala index 915fc554584c7..e45e8b6fa6d76 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala @@ -27,7 +27,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo, Unevaluable} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -94,7 +94,8 @@ case class UDFExpression( dataType: DataType, nullable: Boolean, children: Seq[Expression]) - extends Transformable { + extends Unevaluable + with Transformable { override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): Expression = { this.copy(children = newChildren) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala index a892b6f313a4e..f4fa542aed74b 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala @@ -19,6 +19,7 @@ package org.apache.gluten.execution import org.apache.gluten.GlutenConfig import org.apache.gluten.datasource.ArrowCSVFileFormat import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec +import org.apache.gluten.expression.DummyExpression import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.SparkConf @@ -66,14 +67,20 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla test("select_part_column") { val df = runQueryAndCompare("select l_shipdate, l_orderkey from lineitem limit 1") { - df => { assert(df.schema.fields.length == 2) } + df => + { + assert(df.schema.fields.length == 2) + } } checkLengthAndPlan(df, 1) } test("select_as") { val df = runQueryAndCompare("select l_shipdate as my_col from lineitem limit 1") { - df => { assert(df.schema.fieldNames(0).equals("my_col")) } + df => + { + assert(df.schema.fieldNames(0).equals("my_col")) + } } checkLengthAndPlan(df, 1) } @@ -1074,6 +1081,13 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla // No ProjectExecTransformer is introduced. checkSparkOperatorChainMatch[GenerateExecTransformer, FilterExecTransformer] } + + // Test pre-project operator fallback. + runQueryAndCompare(s""" + |SELECT $func(${DummyExpression.VELOX_DUMMY_EXPRESSION}(a)) from t2; + |""".stripMargin) { + checkSparkOperatorMatch[GenerateExec] + } } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index b7b0889dc1eb8..da5625cd45e54 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.HiveUDFTransformer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -trait Transformable extends Unevaluable { +trait Transformable { def getTransformer(childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala index ac663314bead5..57150760370f7 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala @@ -70,12 +70,11 @@ class RewriteSparkPlanRulesManager private (rewriteRules: Seq[RewriteSingleNode] private def getTransformHintBack( origin: SparkPlan, rewrittenPlan: SparkPlan): Option[TransformHint] = { - // The rewritten plan may contain more nodes than origin, here use the node name to get it back - val target = rewrittenPlan.collect { - case p if p.nodeName == origin.nodeName => p + // Get possible fallback hint from all nodes before RewrittenNodeWall. + rewrittenPlan.map(TransformHints.getHintOption).find(_.isDefined) match { + case Some(hint) => hint + case None => None } - assert(target.size == 1) - TransformHints.getHintOption(target.head) } private def applyRewriteRules(origin: SparkPlan): (SparkPlan, Option[String]) = {