From a784e45662117b2192a166d24d463257df7311c4 Mon Sep 17 00:00:00 2001 From: lgbo Date: Mon, 29 Jul 2024 10:00:01 +0800 Subject: [PATCH] [GLUTEN-6544][CH] Support existence join (#6548) * support existence join * fixed tests --- .../gluten/vectorized/StorageJoinBuilder.java | 2 + ...oadcastNestedLoopJoinExecTransformer.scala | 16 +++++++- .../execution/CHHashJoinExecTransformer.scala | 37 ++++++++++++++++++- .../gluten/utils/CHJoinValidateUtil.scala | 4 ++ .../GlutenClickHouseTPCDSAbstractSuite.scala | 11 +++--- ...kHouseTPCDSParquetSortMergeJoinSuite.scala | 5 ++- .../execution/GlutenClickHouseTPCHSuite.scala | 31 ++++++++++++++++ .../benchmarks/CHHashBuildBenchmark.scala | 2 +- cpp-ch/local-engine/Common/CHUtil.cpp | 8 +++- cpp-ch/local-engine/Common/CHUtil.h | 2 +- .../Join/BroadCastJoinBuilder.cpp | 3 +- .../local-engine/Join/BroadCastJoinBuilder.h | 1 + cpp-ch/local-engine/Parser/JoinRelParser.cpp | 35 ++++++++++++++++-- cpp-ch/local-engine/Parser/JoinRelParser.h | 2 + cpp-ch/local-engine/local_engine_jni.cpp | 3 +- 15 files changed, 144 insertions(+), 18 deletions(-) diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java index 27725998feeb..ae7b89120cd4 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java @@ -45,6 +45,7 @@ private static native long nativeBuild( String joinKeys, int joinType, boolean hasMixedFiltCondition, + boolean isExistenceJoin, byte[] namedStruct); private StorageJoinBuilder() {} @@ -89,6 +90,7 @@ public static long build( joinKey, joinType, broadCastContext.hasMixedFiltCondition(), + broadCastContext.isExistenceJoin(), toNameStruct(output).toByteArray()); } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala index d1dc76045338..3aab5a6eb998 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.rpc.GlutenDriverEndpoint import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftSemi} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.BuildSideRelation @@ -44,6 +45,13 @@ case class CHBroadcastNestedLoopJoinExecTransformer( condition ) { + private val finalJoinType = joinType match { + case ExistenceJoin(_) => + LeftSemi + case _ => + joinType + } + override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = { val streamedRDD = getColumnarInputRDDs(streamedPlan) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) @@ -57,7 +65,13 @@ case class CHBroadcastNestedLoopJoinExecTransformer( } val broadcast = buildPlan.executeBroadcast[BuildSideRelation]() val context = - BroadCastHashJoinContext(Seq.empty, joinType, false, buildPlan.output, buildBroadcastTableId) + BroadCastHashJoinContext( + Seq.empty, + finalJoinType, + false, + joinType.isInstanceOf[ExistenceJoin], + buildPlan.output, + buildBroadcastTableId) val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context) streamedRDD :+ broadcastRDD } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 48870892d290..c44156373528 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch +import io.substrait.proto.JoinRel + case class CHShuffledHashJoinExecTransformer( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -82,6 +84,7 @@ case class BroadCastHashJoinContext( buildSideJoinKeys: Seq[Expression], joinType: JoinType, hasMixedFiltCondition: Boolean, + isExistenceJoin: Boolean, buildSideStructure: Seq[Attribute], buildHashTableId: String) @@ -112,7 +115,7 @@ case class CHBroadcastHashJoinExecTransformer( override protected def doValidateInternal(): ValidationResult = { val shouldFallback = CHJoinValidateUtil.shouldFallback( - BroadcastHashJoinStrategy(joinType), + BroadcastHashJoinStrategy(finalJoinType), left.outputSet, right.outputSet, condition) @@ -141,8 +144,9 @@ case class CHBroadcastHashJoinExecTransformer( val context = BroadCastHashJoinContext( buildKeyExprs, - joinType, + finalJoinType, isMixedCondition(condition), + joinType.isInstanceOf[ExistenceJoin], buildPlan.output, buildHashTableId) val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context) @@ -161,4 +165,33 @@ case class CHBroadcastHashJoinExecTransformer( } res } + + // ExistenceJoin is introduced in #SPARK-14781. It returns all rows from the left table with + // a new column to indecate whether the row is matched in the right table. + // Indeed, the ExistenceJoin is transformed into left any join in CH. + // We don't have left any join in substrait, so use left semi join instead. + // and isExistenceJoin is set to true to indicate that it is an existence join. + private val finalJoinType = joinType match { + case ExistenceJoin(_) => + LeftSemi + case _ => + joinType + } + override protected lazy val substraitJoinType: JoinRel.JoinType = { + joinType match { + case _: InnerLike => + JoinRel.JoinType.JOIN_TYPE_INNER + case FullOuter => + JoinRel.JoinType.JOIN_TYPE_OUTER + case LeftOuter | RightOuter => + JoinRel.JoinType.JOIN_TYPE_LEFT + case LeftSemi | ExistenceJoin(_) => + JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI + case LeftAnti => + JoinRel.JoinType.JOIN_TYPE_ANTI + case _ => + // TODO: Support cross join with Cross Rel + JoinRel.JoinType.UNRECOGNIZED + } + } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala index dae8e6e073a1..08b5ef5b2ef0 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala @@ -55,6 +55,7 @@ object CHJoinValidateUtil extends Logging { var shouldFallback = false val joinType = joinStrategy.joinType if (joinType.toString.contains("ExistenceJoin")) { + logError("Fallback for join type ExistenceJoin") return true } if (joinType.sql.contains("INNER")) { @@ -78,6 +79,9 @@ object CHJoinValidateUtil extends Logging { case _ => false } } + if (shouldFallback) { + logError(s"Fallback for join type $joinType") + } shouldFallback } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala index f0712bf5af10..9787182ed93f 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala @@ -58,13 +58,14 @@ abstract class GlutenClickHouseTPCDSAbstractSuite Seq("q" + "%d".format(queryNum)) } val noFallBack = queryNum match { - case i if i == 10 || i == 16 || i == 35 || i == 45 || i == 94 => - // Q10 BroadcastHashJoin, ExistenceJoin - // Q16 ShuffledHashJoin, NOT condition - // Q35 BroadcastHashJoin, ExistenceJoin - // Q45 BroadcastHashJoin, ExistenceJoin + case i if !isAqe && (i == 10 || i == 16 || i == 35 || i == 94) => + // q10 smj + existence join + // q16 smj + left semi + not condition + // q35 smj + existence join // Q94 BroadcastHashJoin, LeftSemi, NOT condition (false, false) + case i if isAqe && (i == 16 || i == 94) => + (false, false) case other => (true, false) } sqlNums.map((_, noFallBack._1, noFallBack._2)) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala index 3f7816cb84ad..3ec4e31a4109 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala @@ -23,12 +23,15 @@ import org.apache.spark.SparkConf class GlutenClickHouseTPCDSParquetSortMergeJoinSuite extends GlutenClickHouseTPCDSAbstractSuite { override protected def excludedTpcdsQueries: Set[String] = Set( - // fallback due to left semi/anti + // fallback due to left semi/anti/existence join "q8", + "q10", "q14a", "q14b", + "116", "q23a", "q23b", + "q35", "q38", "q51", "q69", diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala index d26891ddb1ea..1c09449c817f 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala @@ -500,5 +500,36 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite { compareResultsAgainstVanillaSpark(sql2, true, { _ => }) } + + test("existence join") { + spark.sql("create table t1(a int, b int) using parquet") + spark.sql("create table t2(a int, b int) using parquet") + spark.sql("insert into t1 values(0, 0), (1, 2), (2, 3), (3, 4), (null, 5), (6, null)") + spark.sql("insert into t2 values(0, 0), (1, 2), (2, 3), (2,4), (null, 5), (6, null)") + + val sql1 = """ + |select * from t1 where exists (select 1 from t2 where t1.a = t2.a) or t1.a > 1 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql1, true, { _ => }) + + val sql2 = """ + |select * from t1 where exists (select 1 from t2 where t1.a = t2.a) or t1.a > 3 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql2, true, { _ => }) + + val sql3 = """ + |select * from t1 where exists (select 1 from t2 where t1.a = t2.a) or t1.b > 0 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql3, true, { _ => }) + + val sql4 = """ + |select * from t1 where exists (select 1 from t2 + |where t1.a = t2.a and t1.b = t2.b) or t1.a > 0 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql4, true, { _ => }) + + spark.sql("drop table t1") + spark.sql("drop table t2") + } } // scalastyle:off line.size.limit diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala index 8d4bee554625..141bf5eea5cb 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala @@ -104,7 +104,7 @@ object CHHashBuildBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark w ( countsAndBytes.flatMap(_._2), countsAndBytes.map(_._1).sum, - BroadCastHashJoinContext(Seq(child.output.head), Inner, false, child.output, "") + BroadCastHashJoinContext(Seq(child.output.head), Inner, false, false, child.output, "") ) } } diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 787277dbefb1..3a699b50e302 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -1089,14 +1089,18 @@ void JoinUtil::reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols) plan.addStep(std::move(project_step)); } -std::pair JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type) +std::pair +JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type, bool is_existence_join) { switch (join_type) { case substrait::JoinRel_JoinType_JOIN_TYPE_INNER: return {DB::JoinKind::Inner, DB::JoinStrictness::All}; - case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: + case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: { + if (is_existence_join) + return {DB::JoinKind::Left, DB::JoinStrictness::Any}; return {DB::JoinKind::Left, DB::JoinStrictness::Semi}; + } case substrait::JoinRel_JoinType_JOIN_TYPE_ANTI: return {DB::JoinKind::Left, DB::JoinStrictness::Anti}; case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 98139fb49a5b..b45c6ab3c4d2 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -313,7 +313,7 @@ class JoinUtil { public: static void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols); - static std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type); + static std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type, bool is_existence_join); static std::pair getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type); }; diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp index 4d5eae6dc0b5..c21cc8ba3524 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp @@ -99,6 +99,7 @@ std::shared_ptr buildJoin( const std::string & join_keys, jint join_type, bool has_mixed_join_condition, + bool is_existence_join, const std::string & named_struct) { auto join_key_list = Poco::StringTokenizer(join_keys, ","); @@ -112,7 +113,7 @@ std::shared_ptr buildJoin( if (key.starts_with("BuiltBNLJBroadcastTable-")) std::tie(kind, strictness) = JoinUtil::getCrossJoinKindAndStrictness(static_cast(join_type)); else - std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast(join_type)); + std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast(join_type), is_existence_join); substrait::NamedStruct substrait_struct; diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h index 3d2e67f9df10..a97bd77a84d0 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h @@ -37,6 +37,7 @@ std::shared_ptr buildJoin( const std::string & join_keys, jint join_type, bool has_mixed_join_condition, + bool is_existence_join, const std::string & named_struct); void cleanBuildHashTable(const std::string & hash_table_id, jlong instance); std::shared_ptr getJoin(const std::string & hash_table_id); diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index 460311e289b6..24ba7acdb654 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include @@ -50,13 +51,13 @@ using namespace DB; namespace local_engine { -std::shared_ptr createDefaultTableJoin(substrait::JoinRel_JoinType join_type) +std::shared_ptr createDefaultTableJoin(substrait::JoinRel_JoinType join_type, bool is_existence_join) { auto & global_context = SerializedPlanParser::global_context; auto table_join = std::make_shared( global_context->getSettings(), global_context->getGlobalTemporaryVolume(), global_context->getTempDataOnDisk()); - std::pair kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type); + std::pair kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type, is_existence_join); table_join->setKind(kind_and_strictness.first); table_join->setStrictness(kind_and_strictness.second); return table_join; @@ -218,7 +219,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q renamePlanColumns(*left, *right, *storage_join); } - auto table_join = createDefaultTableJoin(join.type()); + auto table_join = createDefaultTableJoin(join.type(), join_opt_info.is_existence_join); DB::Block right_header_before_convert_step = right->getCurrentDataStream().header; addConvertStep(*table_join, *left, *right); @@ -350,11 +351,39 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q query_plan = std::make_unique(); query_plan->unitePlans(std::move(join_step), {std::move(plans)}); } + JoinUtil::reorderJoinOutput(*query_plan, after_join_names); + /// Need to project the right table column into boolean type + if (join_opt_info.is_existence_join) + { + existenceJoinPostProject(*query_plan, left_names); + } return query_plan; } + +/// We use left any join to implement ExistenceJoin. +/// The result columns of ExistenceJoin are left table columns + one flag column. +/// The flag column indicates whether a left row is matched or not. We build the flag column here. +/// The input plan's header is left table columns + right table columns. If one row in the right row is null, +/// we mark the flag 0, otherwise mark it 1. +void JoinRelParser::existenceJoinPostProject(DB::QueryPlan & plan, const DB::Names & left_input_cols) +{ + auto actions_dag = std::make_shared(plan.getCurrentDataStream().header.getColumnsWithTypeAndName()); + const auto * right_col_node = actions_dag->getInputs().back(); + auto function_builder = DB::FunctionFactory::instance().get("isNotNull", getContext()); + const auto * not_null_node = &actions_dag->addFunction(function_builder, {right_col_node}, right_col_node->result_name); + actions_dag->addOrReplaceInOutputs(*not_null_node); + DB::Names required_cols = left_input_cols; + required_cols.emplace_back(not_null_node->result_name); + actions_dag->removeUnusedActions(required_cols); + auto project_step = std::make_unique(plan.getCurrentDataStream(), actions_dag); + project_step->setStepDescription("ExistenceJoin Post Project"); + steps.emplace_back(project_step.get()); + plan.addStep(std::move(project_step)); +} + void JoinRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right) { /// If the columns name in right table is duplicated with left table, we need to rename the right table's columns. diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index e6d31e6d31d6..ee1155cb4712 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -66,6 +66,8 @@ class JoinRelParser : public RelParser void addPostFilter(DB::QueryPlan & plan, const substrait::JoinRel & join); + void existenceJoinPostProject(DB::QueryPlan & plan, const DB::Names & left_input_cols); + static std::unordered_set extractTableSidesFromExpression( const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); }; diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 17d087bb82ff..a6ca55052ef9 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -1094,6 +1094,7 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild jstring join_key_, jint join_type_, jboolean has_mixed_join_condition, + jboolean is_existence_join, jbyteArray named_struct) { LOCAL_ENGINE_JNI_METHOD_START @@ -1107,7 +1108,7 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild DB::CompressedReadBuffer input(read_buffer_from_java_array); local_engine::configureCompressedReadBuffer(input); const auto * obj = make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin( - hash_table_id, input, row_count_, join_key, join_type_, has_mixed_join_condition, struct_string)); + hash_table_id, input, row_count_, join_key, join_type_, has_mixed_join_condition, is_existence_join, struct_string)); return obj->instance(); LOCAL_ENGINE_JNI_METHOD_END(env, 0) }