diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala index 7db259409ac1..4cedaae25684 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive._ case class ReorderJoinTablesRule(session: SparkSession) extends Rule[SparkPlan] with Logging { - logError(s"ReorderJoinTablesRule is enabled: ${CHBackendSettings.enableReorderHashJoinTables}") override def apply(plan: SparkPlan): SparkPlan = { if (CHBackendSettings.enableReorderHashJoinTables) { visitPlan(plan) @@ -59,9 +58,6 @@ case class ReorderJoinTablesRule(session: SparkSession) extends Rule[SparkPlan] val threshold = CHBackendSettings.reorderHashJoinTablesThreshold val isLeftLarger = leftQueryStageRow.get > rightQueryStageRow.get * threshold val isRightLarger = leftQueryStageRow.get * threshold < rightQueryStageRow.get - logError( - s"xxx isLeftLarger:$isLeftLarger, isRightLarger:$isRightLarger, " + - s"buildside:${hashJoin.buildSide}, join type: ${hashJoin.joinType}") hashJoin.joinType match { case Inner => if (isRightLarger && hashJoin.buildSide == BuildRight) { diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala index 8e9583492255..fc22add2d880 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala @@ -17,14 +17,17 @@ package org.apache.gluten.execution import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.CoalescedPartitionSpec import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec} class GlutenClickHouseColumnarShuffleAQESuite extends GlutenClickHouseTPCHAbstractSuite - with AdaptiveSparkPlanHelper { + with AdaptiveSparkPlanHelper + with Logging { override protected val tablesPath: String = basePath + "/tpch-data-ch" override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch" @@ -184,6 +187,25 @@ class GlutenClickHouseColumnarShuffleAQESuite spark.sql("insert into t1 select id as a, id as b from range(100000)") spark.sql("insert into t1 select id as a, id as b from range(100)") + def isExpectedJoinNode(plan: SparkPlan, joinType: JoinType, buildSide: BuildSide): Boolean = { + plan match { + case join: CHShuffledHashJoinExecTransformer => + join.joinType == joinType && join.buildSide == buildSide + case _ => false + } + } + + def collectExpectedJoinNode( + plan: SparkPlan, + joinType: JoinType, + buildSide: BuildSide): Seq[SparkPlan] = { + if (isExpectedJoinNode(plan, joinType, buildSide)) { + Seq(plan) ++ plan.children.flatMap(collectExpectedJoinNode(_, joinType, buildSide)) + } else { + plan.children.flatMap(collectExpectedJoinNode(_, joinType, buildSide)) + } + } + var sql = """ |select * from t2 left join t1 on t1.a = t2.a |""".stripMargin @@ -192,14 +214,10 @@ class GlutenClickHouseColumnarShuffleAQESuite true, { df => - val nodes = df.queryExecution.executedPlan.collect { - case node => node.getClass.getSimpleName - } - assert(nodes == Seq()) val joins = df.queryExecution.executedPlan.collect { - case joinExec: CHShuffledHashJoinExecTransformer - if (joinExec.joinType == RightOuter) && (joinExec.buildSide == BuildRight) => - joinExec + case adpativeNode: AdaptiveSparkPlanExec => + collectExpectedJoinNode(adpativeNode.executedPlan, RightOuter, BuildRight) + case _ => Seq() } assert(joins.size == 1) } @@ -214,9 +232,9 @@ class GlutenClickHouseColumnarShuffleAQESuite { df => val joins = df.queryExecution.executedPlan.collect { - case joinExec: CHShuffledHashJoinExecTransformer - if (joinExec.joinType == LeftOuter) && (joinExec.buildSide == BuildRight) => - joinExec + case adpativeNode: AdaptiveSparkPlanExec => + collectExpectedJoinNode(adpativeNode.executedPlan, LeftOuter, BuildRight) + case _ => Seq() } assert(joins.size == 1) } @@ -231,9 +249,9 @@ class GlutenClickHouseColumnarShuffleAQESuite { df => val joins = df.queryExecution.executedPlan.collect { - case joinExec: CHShuffledHashJoinExecTransformer - if (joinExec.joinType == RightOuter) && (joinExec.buildSide == BuildRight) => - joinExec + case adpativeNode: AdaptiveSparkPlanExec => + collectExpectedJoinNode(adpativeNode.executedPlan, RightOuter, BuildRight) + case _ => Seq() } assert(joins.size == 1) }