From 12cf403ee813e78d8f6b354485c0a3f9767b804b Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 23 Jul 2024 16:22:58 +0800 Subject: [PATCH] try replace smj with hj --- .../clickhouse/CHSparkPlanExecApi.scala | 7 +- .../RewriteSortMergeJoinToHashJoinRule.scala | 119 ++++++++++++++++++ 2 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 32f372956f0ae..34f71f51ad1f0 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi} import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ import org.apache.gluten.expression._ -import org.apache.gluten.extension.{CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteToDateExpresstionRule} +import org.apache.gluten.extension.{CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule} import org.apache.gluten.extension.columnar.AddFallbackTagRule import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides import org.apache.gluten.extension.columnar.transition.Convention @@ -555,8 +555,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { * * @return */ - override def genExtendedQueryStagePrepRules(): List[SparkSession => Rule[SparkPlan]] = + override def genExtendedQueryStagePrepRules(): List[SparkSession => Rule[SparkPlan]] = { List(spark => FallbackBroadcastHashJoinPrepQueryStage(spark)) + } /** * Generate extended Analyzers. Currently only for ClickHouse backend. @@ -597,7 +598,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { * @return */ override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] = - List() + List(spark => RewriteSortMergeJoinToHashJoinRule(spark)) override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { List() diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala new file mode 100644 index 0000000000000..048545361a8ae --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.execution._ +import org.apache.gluten.utils.{CHJoinValidateUtil, ShuffleHashJoinStrategy, SortMergeJoinStrategy} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.joins._ + +// import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +// If a SortMergeJoinExec cannot be offloaded, try to replace it with ShuffledHashJoinExec +// instead. +// This is rule is applied after spark plan nodes are transformed into columnar ones. +case class RewriteSortMergeJoinToHashJoinRule(session: SparkSession) + extends Rule[SparkPlan] + with Logging { + override def apply(plan: SparkPlan): SparkPlan = { + visitPlan(plan) + } + + private def visitPlan(plan: SparkPlan): SparkPlan = { + plan match { + case smj: SortMergeJoinExec => + tryReplaceSortMergeJoin(smj) + case other => + other.withNewChildren(other.children.map(visitPlan)) + } + } + + private def tryReplaceSortMergeJoin(smj: SortMergeJoinExec): SparkPlan = { + // cannot offload SortMergeJoin, try to replace it with ShuffledHashJoin + val finalJoinType = smj.joinType match { + case ExistenceJoin(_) => + LeftSemi + case _ => + smj.joinType + } + val needFallback = CHJoinValidateUtil.shouldFallback( + ShuffleHashJoinStrategy(finalJoinType), + smj.left.outputSet, + smj.right.outputSet, + smj.condition) + // also cannot offload HashJoin, don't replace it. + if (needFallback) { + logInfo(s"Cannot offload this join by hash join algorithm") + return smj + } else { + replaceSortMergeJoinWithHashJoin(smj) + } + } + + private def replaceSortMergeJoinWithHashJoin(smj: SortMergeJoinExec): SparkPlan = { + val newLeft = replaceSortMergeJoinChild(smj.left) + val newRight = replaceSortMergeJoinChild(smj.right) + // Some cases that we cannot handle. + if (newLeft == null || newRight == null) { + logInfo("Apply on sort merge children failed") + return smj + } + + CHShuffledHashJoinExecTransformer( + smj.leftKeys, + smj.rightKeys, + smj.joinType, + BuildRight, + smj.condition, + newLeft, + newRight, + smj.isSkewJoin) + } + + private def replaceSortMergeJoinChild(plan: SparkPlan): SparkPlan = { + plan match { + case sort: SortExecTransformer => + sort.child match { + case hashShuffle: ColumnarShuffleExchangeExec => + // drop sort node, return the shuffle node direclty + hashShuffle.withNewChildren(hashShuffle.children.map(visitPlan)) + case columnarPlan: TransformSupport => + visitPlan(columnarPlan) + case _ => + // other cases that we don't know + logInfo(s"Expected ColumnarShuffleExchangeExec, got ${sort.child.getClass}") + null + } + case smj: SortMergeJoinExec => + val newChild = replaceSortMergeJoinWithHashJoin(smj) + if (newChild.isInstanceOf[SortMergeJoinExec]) { + null + } else { + newChild + } + case _: TransformSupport => visitPlan(plan) + case _ => + logInfo(s"Expected Columnar node, got ${plan.getClass}") + null + } + } +}