Skip to content

Commit

Permalink
try replace smj with hj
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Jul 24, 2024
1 parent 23af01d commit 07b41fb
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi}
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.extension.{CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.{CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.columnar.AddFallbackTagRule
import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.transition.Convention
Expand Down Expand Up @@ -555,8 +555,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
*
* @return
*/
override def genExtendedQueryStagePrepRules(): List[SparkSession => Rule[SparkPlan]] =
override def genExtendedQueryStagePrepRules(): List[SparkSession => Rule[SparkPlan]] = {
List(spark => FallbackBroadcastHashJoinPrepQueryStage(spark))
}

/**
* Generate extended Analyzers. Currently only for ClickHouse backend.
Expand Down Expand Up @@ -597,7 +598,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
* @return
*/
override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] =
List()
List(spark => RewriteSortMergeJoinToHashJoinRule(spark))

override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = {
List()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension

import org.apache.gluten.execution._
import org.apache.gluten.utils.{CHJoinValidateUtil, ShuffleHashJoinStrategy, SortMergeJoinStrategy}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.joins._

// import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
// If a SortMergeJoinExec cannot be offloaded, try to replace it with ShuffledHashJoinExec
// instead.
// This is rule is applied after spark plan nodes are transformed into columnar ones.
case class RewriteSortMergeJoinToHashJoinRule(session: SparkSession)
extends Rule[SparkPlan]
with Logging {
override def apply(plan: SparkPlan): SparkPlan = {
visitPlan(plan)
}

private def visitPlan(plan: SparkPlan): SparkPlan = {
plan match {
case smj: SortMergeJoinExec =>
tryReplaceSortMergeJoin(smj)
case other =>
other.withNewChildren(other.children.map(visitPlan))
}
}

private def tryReplaceSortMergeJoin(smj: SortMergeJoinExec): SparkPlan = {
var needFallback = CHJoinValidateUtil.shouldFallback(
SortMergeJoinStrategy(smj.joinType),
smj.left.outputSet,
smj.right.outputSet,
smj.condition)
if (!needFallback) {
val debugNode = smj.withNewChildren(Seq())
throw new IllegalStateException(s"SortMergeJoin should have been offloaded. $debugNode")
}

// cannot offload SortMergeJoin, try to replace it with ShuffledHashJoin
val finalJoinType = smj.joinType match {
case ExistenceJoin(_) =>
LeftSemi
case _ =>
smj.joinType
}
needFallback = CHJoinValidateUtil.shouldFallback(
ShuffleHashJoinStrategy(finalJoinType),
smj.left.outputSet,
smj.right.outputSet,
smj.condition)
// also cannot offload HashJoin, don't replace it.
if (needFallback) {
logInfo(s"Cannot offload this join by hash join algorithm")
return smj
} else {
replaceSortMergeJoinWithHashJoin(smj)
}
}

private def replaceSortMergeJoinWithHashJoin(smj: SortMergeJoinExec): SparkPlan = {
val newLeft = replaceSortMergeJoinChild(smj.left)
val newRight = replaceSortMergeJoinChild(smj.right)
// Some cases that we cannot handle.
if (newLeft == null || newRight == null) {
logInfo("Apply on sort merge children failed")
return smj
}

CHShuffledHashJoinExecTransformer(
smj.leftKeys,
smj.rightKeys,
smj.joinType,
BuildRight,
smj.condition,
newLeft,
newRight,
smj.isSkewJoin)
}

private def replaceSortMergeJoinChild(plan: SparkPlan): SparkPlan = {
plan match {
case sort: SortExecTransformer =>
sort.child match {
case hashShuffle: ColumnarShuffleExchangeExec =>
// drop sort node, return the shuffle node direclty
hashShuffle.withNewChildren(hashShuffle.children.map(visitPlan))
case columnarPlan: TransformSupport =>
visitPlan(columnarPlan)
case _ =>
// other cases that we don't know
logInfo(s"Expected ColumnarShuffleExchangeExec, got ${sort.child.getClass}")
null
}
case smj: SortMergeJoinExec =>
val newChild = replaceSortMergeJoinWithHashJoin(smj)
if (newChild.isInstanceOf[SortMergeJoinExec]) {
null
} else {
newChild
}
case _: TransformSupport => visitPlan(plan)
case _ =>
logInfo(s"Expected Columnar node, got ${plan.getClass}")
null
}
}
}

0 comments on commit 07b41fb

Please sign in to comment.