Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-6544][CH] Support existence join #6548

Merged
merged 2 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ private static native long nativeBuild(
String joinKeys,
int joinType,
boolean hasMixedFiltCondition,
boolean isExistenceJoin,
byte[] namedStruct);

private StorageJoinBuilder() {}
Expand Down Expand Up @@ -89,6 +90,7 @@ public static long build(
joinKey,
joinType,
broadCastContext.hasMixedFiltCondition(),
broadCastContext.isExistenceJoin(),
toNameStruct(output).toByteArray());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftSemi}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.BuildSideRelation
Expand All @@ -44,6 +45,13 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
condition
) {

private val finalJoinType = joinType match {
case ExistenceJoin(_) =>
LeftSemi
case _ =>
joinType
}

override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
val streamedRDD = getColumnarInputRDDs(streamedPlan)
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
Expand All @@ -57,7 +65,13 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
}
val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
val context =
BroadCastHashJoinContext(Seq.empty, joinType, false, buildPlan.output, buildBroadcastTableId)
BroadCastHashJoinContext(
Seq.empty,
finalJoinType,
false,
joinType.isInstanceOf[ExistenceJoin],
buildPlan.output,
buildBroadcastTableId)
val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context)
streamedRDD :+ broadcastRDD
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

import io.substrait.proto.JoinRel

case class CHShuffledHashJoinExecTransformer(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
Expand Down Expand Up @@ -82,6 +84,7 @@ case class BroadCastHashJoinContext(
buildSideJoinKeys: Seq[Expression],
joinType: JoinType,
hasMixedFiltCondition: Boolean,
isExistenceJoin: Boolean,
buildSideStructure: Seq[Attribute],
buildHashTableId: String)

Expand Down Expand Up @@ -112,7 +115,7 @@ case class CHBroadcastHashJoinExecTransformer(
override protected def doValidateInternal(): ValidationResult = {
val shouldFallback =
CHJoinValidateUtil.shouldFallback(
BroadcastHashJoinStrategy(joinType),
BroadcastHashJoinStrategy(finalJoinType),
left.outputSet,
right.outputSet,
condition)
Expand Down Expand Up @@ -141,8 +144,9 @@ case class CHBroadcastHashJoinExecTransformer(
val context =
BroadCastHashJoinContext(
buildKeyExprs,
joinType,
finalJoinType,
isMixedCondition(condition),
joinType.isInstanceOf[ExistenceJoin],
buildPlan.output,
buildHashTableId)
val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context)
Expand All @@ -161,4 +165,33 @@ case class CHBroadcastHashJoinExecTransformer(
}
res
}

// ExistenceJoin is introduced in #SPARK-14781. It returns all rows from the left table with
// a new column to indecate whether the row is matched in the right table.
// Indeed, the ExistenceJoin is transformed into left any join in CH.
// We don't have left any join in substrait, so use left semi join instead.
// and isExistenceJoin is set to true to indicate that it is an existence join.
private val finalJoinType = joinType match {
case ExistenceJoin(_) =>
LeftSemi
case _ =>
joinType
}
override protected lazy val substraitJoinType: JoinRel.JoinType = {
joinType match {
case _: InnerLike =>
JoinRel.JoinType.JOIN_TYPE_INNER
case FullOuter =>
JoinRel.JoinType.JOIN_TYPE_OUTER
case LeftOuter | RightOuter =>
JoinRel.JoinType.JOIN_TYPE_LEFT
case LeftSemi | ExistenceJoin(_) =>
JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
case LeftAnti =>
JoinRel.JoinType.JOIN_TYPE_ANTI
case _ =>
// TODO: Support cross join with Cross Rel
JoinRel.JoinType.UNRECOGNIZED
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ object CHJoinValidateUtil extends Logging {
var shouldFallback = false
val joinType = joinStrategy.joinType
if (joinType.toString.contains("ExistenceJoin")) {
logError("Fallback for join type ExistenceJoin")
return true
}
if (joinType.sql.contains("INNER")) {
Expand All @@ -78,6 +79,9 @@ object CHJoinValidateUtil extends Logging {
case _ => false
}
}
if (shouldFallback) {
logError(s"Fallback for join type $joinType")
}
shouldFallback
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,14 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
Seq("q" + "%d".format(queryNum))
}
val noFallBack = queryNum match {
case i if i == 10 || i == 16 || i == 35 || i == 45 || i == 94 =>
// Q10 BroadcastHashJoin, ExistenceJoin
// Q16 ShuffledHashJoin, NOT condition
// Q35 BroadcastHashJoin, ExistenceJoin
// Q45 BroadcastHashJoin, ExistenceJoin
case i if !isAqe && (i == 10 || i == 16 || i == 35 || i == 94) =>
// q10 smj + existence join
// q16 smj + left semi + not condition
// q35 smj + existence join
// Q94 BroadcastHashJoin, LeftSemi, NOT condition
(false, false)
case i if isAqe && (i == 16 || i == 94) =>
(false, false)
case other => (true, false)
}
sqlNums.map((_, noFallBack._1, noFallBack._2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ import org.apache.spark.SparkConf
class GlutenClickHouseTPCDSParquetSortMergeJoinSuite extends GlutenClickHouseTPCDSAbstractSuite {

override protected def excludedTpcdsQueries: Set[String] = Set(
// fallback due to left semi/anti
// fallback due to left semi/anti/existence join
"q8",
"q10",
"q14a",
"q14b",
"116",
"q23a",
"q23b",
"q35",
"q38",
"q51",
"q69",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,5 +500,36 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {
compareResultsAgainstVanillaSpark(sql2, true, { _ => })

}

test("existence join") {
spark.sql("create table t1(a int, b int) using parquet")
spark.sql("create table t2(a int, b int) using parquet")
spark.sql("insert into t1 values(0, 0), (1, 2), (2, 3), (3, 4), (null, 5), (6, null)")
spark.sql("insert into t2 values(0, 0), (1, 2), (2, 3), (2,4), (null, 5), (6, null)")

val sql1 = """
|select * from t1 where exists (select 1 from t2 where t1.a = t2.a) or t1.a > 1
|""".stripMargin
compareResultsAgainstVanillaSpark(sql1, true, { _ => })

val sql2 = """
|select * from t1 where exists (select 1 from t2 where t1.a = t2.a) or t1.a > 3
|""".stripMargin
compareResultsAgainstVanillaSpark(sql2, true, { _ => })

val sql3 = """
|select * from t1 where exists (select 1 from t2 where t1.a = t2.a) or t1.b > 0
|""".stripMargin
compareResultsAgainstVanillaSpark(sql3, true, { _ => })

val sql4 = """
|select * from t1 where exists (select 1 from t2
|where t1.a = t2.a and t1.b = t2.b) or t1.a > 0
|""".stripMargin
compareResultsAgainstVanillaSpark(sql4, true, { _ => })

spark.sql("drop table t1")
spark.sql("drop table t2")
}
}
// scalastyle:off line.size.limit
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ object CHHashBuildBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark w
(
countsAndBytes.flatMap(_._2),
countsAndBytes.map(_._1).sum,
BroadCastHashJoinContext(Seq(child.output.head), Inner, false, child.output, "")
BroadCastHashJoinContext(Seq(child.output.head), Inner, false, false, child.output, "")
)
}
}
8 changes: 6 additions & 2 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,14 +1090,18 @@ void JoinUtil::reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols)
plan.addStep(std::move(project_step));
}

std::pair<DB::JoinKind, DB::JoinStrictness> JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type)
std::pair<DB::JoinKind, DB::JoinStrictness>
JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type, bool is_existence_join)
{
switch (join_type)
{
case substrait::JoinRel_JoinType_JOIN_TYPE_INNER:
return {DB::JoinKind::Inner, DB::JoinStrictness::All};
case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: {
if (is_existence_join)
return {DB::JoinKind::Left, DB::JoinStrictness::Any};
return {DB::JoinKind::Left, DB::JoinStrictness::Semi};
}
case substrait::JoinRel_JoinType_JOIN_TYPE_ANTI:
return {DB::JoinKind::Left, DB::JoinStrictness::Anti};
case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class JoinUtil
{
public:
static void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols);
static std::pair<DB::JoinKind, DB::JoinStrictness> getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type);
static std::pair<DB::JoinKind, DB::JoinStrictness> getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type, bool is_existence_join);
static std::pair<DB::JoinKind, DB::JoinStrictness> getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type);
};

Expand Down
3 changes: 2 additions & 1 deletion cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
const std::string & join_keys,
jint join_type,
bool has_mixed_join_condition,
bool is_existence_join,
const std::string & named_struct)
{
auto join_key_list = Poco::StringTokenizer(join_keys, ",");
Expand All @@ -112,7 +113,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
if (key.starts_with("BuiltBNLJBroadcastTable-"))
std::tie(kind, strictness) = JoinUtil::getCrossJoinKindAndStrictness(static_cast<substrait::CrossRel_JoinType>(join_type));
else
std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast<substrait::JoinRel_JoinType>(join_type));
std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast<substrait::JoinRel_JoinType>(join_type), is_existence_join);


substrait::NamedStruct substrait_struct;
Expand Down
1 change: 1 addition & 0 deletions cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
const std::string & join_keys,
jint join_type,
bool has_mixed_join_condition,
bool is_existence_join,
const std::string & named_struct);
void cleanBuildHashTable(const std::string & hash_table_id, jlong instance);
std::shared_ptr<StorageJoinFromReadBuffer> getJoin(const std::string & hash_table_id);
Expand Down
35 changes: 32 additions & 3 deletions cpp-ch/local-engine/Parser/JoinRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <Processors/QueryPlan/JoinStep.h>
#include <google/protobuf/wrappers.pb.h>
#include <Common/CHUtil.h>
#include <Functions/FunctionFactory.h>

#include <Poco/Logger.h>
#include <Common/logger_useful.h>
Expand All @@ -51,13 +52,13 @@ using namespace DB;

namespace local_engine
{
std::shared_ptr<DB::TableJoin> createDefaultTableJoin(substrait::JoinRel_JoinType join_type)
std::shared_ptr<DB::TableJoin> createDefaultTableJoin(substrait::JoinRel_JoinType join_type, bool is_existence_join)
{
auto & global_context = SerializedPlanParser::global_context;
auto table_join = std::make_shared<TableJoin>(
global_context->getSettings(), global_context->getGlobalTemporaryVolume(), global_context->getTempDataOnDisk());

std::pair<DB::JoinKind, DB::JoinStrictness> kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type);
std::pair<DB::JoinKind, DB::JoinStrictness> kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type, is_existence_join);
table_join->setKind(kind_and_strictness.first);
table_join->setStrictness(kind_and_strictness.second);
return table_join;
Expand Down Expand Up @@ -219,7 +220,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q
renamePlanColumns(*left, *right, *storage_join);
}

auto table_join = createDefaultTableJoin(join.type());
auto table_join = createDefaultTableJoin(join.type(), join_opt_info.is_existence_join);
DB::Block right_header_before_convert_step = right->getCurrentDataStream().header;
addConvertStep(*table_join, *left, *right);

Expand Down Expand Up @@ -351,11 +352,39 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q
query_plan = std::make_unique<QueryPlan>();
query_plan->unitePlans(std::move(join_step), {std::move(plans)});
}

JoinUtil::reorderJoinOutput(*query_plan, after_join_names);
/// Need to project the right table column into boolean type
if (join_opt_info.is_existence_join)
{
existenceJoinPostProject(*query_plan, left_names);
}

return query_plan;
}


/// We use left any join to implement ExistenceJoin.
/// The result columns of ExistenceJoin are left table columns + one flag column.
/// The flag column indicates whether a left row is matched or not. We build the flag column here.
/// The input plan's header is left table columns + right table columns. If one row in the right row is null,
/// we mark the flag 0, otherwise mark it 1.
void JoinRelParser::existenceJoinPostProject(DB::QueryPlan & plan, const DB::Names & left_input_cols)
{
auto actions_dag = std::make_shared<DB::ActionsDAG>(plan.getCurrentDataStream().header.getColumnsWithTypeAndName());
const auto * right_col_node = actions_dag->getInputs().back();
auto function_builder = DB::FunctionFactory::instance().get("isNotNull", getContext());
const auto * not_null_node = &actions_dag->addFunction(function_builder, {right_col_node}, right_col_node->result_name);
actions_dag->addOrReplaceInOutputs(*not_null_node);
DB::Names required_cols = left_input_cols;
required_cols.emplace_back(not_null_node->result_name);
actions_dag->removeUnusedActions(required_cols);
auto project_step = std::make_unique<DB::ExpressionStep>(plan.getCurrentDataStream(), actions_dag);
project_step->setStepDescription("ExistenceJoin Post Project");
steps.emplace_back(project_step.get());
plan.addStep(std::move(project_step));
}

void JoinRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right)
{
/// If the columns name in right table is duplicated with left table, we need to rename the right table's columns.
Expand Down
2 changes: 2 additions & 0 deletions cpp-ch/local-engine/Parser/JoinRelParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class JoinRelParser : public RelParser

void addPostFilter(DB::QueryPlan & plan, const substrait::JoinRel & join);

void existenceJoinPostProject(DB::QueryPlan & plan, const DB::Names & left_input_cols);

static std::unordered_set<DB::JoinTableSide> extractTableSidesFromExpression(
const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header);
};
Expand Down
3 changes: 2 additions & 1 deletion cpp-ch/local-engine/local_engine_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,7 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild
jstring join_key_,
jint join_type_,
jboolean has_mixed_join_condition,
jboolean is_existence_join,
jbyteArray named_struct)
{
LOCAL_ENGINE_JNI_METHOD_START
Expand All @@ -1126,7 +1127,7 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild
DB::CompressedReadBuffer input(read_buffer_from_java_array);
local_engine::configureCompressedReadBuffer(input);
const auto * obj = make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin(
hash_table_id, input, row_count_, join_key, join_type_, has_mixed_join_condition, struct_string));
hash_table_id, input, row_count_, join_key, join_type_, has_mixed_join_condition, is_existence_join, struct_string));
return obj->instance();
LOCAL_ENGINE_JNI_METHOD_END(env, 0)
}
Expand Down
Loading