Skip to content

Commit

Permalink
add ut
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma committed Jun 20, 2024
1 parent b4ac7d9 commit a73ca72
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
projectList
)
}

override def genCartesianProductExecTransformer(
left: SparkPlan,
right: SparkPlan,
Expand Down Expand Up @@ -861,13 +862,14 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN),
Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG),
Sig[TransformKeys](TRANSFORM_KEYS),
Sig[TransformValues](TRANSFORM_VALUES)
Sig[TransformValues](TRANSFORM_VALUES),
Sig[VeloxDummyExpression](DummyExpression.VELOX_DUMMY_EXPRESSION)
)
}

override def genInjectedFunctions()
: Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
UDFResolver.getFunctionSignatures
UDFResolver.getFunctionSignatures ++ DummyExpression.getInjectedFunctions
}

override def rewriteSpillPath(path: String): String = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.expression

import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.DataType

abstract class DummyExpression(child: Expression) extends UnaryExpression with Serializable {
private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
defineCodeGen(ctx, ev, c => c)

override def dataType: DataType = child.dataType

override def eval(input: InternalRow): Any = {
assert(input.numFields == 1, "The input row of DummyExpression should have only 1 field.")
accessor(input, 0)
}
}

case class SparkDummyExpression(child: Expression) extends DummyExpression(child) {
override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

case class VeloxDummyExpression(child: Expression)
extends DummyExpression(child)
with Transformable {
override def getTransformer(
childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer = {
if (childrenTransformers.size != children.size) {
throw new IllegalStateException(
this.getClass.getSimpleName +
": getTransformer called before children transformer initialized.")
}

GenericExpressionTransformer(DummyExpression.VELOX_DUMMY_EXPRESSION, childrenTransformers, this)
}

override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object DummyExpression {
val VELOX_DUMMY_EXPRESSION = "velox_dummy_expression"

val SPARK_DUMMY_EXPRESSION = "spark_dummy_expression"

def getInjectedFunctions: Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
Seq(
(
new FunctionIdentifier(SPARK_DUMMY_EXPRESSION),
new ExpressionInfo(classOf[SparkDummyExpression].getName, SPARK_DUMMY_EXPRESSION),
(e: Seq[Expression]) => SparkDummyExpression(e.head)),
(
new FunctionIdentifier(VELOX_DUMMY_EXPRESSION),
new ExpressionInfo(classOf[VeloxDummyExpression].getName, VELOX_DUMMY_EXPRESSION),
(e: Seq[Expression]) => VeloxDummyExpression(e.head))
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
Expand Down Expand Up @@ -94,7 +94,8 @@ case class UDFExpression(
dataType: DataType,
nullable: Boolean,
children: Seq[Expression])
extends Transformable {
extends Unevaluable
with Transformable {
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
this.copy(children = newChildren)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.gluten.execution
import org.apache.gluten.GlutenConfig
import org.apache.gluten.datasource.ArrowCSVFileFormat
import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec
import org.apache.gluten.expression.DummyExpression
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -66,14 +67,20 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla

test("select_part_column") {
val df = runQueryAndCompare("select l_shipdate, l_orderkey from lineitem limit 1") {
df => { assert(df.schema.fields.length == 2) }
df =>
{
assert(df.schema.fields.length == 2)
}
}
checkLengthAndPlan(df, 1)
}

test("select_as") {
val df = runQueryAndCompare("select l_shipdate as my_col from lineitem limit 1") {
df => { assert(df.schema.fieldNames(0).equals("my_col")) }
df =>
{
assert(df.schema.fieldNames(0).equals("my_col"))
}
}
checkLengthAndPlan(df, 1)
}
Expand Down Expand Up @@ -1074,6 +1081,13 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
// No ProjectExecTransformer is introduced.
checkSparkOperatorChainMatch[GenerateExecTransformer, FilterExecTransformer]
}

// Test pre-project operator fallback.
runQueryAndCompare(s"""
|SELECT $func(${DummyExpression.VELOX_DUMMY_EXPRESSION}(a)) from t2;
|""".stripMargin) {
checkSparkOperatorMatch[GenerateExec]
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.HiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

trait Transformable extends Unevaluable {
trait Transformable {
def getTransformer(childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,11 @@ class RewriteSparkPlanRulesManager private (rewriteRules: Seq[RewriteSingleNode]
private def getTransformHintBack(
origin: SparkPlan,
rewrittenPlan: SparkPlan): Option[TransformHint] = {
// The rewritten plan may contain more nodes than origin, here use the node name to get it back
val target = rewrittenPlan.collect {
case p if p.nodeName == origin.nodeName => p
// Get possible fallback hint from all nodes before RewrittenNodeWall.
rewrittenPlan.map(TransformHints.getHintOption).find(_.isDefined) match {
case Some(hint) => hint
case None => None
}
assert(target.size == 1)
TransformHints.getHintOption(target.head)
}

private def applyRewriteRules(origin: SparkPlan): (SparkPlan, Option[String]) = {
Expand Down

0 comments on commit a73ca72

Please sign in to comment.