Skip to content

Commit

Permalink
[GLUTEN-1874][CH] Fixes nullable mismatch in union (apache#1901)
Browse files Browse the repository at this point in the history
Fixes nullable mismatch in union for CH
  • Loading branch information
lgbo-ustc authored Jul 14, 2023
1 parent e05462d commit d942ee1
Show file tree
Hide file tree
Showing 14 changed files with 157 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,5 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
GlutenConfig.GLUTEN_CONFIG_PREFIX + GlutenConfig.GLUTEN_CLICKHOUSE_BACKEND

override def shuffleSupportedCodec(): Set[String] = GLUTNE_CLICKHOUSE_SHUFFLE_SUPPORTED_CODEC
override def needOutputSchemaForPlan(): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
if (leftCondition == null) {
// The computing for this filter is not needed.
context.registerEmptyRelToOperator(operatorId)
return childCtx
// Since some columns' nullability will be removed after this filter, we need to update the
// outputAttributes of child context.
TransformContext(childCtx.inputAttributes, output, childCtx.root)
}

val currRel = if (childCtx != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,41 @@ class GlutenClickHouseTPCHParquetSuite extends GlutenClickHouseTPCHAbstractSuite
compareResultsAgainstVanillaSpark(sql5, true, { _ => }, false)
}

test("GLUTEN-1874 not null in one stream") {
val sql =
"""
|select n_regionkey from (
| select *, row_number() over (partition by n_regionkey order by is_new) as rank from(
| select n_regionkey, 0 as is_new from nation where n_regionkey is not null
| union all
| select n_regionkey, 1 as is_new from (
| select n_regionkey,
| row_number() over (partition by n_regionkey order by n_nationkey) as rn from nation
| ) t0 where rn = 1
| ) t1
|) t2 where rank = 1
""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("GLUTEN-1874 not null in both streams") {
val sql =
"""
|select n_regionkey from (
| select *, row_number() over (partition by n_regionkey order by is_new) as rank from(
| select n_regionkey, 0 as is_new from nation where n_regionkey is not null
| union all
| select n_regionkey, 1 as is_new from (
| select n_regionkey,
| row_number() over (partition by n_regionkey order by n_nationkey) as rn
| from nation where n_regionkey is not null
| ) t0 where rn = 1
| ) t1
|) t2 where rank = 1
""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("GLUTEN-2095: test cast(string as binary)") {
runQueryAndCompare(
"select cast(n_nationkey as binary), cast(n_comment as binary) from nation"
Expand Down
51 changes: 48 additions & 3 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "SerializedPlanParser.h"
#include <algorithm>
#include <memory>
#include <string_view>
#include <AggregateFunctions/AggregateFunctionFactory.h>
Expand Down Expand Up @@ -366,7 +367,7 @@ DB::QueryPlanPtr SerializedPlanParser::parseMergeTreeTable(const substrait::Read
{
auto input_header = query->getCurrentDataStream().header;
std::erase_if(non_nullable_columns, [input_header](auto item) -> bool { return !input_header.has(item); });
auto* remove_null_step = addRemoveNullableStep(*query, non_nullable_columns);
auto * remove_null_step = addRemoveNullableStep(*query, non_nullable_columns);
if (remove_null_step)
{
steps.emplace_back(remove_null_step);
Expand Down Expand Up @@ -396,7 +397,7 @@ SerializedPlanParser::parsePreWhereInfo(const substrait::Expression & rel, Block
prewhere_info->need_filter = true;
prewhere_info->remove_prewhere_column = true;
auto cols = prewhere_info->prewhere_actions->getRequiredColumnsNames();
// Keep it the same as the input.
// Keep it the same as the input.
prewhere_info->prewhere_actions->removeUnusedActions(Names{filter_name}, false, true);
prewhere_info->prewhere_actions->projectInput(false);
for (const auto & name : input.getNames())
Expand Down Expand Up @@ -458,6 +459,51 @@ QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
expression_step->setStepDescription("Rename Output");
query_plan->addStep(std::move(expression_step));
}

// fixes: issue-1874, to keep the nullability as expected.
const auto & output_schema = root_rel.root().output_schema();
if (output_schema.types_size())
{
auto original_header = query_plan->getCurrentDataStream().header;
const auto & original_cols = original_header.getColumnsWithTypeAndName();
if (static_cast<size_t>(output_schema.types_size()) != original_cols.size())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Mismatch output schema");
}
bool need_final_project = false;
DB::ColumnsWithTypeAndName final_cols;
for (int i = 0; i < output_schema.types_size(); ++i)
{
const auto & col = original_cols[i];
auto type = TypeParser::parseType(output_schema.types(i));
// At present, we only check nullable mismatch.
// intermediate aggregate data is special, no check here.
if (type->isNullable() != col.type->isNullable() && !typeid_cast<const DB::DataTypeAggregateFunction*>(col.type.get()))
{
if (type->isNullable())
{
final_cols.emplace_back(type->createColumn(), std::make_shared<DB::DataTypeNullable>(col.type), col.name);
}
else
{
final_cols.emplace_back(type->createColumn(), DB::removeNullable(col.type), col.name);
}
need_final_project = true;
}
else
{
final_cols.push_back(col);
}
}
if (need_final_project)
{
ActionsDAGPtr final_project
= ActionsDAG::makeConvertingActions(original_cols, final_cols, ActionsDAG::MatchColumnsMode::Position);
QueryPlanStepPtr final_project_step = std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(), final_project);
final_project_step->setStepDescription("Project for output schema");
query_plan->addStep(std::move(final_project_step));
}
}
return query_plan;
}
else
Expand Down Expand Up @@ -2591,5 +2637,4 @@ void NonNullableColumnsResolver::visitNonNullable(const substrait::Expression &
}
// else, do nothing.
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.glutenproject.substrait.extensions.ExtensionBuilder;
import io.glutenproject.substrait.extensions.FunctionMappingNode;
import io.glutenproject.substrait.rel.RelNode;
import io.glutenproject.substrait.type.TypeNode;

import java.util.ArrayList;
import java.util.Map;
Expand All @@ -36,12 +37,29 @@ public static PlanNode makePlan(ArrayList<FunctionMappingNode> mappingNodes,
return new PlanNode(mappingNodes, relNodes, outNames);
}

public static PlanNode makePlan(ArrayList<FunctionMappingNode> mappingNodes,
ArrayList<RelNode> relNodes,
ArrayList<String> outNames,
TypeNode outputSchema,
AdvancedExtensionNode extension) {
return new PlanNode(mappingNodes, relNodes, outNames, outputSchema, extension);
}

public static PlanNode makePlan(AdvancedExtensionNode extension) {
return new PlanNode(extension);
}

public static PlanNode makePlan(SubstraitContext subCtx, ArrayList<RelNode> relNodes,
public static PlanNode makePlan(SubstraitContext subCtx,
ArrayList<RelNode> relNodes,
ArrayList<String> outNames) {
return makePlan(subCtx, relNodes, outNames, null, null);
}

public static PlanNode makePlan(SubstraitContext subCtx,
ArrayList<RelNode> relNodes,
ArrayList<String> outNames,
TypeNode outputSchema,
AdvancedExtensionNode extension) {
if (subCtx == null) {
throw new NullPointerException("ColumnarWholestageTransformer cannot doTansform.");
}
Expand All @@ -52,6 +70,9 @@ public static PlanNode makePlan(SubstraitContext subCtx, ArrayList<RelNode> relN
ExtensionBuilder.makeFunctionMapping(entry.getKey(), entry.getValue());
mappingNodes.add(mappingNode);
}
if (extension != null || outputSchema != null) {
return makePlan(mappingNodes, relNodes, outNames, outputSchema, extension);
}
return makePlan(mappingNodes, relNodes, outNames);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.glutenproject.substrait.extensions.AdvancedExtensionNode;
import io.glutenproject.substrait.extensions.FunctionMappingNode;
import io.glutenproject.substrait.rel.RelNode;
import io.glutenproject.substrait.type.TypeNode;
import io.substrait.proto.Plan;
import io.substrait.proto.PlanRel;
import io.substrait.proto.RelRoot;
Expand All @@ -32,6 +33,7 @@ public class PlanNode implements Serializable {
private final ArrayList<RelNode> relNodes = new ArrayList<>();
private final ArrayList<String> outNames = new ArrayList<>();

private TypeNode outputSchema = null;
private AdvancedExtensionNode extension = null;

PlanNode(ArrayList<FunctionMappingNode> mappingNodes,
Expand All @@ -42,6 +44,18 @@ public class PlanNode implements Serializable {
this.outNames.addAll(outNames);
}

PlanNode(ArrayList<FunctionMappingNode> mappingNodes,
ArrayList<RelNode> relNodes,
ArrayList<String> outNames,
TypeNode outputSchema,
AdvancedExtensionNode extension) {
this.mappingNodes.addAll(mappingNodes);
this.relNodes.addAll(relNodes);
this.outNames.addAll(outNames);
this.outputSchema = outputSchema;
this.extension = extension;
}

PlanNode(AdvancedExtensionNode extension) {
this.extension = extension;
}
Expand All @@ -61,6 +75,9 @@ public Plan toProtobuf() {
for (String name : outNames) {
relRootBuilder.addNames(name);
}
if (outputSchema != null) {
relRootBuilder.setOutputSchema(outputSchema.toProtobuf().getStruct());
}
planRelBuilder.setRoot(relRootBuilder.build());

planBuilder.addRelations(planRelBuilder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ message RelRoot {
Rel input = 1;
// Field names in depth-first order
repeated string names = 2;
Type.Struct output_schema = 3;
}

// A relation (used internally in a plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,6 @@ trait BackendSettingsApi {
def rescaleDecimalIntegralExpression(): Boolean = false

def shuffleSupportedCodec(): Set[String]

def needOutputSchemaForPlan(): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import io.glutenproject.substrait.expression.ExpressionNode
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.plan.PlanBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.utils.BindReferencesUtil

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -309,8 +309,7 @@ case class ProjectExecTransformer(projectList: Seq[NamedExpression],
}
assert(currRel != null, "Project Rel should be valid")

val outputAttrs = BindReferencesUtil.bindReferencesWithNullable(output, inputAttributes)
TransformContext(inputAttributes, outputAttrs, currRel)
TransformContext(inputAttributes, output, currRel)
}

override def output: Seq[Attribute] = projectList.map(_.toAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.plan.PlanBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.utils.BindReferencesUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -269,8 +268,7 @@ case class ExpandExecTransformer(projections: Seq[Seq[Expression]],
child.output)
}
assert(currRel != null, "Expand Rel should be valid")
val outputAttrs = BindReferencesUtil.bindReferencesWithNullable(output, inputAttributes)
TransformContext(inputAttributes, outputAttrs, currRel)
TransformContext(inputAttributes, output, currRel)
}

protected override def doExecute(): RDD[InternalRow] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import io.glutenproject.extension.ValidationResult
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.plan.PlanBuilder
import io.glutenproject.utils.BindReferencesUtil
import io.substrait.proto.SortField
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -296,8 +295,7 @@ case class SortExecTransformer(sortOrder: Seq[SortOrder],
child.output)
}
assert(currRel != null, "Sort Rel should be valid")
val outputAttrs = BindReferencesUtil.bindReferencesWithNullable(output, inputAttributes)
TransformContext(inputAttributes, outputAttrs, currRel)
TransformContext(inputAttributes, output, currRel)
}

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
package io.glutenproject.execution

import com.google.common.collect.Lists

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression._
import io.glutenproject.extension.GlutenPlan
import io.glutenproject.metrics.{MetricsUpdater, NoopMetricsUpdater}
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.plan.{PlanBuilder, PlanNode}
import io.glutenproject.substrait.rel.RelNode
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.utils.SubstraitPlanPrinterUtil

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
Expand Down Expand Up @@ -164,12 +168,26 @@ case class WholeStageTransformer(child: SparkPlan)(val transformStageId: Int)
throw new NullPointerException(s"ColumnarWholestageTransformer can't doTansform on $child")
}
val outNames = new java.util.ArrayList[String]()
for (attr <- childCtx.outputAttributes) {
outNames.add(ConverterUtils.genColumnNameWithExprId(attr))
}
val planNode =
PlanBuilder.makePlan(substraitContext, Lists.newArrayList(childCtx.root), outNames)
val planNode = if (BackendsApiManager.getSettings.needOutputSchemaForPlan()) {
val outputTypeNodeList = new java.util.ArrayList[TypeNode]()
for (attr <- childCtx.outputAttributes) {
outNames.add(ConverterUtils.genColumnNameWithExprId(attr))
outputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}

// Fixes issue-1874
val outputSchema = TypeBuilder.makeStruct(false, outputTypeNodeList)
PlanBuilder.makePlan(substraitContext,
Lists.newArrayList(childCtx.root),
outNames,
outputSchema,
null)
} else {
for (attr <- childCtx.outputAttributes) {
outNames.add(ConverterUtils.genColumnNameWithExprId(attr))
}
PlanBuilder.makePlan(substraitContext, Lists.newArrayList(childCtx.root), outNames)
}
planJson = SubstraitPlanPrinterUtil.substraitPlanToJson(planNode.toProtobuf)

WholestageTransformContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import io.glutenproject.substrait.expression.{ExpressionNode, WindowFunctionNode
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.plan.PlanBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.utils.BindReferencesUtil
import io.substrait.proto.SortField

import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -232,8 +231,7 @@ case class WindowExecTransformer(windowExpression: Seq[NamedExpression],
child.output)
}
assert(currRel != null, "Window Rel should be valid")
val outputAttrs = BindReferencesUtil.bindReferencesWithNullable(output, inputAttributes)
TransformContext(inputAttributes, outputAttrs, currRel)
TransformContext(inputAttributes, output, currRel)
}

override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
Expand Down
Loading

0 comments on commit d942ee1

Please sign in to comment.