diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java index 27725998feebe..ae7b89120cd4d 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java @@ -45,6 +45,7 @@ private static native long nativeBuild( String joinKeys, int joinType, boolean hasMixedFiltCondition, + boolean isExistenceJoin, byte[] namedStruct); private StorageJoinBuilder() {} @@ -89,6 +90,7 @@ public static long build( joinKey, joinType, broadCastContext.hasMixedFiltCondition(), + broadCastContext.isExistenceJoin(), toNameStruct(output).toByteArray()); } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala index d1dc76045338d..3aab5a6eb9986 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.rpc.GlutenDriverEndpoint import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftSemi} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.BuildSideRelation @@ -44,6 +45,13 @@ case class CHBroadcastNestedLoopJoinExecTransformer( condition ) { + private val finalJoinType = joinType match { + case ExistenceJoin(_) => + LeftSemi + case _ => + joinType + } + override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = { val streamedRDD = getColumnarInputRDDs(streamedPlan) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) @@ -57,7 +65,13 @@ case class CHBroadcastNestedLoopJoinExecTransformer( } val broadcast = buildPlan.executeBroadcast[BuildSideRelation]() val context = - BroadCastHashJoinContext(Seq.empty, joinType, false, buildPlan.output, buildBroadcastTableId) + BroadCastHashJoinContext( + Seq.empty, + finalJoinType, + false, + joinType.isInstanceOf[ExistenceJoin], + buildPlan.output, + buildBroadcastTableId) val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context) streamedRDD :+ broadcastRDD } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 48870892d290e..aa0faa8974117 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch +import io.substrait.proto.JoinRel + case class CHShuffledHashJoinExecTransformer( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -82,6 +84,7 @@ case class BroadCastHashJoinContext( buildSideJoinKeys: Seq[Expression], joinType: JoinType, hasMixedFiltCondition: Boolean, + isExistenceJoin: Boolean, buildSideStructure: Seq[Attribute], buildHashTableId: String) @@ -112,7 +115,7 @@ case class CHBroadcastHashJoinExecTransformer( override protected def doValidateInternal(): ValidationResult = { val shouldFallback = CHJoinValidateUtil.shouldFallback( - BroadcastHashJoinStrategy(joinType), + BroadcastHashJoinStrategy(finalJoinType), left.outputSet, right.outputSet, condition) @@ -141,8 +144,9 @@ case class CHBroadcastHashJoinExecTransformer( val context = BroadCastHashJoinContext( buildKeyExprs, - joinType, + finalJoinType, isMixedCondition(condition), + joinType.isInstanceOf[ExistenceJoin], buildPlan.output, buildHashTableId) val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context) @@ -161,4 +165,32 @@ case class CHBroadcastHashJoinExecTransformer( } res } + + // ExistenceJoin is introduced in #SPARK-14781 + // Indeed, the ExistenceJoin is transformed into left any join in CH. + // We don't have left any join in substrait, so use left semi join instean. + // and isExistenceJoin is set to true to indicate that it is an existence join. + private val finalJoinType = joinType match { + case ExistenceJoin(_) => + LeftSemi + case _ => + joinType + } + override protected lazy val substraitJoinType: JoinRel.JoinType = { + joinType match { + case _: InnerLike => + JoinRel.JoinType.JOIN_TYPE_INNER + case FullOuter => + JoinRel.JoinType.JOIN_TYPE_OUTER + case LeftOuter | RightOuter => + JoinRel.JoinType.JOIN_TYPE_LEFT + case LeftSemi | ExistenceJoin(_) => + JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI + case LeftAnti => + JoinRel.JoinType.JOIN_TYPE_ANTI + case _ => + // TODO: Support cross join with Cross Rel + JoinRel.JoinType.UNRECOGNIZED + } + } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala index d26891ddb1eaa..1c09449c817fb 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala @@ -500,5 +500,36 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite { compareResultsAgainstVanillaSpark(sql2, true, { _ => }) } + + test("existence join") { + spark.sql("create table t1(a int, b int) using parquet") + spark.sql("create table t2(a int, b int) using parquet") + spark.sql("insert into t1 values(0, 0), (1, 2), (2, 3), (3, 4), (null, 5), (6, null)") + spark.sql("insert into t2 values(0, 0), (1, 2), (2, 3), (2,4), (null, 5), (6, null)") + + val sql1 = """ + |select * from t1 where exists (select 1 from t2 where t1.a = t2.a) or t1.a > 1 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql1, true, { _ => }) + + val sql2 = """ + |select * from t1 where exists (select 1 from t2 where t1.a = t2.a) or t1.a > 3 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql2, true, { _ => }) + + val sql3 = """ + |select * from t1 where exists (select 1 from t2 where t1.a = t2.a) or t1.b > 0 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql3, true, { _ => }) + + val sql4 = """ + |select * from t1 where exists (select 1 from t2 + |where t1.a = t2.a and t1.b = t2.b) or t1.a > 0 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql4, true, { _ => }) + + spark.sql("drop table t1") + spark.sql("drop table t2") + } } // scalastyle:off line.size.limit diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala index 8d4bee5546253..141bf5eea5cb9 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala @@ -104,7 +104,7 @@ object CHHashBuildBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark w ( countsAndBytes.flatMap(_._2), countsAndBytes.map(_._1).sum, - BroadCastHashJoinContext(Seq(child.output.head), Inner, false, child.output, "") + BroadCastHashJoinContext(Seq(child.output.head), Inner, false, false, child.output, "") ) } } diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index b74c18dd14af1..a2ec4cf192f5b 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -1090,14 +1090,18 @@ void JoinUtil::reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols) plan.addStep(std::move(project_step)); } -std::pair JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type) +std::pair +JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type, bool is_existence_join) { switch (join_type) { case substrait::JoinRel_JoinType_JOIN_TYPE_INNER: return {DB::JoinKind::Inner, DB::JoinStrictness::All}; - case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: + case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: { + if (is_existence_join) + return {DB::JoinKind::Left, DB::JoinStrictness::Any}; return {DB::JoinKind::Left, DB::JoinStrictness::Semi}; + } case substrait::JoinRel_JoinType_JOIN_TYPE_ANTI: return {DB::JoinKind::Left, DB::JoinStrictness::Anti}; case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 7be3f86dc2303..10c6ce8f65110 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -311,7 +311,7 @@ class JoinUtil { public: static void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols); - static std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type); + static std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type, bool is_existence_join); static std::pair getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type); }; diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp index 4d5eae6dc0b57..c21cc8ba35246 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp @@ -99,6 +99,7 @@ std::shared_ptr buildJoin( const std::string & join_keys, jint join_type, bool has_mixed_join_condition, + bool is_existence_join, const std::string & named_struct) { auto join_key_list = Poco::StringTokenizer(join_keys, ","); @@ -112,7 +113,7 @@ std::shared_ptr buildJoin( if (key.starts_with("BuiltBNLJBroadcastTable-")) std::tie(kind, strictness) = JoinUtil::getCrossJoinKindAndStrictness(static_cast(join_type)); else - std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast(join_type)); + std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast(join_type), is_existence_join); substrait::NamedStruct substrait_struct; diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h index 3d2e67f9df101..a97bd77a84d09 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h @@ -37,6 +37,7 @@ std::shared_ptr buildJoin( const std::string & join_keys, jint join_type, bool has_mixed_join_condition, + bool is_existence_join, const std::string & named_struct); void cleanBuildHashTable(const std::string & hash_table_id, jlong instance); std::shared_ptr getJoin(const std::string & hash_table_id); diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index 09a5152174fdf..bb3470c1ccb34 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -51,13 +52,13 @@ using namespace DB; namespace local_engine { -std::shared_ptr createDefaultTableJoin(substrait::JoinRel_JoinType join_type) +std::shared_ptr createDefaultTableJoin(substrait::JoinRel_JoinType join_type, bool is_existence_join) { auto & global_context = SerializedPlanParser::global_context; auto table_join = std::make_shared( global_context->getSettings(), global_context->getGlobalTemporaryVolume(), global_context->getTempDataOnDisk()); - std::pair kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type); + std::pair kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type, is_existence_join); table_join->setKind(kind_and_strictness.first); table_join->setStrictness(kind_and_strictness.second); return table_join; @@ -219,7 +220,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q renamePlanColumns(*left, *right, *storage_join); } - auto table_join = createDefaultTableJoin(join.type()); + auto table_join = createDefaultTableJoin(join.type(), join_opt_info.is_existence_join); DB::Block right_header_before_convert_step = right->getCurrentDataStream().header; addConvertStep(*table_join, *left, *right); @@ -351,11 +352,32 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q query_plan = std::make_unique(); query_plan->unitePlans(std::move(join_step), {std::move(plans)}); } + + /// Need to project the right table column into boolean type + if (join_opt_info.is_existence_join) + { + existenceJoinPostProject(*query_plan); + } JoinUtil::reorderJoinOutput(*query_plan, after_join_names); return query_plan; } +void JoinRelParser::existenceJoinPostProject(DB::QueryPlan & plan) +{ + /// The last column if a column from right table' keys. If there is no matched for a left row, the value is null + /// in this column. We have to convert this column into flags to indicate whether a left row is matched or not. + auto actions_dag = std::make_shared(plan.getCurrentDataStream().header.getColumnsWithTypeAndName()); + const auto * right_col_node = actions_dag->getInputs().back(); + auto function_builder = DB::FunctionFactory::instance().get("isNotNull", getContext()); + const auto * not_null_node = &actions_dag->addFunction(function_builder, {right_col_node}, right_col_node->result_name); + actions_dag->addOrReplaceInOutputs(*not_null_node); + auto project_step = std::make_unique(plan.getCurrentDataStream(), actions_dag); + project_step->setStepDescription("ExistenceJoin Post Project"); + steps.emplace_back(project_step.get()); + plan.addStep(std::move(project_step)); +} + void JoinRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right) { /// If the columns name in right table is duplicated with left table, we need to rename the right table's columns. diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index e6d31e6d31d6e..1d5d50e175a02 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -66,6 +66,8 @@ class JoinRelParser : public RelParser void addPostFilter(DB::QueryPlan & plan, const substrait::JoinRel & join); + void existenceJoinPostProject(DB::QueryPlan & plan); + static std::unordered_set extractTableSidesFromExpression( const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); };