Skip to content

Commit

Permalink
fix generator fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma committed Jun 24, 2024
1 parent 7a4a07f commit 2496f2a
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN),
Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG),
Sig[TransformKeys](TRANSFORM_KEYS),
Sig[TransformValues](TRANSFORM_VALUES)
Sig[TransformValues](TRANSFORM_VALUES),
// For test purpose.
Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION)
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.expression

import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.DataType

abstract class DummyExpression(child: Expression) extends UnaryExpression with Serializable {
private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
defineCodeGen(ctx, ev, c => c)

override def dataType: DataType = child.dataType

override def eval(input: InternalRow): Any = {
assert(input.numFields == 1, "The input row of DummyExpression should have only 1 field.")
accessor(input, 0)
}
}

// Can be used as a wrapper to force fall back the original expression to mock the fallback behavior
// of an supported expression in Gluten which fails native validation.
case class VeloxDummyExpression(child: Expression)
extends DummyExpression(child)
with Transformable {
override def getTransformer(
childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer = {
if (childrenTransformers.size != children.size) {
throw new IllegalStateException(
this.getClass.getSimpleName +
": getTransformer called before children transformer initialized.")
}

GenericExpressionTransformer(
VeloxDummyExpression.VELOX_DUMMY_EXPRESSION,
childrenTransformers,
this)
}

override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

object VeloxDummyExpression {
val VELOX_DUMMY_EXPRESSION = "velox_dummy_expression"

private val identifier = new FunctionIdentifier(VELOX_DUMMY_EXPRESSION)

def registerFunctions(registry: FunctionRegistry): Unit = {
registry.registerFunction(
identifier,
new ExpressionInfo(classOf[VeloxDummyExpression].getName, VELOX_DUMMY_EXPRESSION),
(e: Seq[Expression]) => VeloxDummyExpression(e.head)
)
}

def unregisterFunctions(registry: FunctionRegistry): Unit = {
registry.dropFunction(identifier)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
Expand Down Expand Up @@ -94,7 +94,8 @@ case class UDFExpression(
dataType: DataType,
nullable: Boolean,
children: Seq[Expression])
extends Transformable {
extends Unevaluable
with Transformable {
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
this.copy(children = newChildren)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.gluten.execution
import org.apache.gluten.GlutenConfig
import org.apache.gluten.datasource.ArrowCSVFileFormat
import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec
import org.apache.gluten.expression.VeloxDummyExpression
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkConf
Expand All @@ -45,6 +46,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
override def beforeAll(): Unit = {
super.beforeAll()
createTPCHNotNullTables()
VeloxDummyExpression.registerFunctions(spark.sessionState.functionRegistry)
}

override def afterAll(): Unit = {
VeloxDummyExpression.unregisterFunctions(spark.sessionState.functionRegistry)
super.afterAll()
}

override protected def sparkConf: SparkConf = {
Expand All @@ -66,14 +73,20 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla

test("select_part_column") {
val df = runQueryAndCompare("select l_shipdate, l_orderkey from lineitem limit 1") {
df => { assert(df.schema.fields.length == 2) }
df =>
{
assert(df.schema.fields.length == 2)
}
}
checkLengthAndPlan(df, 1)
}

test("select_as") {
val df = runQueryAndCompare("select l_shipdate as my_col from lineitem limit 1") {
df => { assert(df.schema.fieldNames(0).equals("my_col")) }
df =>
{
assert(df.schema.fieldNames(0).equals("my_col"))
}
}
checkLengthAndPlan(df, 1)
}
Expand Down Expand Up @@ -1074,6 +1087,13 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
// No ProjectExecTransformer is introduced.
checkSparkOperatorChainMatch[GenerateExecTransformer, FilterExecTransformer]
}

runQueryAndCompare(
s"""
|SELECT $func(${VeloxDummyExpression.VELOX_DUMMY_EXPRESSION}(a)) from t2;
|""".stripMargin) {
checkGlutenOperatorMatch[GenerateExecTransformer]
}
}
}
}
Expand Down
35 changes: 23 additions & 12 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "utils/ConfigExtractor.h"

#include "config/GlutenConfig.h"
#include "operators/plannodes/RowVectorStream.h"

namespace gluten {
namespace {
Expand Down Expand Up @@ -710,16 +711,24 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
namespace {

void extractUnnestFieldExpr(
std::shared_ptr<const core::ProjectNode> projNode,
std::shared_ptr<const core::PlanNode> child,
int32_t index,
std::vector<core::FieldAccessTypedExprPtr>& unnestFields) {
auto name = projNode->names()[index];
auto expr = projNode->projections()[index];
auto type = expr->type();
if (auto projNode = std::dynamic_pointer_cast<const core::ProjectNode>(child)) {
auto name = projNode->names()[index];
auto expr = projNode->projections()[index];
auto type = expr->type();

auto unnestFieldExpr = std::make_shared<core::FieldAccessTypedExpr>(type, name);
VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only support field");
unnestFields.emplace_back(unnestFieldExpr);
auto unnestFieldExpr = std::make_shared<core::FieldAccessTypedExpr>(type, name);
VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only support field");
unnestFields.emplace_back(unnestFieldExpr);
} else {
auto name = child->outputType()->names()[index];
auto field = child->outputType()->childAt(index);
std::cout << "name: " << name << "type: " << field->toString() << std::endl;
auto unnestFieldExpr = std::make_shared<core::FieldAccessTypedExpr>(field, name);
unnestFields.emplace_back(unnestFieldExpr);
}
}

} // namespace
Expand Down Expand Up @@ -752,9 +761,10 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), "injectedProject=");

if (injectedProject) {
auto projNode = std::dynamic_pointer_cast<const core::ProjectNode>(childNode);
// Child should be either ProjectNode or ValueStreamNode in case of project fallback.
VELOX_CHECK(
projNode != nullptr && projNode->names().size() > requiredChildOutput.size(),
std::dynamic_pointer_cast<const core::ProjectNode>(childNode) != nullptr ||
std::dynamic_pointer_cast<const ValueStreamNode>(childNode) != nullptr,
"injectedProject is true, but the Project is missing or does not have the corresponding projection field")

bool isStack = generateRel.has_advanced_extension() &&
Expand All @@ -768,7 +778,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
// +- Project [fake_column#128, [1,2,3] AS _pre_0#129]
// +- RewrittenNodeWall Scan OneRowRelation[fake_column#128]
// The last projection column in GeneratorRel's child(Project) is the column we need to unnest
extractUnnestFieldExpr(projNode, projNode->projections().size() - 1, unnest);
auto index = childNode->outputType()->size() - 1;
extractUnnestFieldExpr(childNode, index, unnest);
} else {
// For stack function, e.g. stack(2, 1,2,3), a sample
// input substrait plan is like the following:
Expand All @@ -782,10 +793,10 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
auto generatorFunc = generator.scalar_function();
auto numRows = SubstraitParser::getLiteralValue<int32_t>(generatorFunc.arguments(0).value().literal());
auto numFields = static_cast<int32_t>(std::ceil((generatorFunc.arguments_size() - 1.0) / numRows));
auto totalProjectCount = projNode->names().size();
auto totalProjectCount = childNode->outputType()->size();

for (auto i = totalProjectCount - numFields; i < totalProjectCount; ++i) {
extractUnnestFieldExpr(projNode, i, unnest);
extractUnnestFieldExpr(childNode, i, unnest);
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.HiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

trait Transformable extends Unevaluable {
trait Transformable {
def getTransformer(childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer
}

Expand Down

0 comments on commit 2496f2a

Please sign in to comment.