Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] Fix: add validate check for Generate #3682

Merged
merged 11 commits into from
Nov 14, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ package io.glutenproject.execution

import org.apache.spark.SparkConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.functions.{avg, col}
import org.apache.spark.sql.execution.{GenerateExec, RDDScanExec}
import org.apache.spark.sql.functions.{avg, col, udf}
import org.apache.spark.sql.types.{DecimalType, StringType, StructField, StructType}

import scala.collection.JavaConverters
Expand Down Expand Up @@ -545,6 +545,45 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
runQueryAndCompare("SELECT c1, explode(array(c2)) FROM t") {
checkOperatorMatch[GenerateExecTransformer]
}

runQueryAndCompare("SELECT c1, explode(c3) FROM (SELECT c1, array(c2) as c3 FROM t)") {
checkOperatorMatch[GenerateExecTransformer]
}
}
}

test("Add the missing Generate validation check") {
jackylee-ch marked this conversation as resolved.
Show resolved Hide resolved
withTable("t") {
spark
.range(10)
.selectExpr("id as c1", "id as c2")
.write
.format("parquet")
.saveAsTable("t")

// Add a simple UDF to generate the unsupported case
val intToArrayFunc = udf((s: Int) => Array(s))
spark.udf.register("intToArray", intToArrayFunc)

// Testing unsupported case
runQueryAndCompare("SELECT explode(intToArray(c1)) from t;") {
df =>
{
getExecutedPlan(df).exists(plan => plan.exists(_.isInstanceOf[GenerateExec]))
}
}

// Testing unsupported case in case when
JkSelf marked this conversation as resolved.
Show resolved Hide resolved
runQueryAndCompare(
"""
|SELECT explode(case when size(intToArray(c1)) > 0
|then array(c1) else array(c2) end) from t;
|""".stripMargin) {
df =>
{
getExecutedPlan(df).exists(plan => plan.exists(_.isInstanceOf[GenerateExec]))
}
}
}
}
}
43 changes: 42 additions & 1 deletion cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ static const std::unordered_set<std::string> kBlackList = {
"split_part",
"factorial",
"concat_ws",
"from_json",
"rand",
"json_array_length",
"from_unixtime",
Expand Down Expand Up @@ -330,6 +331,15 @@ bool SubstraitToVeloxPlanValidator::validateCast(
return true;
}

bool SubstraitToVeloxPlanValidator::validateIfThen(
const ::substrait::Expression_IfThen& ifThen,
const RowTypePtr& inputType) {
for (const auto& ifThen : ifThen.ifs()) {
return validateExpression(ifThen.if_(), inputType) && validateExpression(ifThen.then(), inputType);
}
return true;
}

bool SubstraitToVeloxPlanValidator::validateExpression(
const ::substrait::Expression& expression,
const RowTypePtr& inputType) {
Expand All @@ -341,6 +351,8 @@ bool SubstraitToVeloxPlanValidator::validateExpression(
return validateLiteral(expression.literal(), inputType);
case ::substrait::Expression::RexTypeCase::kCast:
return validateCast(expression.cast(), inputType);
case ::substrait::Expression::RexTypeCase::kIfThen:
return validateIfThen(expression.if_then(), inputType);
default:
return true;
}
Expand Down Expand Up @@ -394,7 +406,36 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel& fetchR
}

bool SubstraitToVeloxPlanValidator::validate(const ::substrait::GenerateRel& generateRel) {
// TODO(yuan): add check
if (generateRel.has_input() && !validate(generateRel.input())) {
logValidateMsg("native validation failed due to: input validation fails in GenerateRel.");
return false;
}

// Get and validate the input types from extension.
if (!generateRel.has_advanced_extension()) {
logValidateMsg("native validation failed due to: Input types are expected in GenerateRel.");
return false;
}
const auto& extension = generateRel.advanced_extension();
std::vector<TypePtr> types;
if (!validateInputTypes(extension, types)) {
logValidateMsg("native validation failed due to: Validation failed for input types in GenerateRel.");
return false;
}

int32_t inputPlanNodeId = 0;
// Create the fake input names to be used in row type.
std::vector<std::string> names;
names.reserve(types.size());
for (uint32_t colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx));
}
auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));

if (generateRel.has_generator() && !validateExpression(generateRel.generator(), rowType)) {
logValidateMsg("native validation failed due to: input validation fails in GenerateRel.");
return false;
}
return true;
}

Expand Down
3 changes: 3 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class SubstraitToVeloxPlanValidator {
/// Validate Substrait literal.
bool validateLiteral(const ::substrait::Expression_Literal& literal, const RowTypePtr& inputType);

/// Validate Substrait if-then expression.
bool validateIfThen(const ::substrait::Expression_IfThen& ifThen, const RowTypePtr& inputType);

/// Add necessary log for fallback
void logValidateMsg(const std::string& log) {
validateLog_.emplace_back(log);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ case class GenerateExecTransformer(
readRel
}
val projRel =
if (
BackendsApiManager.getSettings.insertPostProjectForGenerate() && needsProjection(generator)
) {
if (BackendsApiManager.getSettings.insertPostProjectForGenerate()) {
// need to insert one projection node for velox backend
val projectExpressions = new JArrayList[ExpressionNode]()
val childOutputNodes = child.output.indices
Expand Down Expand Up @@ -160,10 +158,6 @@ case class GenerateExecTransformer(
TransformContext(child.output, output, relNode)
}

def needsProjection(generator: Generator): Boolean = {
!generator.asInstanceOf[Explode].child.isInstanceOf[AttributeReference]
}

def getRelNode(
context: SubstraitContext,
operatorId: Long,
Expand Down
Loading