Skip to content

Commit

Permalink
enable cartesion product
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Jul 19, 2024
1 parent 2e908c0 commit afd5f5b
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -299,4 +299,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging {

override def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = true

override def supportCartesianProductExec(): Boolean = true

}
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,15 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
left: SparkPlan,
right: SparkPlan,
condition: Option[Expression]): CartesianProductExecTransformer =
throw new GlutenNotSupportException(
"CartesianProductExecTransformer is not supported in ch backend.")
if (!condition.isEmpty) {
throw new GlutenNotSupportException(
"CartesianProductExecTransformer with condition is not supported in ch backend.")
} else {
CartesianProductExecTransformer(
ColumnarCartesianProductBridge(left),
ColumnarCartesianProductBridge(right),
condition)
}

override def genBroadcastNestedLoopJoinExecTransformer(
left: SparkPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
// for ch
val joinParametersStr = new StringBuffer("JoinParameters:")
joinParametersStr
.append("isBHJ=")
.append(1)
.append("\n")
.append("buildHashTableId=")
.append(buildBroadcastTableId)
.append("\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2769,5 +2769,17 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr

spark.sql("drop table tb_date")
}

test("test CartesianProductExec") {
withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1")) {
val sql = """
|select t1.n_regionkey, t2.n_regionkey from
|(select n_regionkey from nation) t1
|cross join
|(select n_regionkey from nation) t2
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
}
}
// scalastyle:on line.size.limit
70 changes: 46 additions & 24 deletions cpp-ch/local-engine/Parser/CrossRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ using namespace DB;

namespace local_engine
{

std::shared_ptr<DB::TableJoin> createCrossTableJoin(substrait::CrossRel_JoinType join_type)
{
auto & global_context = SerializedPlanParser::global_context;
Expand Down Expand Up @@ -148,16 +147,19 @@ DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB:
optimization_info.ParseFromString(join.advanced_extension().optimization().value());
auto join_opt_info = JoinOptimizationInfo::parse(optimization_info.value());
const auto & storage_join_key = join_opt_info.storage_join_key;
auto storage_join = BroadCastJoinBuilder::getJoin(storage_join_key) ;
renamePlanColumns(*left, *right, *storage_join);
auto storage_join = join_opt_info.is_broadcast ? BroadCastJoinBuilder::getJoin(storage_join_key) : nullptr;
if (storage_join)
renamePlanColumns(*left, *right, *storage_join);
auto table_join = createCrossTableJoin(join.type());
DB::Block right_header_before_convert_step = right->getCurrentDataStream().header;
addConvertStep(*table_join, *left, *right);

// Add a check to find error easily.
if(!blocksHaveEqualStructure(right_header_before_convert_step, right->getCurrentDataStream().header))
if (!blocksHaveEqualStructure(right_header_before_convert_step, right->getCurrentDataStream().header))
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}",
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR,
"For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}",
left->getCurrentDataStream().header.dumpNames(),
right_header_before_convert_step.dumpNames(),
right->getCurrentDataStream().header.dumpNames());
Expand All @@ -173,28 +175,48 @@ DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB:
auto right_header = right->getCurrentDataStream().header;

QueryPlanPtr query_plan;
table_join->addDisjunct();
auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context);
// table_join->resetKeys();
QueryPlanStepPtr join_step = std::make_unique<FilledJoinStep>(left->getCurrentDataStream(), broadcast_hash_join, 8192);

join_step->setStepDescription("STORAGE_JOIN");
steps.emplace_back(join_step.get());
left->addStep(std::move(join_step));
query_plan = std::move(left);
/// hold right plan for profile
extra_plan_holder.emplace_back(std::move(right));

addPostFilter(*query_plan, join);
Names cols;
for (auto after_join_name : after_join_names)
if (storage_join)
{
if (BlockUtil::VIRTUAL_ROW_COUNT_COLUMN == after_join_name)
continue;
/// FIXME: There is mistake in HashJoin::needUsedFlagsForPerRightTableRow which returns true when
/// join clauses is empty. But in fact there should not be any join clause in cross join.
table_join->addDisjunct();

auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context);
// table_join->resetKeys();
QueryPlanStepPtr join_step = std::make_unique<FilledJoinStep>(left->getCurrentDataStream(), broadcast_hash_join, 8192);

join_step->setStepDescription("STORAGE_JOIN");
steps.emplace_back(join_step.get());
left->addStep(std::move(join_step));
query_plan = std::move(left);
/// hold right plan for profile
extra_plan_holder.emplace_back(std::move(right));

addPostFilter(*query_plan, join);
Names cols;
for (auto after_join_name : after_join_names)
{
if (BlockUtil::VIRTUAL_ROW_COUNT_COLUMN == after_join_name)
continue;

cols.emplace_back(after_join_name);
cols.emplace_back(after_join_name);
}
JoinUtil::reorderJoinOutput(*query_plan, cols);
}
else
{
JoinPtr hash_join = std::make_shared<HashJoin>(table_join, right->getCurrentDataStream().header.cloneEmpty());
QueryPlanStepPtr join_step = std::make_unique<DB::JoinStep>(left->getCurrentDataStream(), right->getCurrentDataStream(), hash_join, 8192, 1, false);
join_step->setStepDescription("CROSS_JOIN");
steps.emplace_back(join_step.get());
std::vector<QueryPlanPtr> plans;
plans.emplace_back(std::move(left));
plans.emplace_back(std::move(right));

query_plan = std::make_unique<QueryPlan>();
query_plan->unitePlans(std::move(join_step), {std::move(plans)});
JoinUtil::reorderJoinOutput(*query_plan, after_join_names);
}
JoinUtil::reorderJoinOutput(*query_plan, cols);

return query_plan;
}
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Parser/JoinRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,6 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q

QueryPlanPtr query_plan;

/// Support only one join clause.
table_join->addDisjunct();
/// some examples to explain when the post_join_filter is not empty
/// - on t1.key = t2.key and t1.v1 > 1 and t2.v1 > 1, 't1.v1> 1' is in the post filter. but 't2.v1 > 1'
/// will be pushed down into right table by spark and is not in the post filter. 't1.key = t2.key ' is
Expand Down Expand Up @@ -430,6 +428,8 @@ void JoinRelParser::collectJoinKeys(
{
if (!join_rel.has_expression())
return;
/// Support only one join clause.
table_join.addDisjunct();
const auto & expr = join_rel.expression();
auto & join_clause = table_join.getClauses().back();
std::list<const const substrait::Expression *> expressions_stack;
Expand Down

0 comments on commit afd5f5b

Please sign in to comment.