From 2cdeed7b330078d3acfb42e745e3d9e8714a96ae Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Fri, 7 Jun 2024 19:39:39 +0800 Subject: [PATCH] fixup --- .../expression/ExpressionConverter.scala | 19 ++------ .../ColumnarSubqueryBroadcastExec.scala | 48 +++++++------------ 2 files changed, 23 insertions(+), 44 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 6a61db2cd4ae..97d986259106 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -700,20 +700,11 @@ object ExpressionConverter extends SQLConfHelper with Logging { val newChild = from.child match { case exchange: BroadcastExchangeExec => toColumnarBroadcastExchange(exchange) - case aqe: AdaptiveSparkPlanExec if !aqe.supportsColumnar => - // ColumnarSubqueryBroadcastExec strictly requires for - // child with columnar output. AQE with supportsColumnar=false - // may produce row plan that will fail the subsequent processing. - // Thus we replace it with supportsColumnar=true to make sure - // columnar output is emitted from AQE. - val newAqe = AdaptiveSparkPlanExec( - aqe.inputPlan, - aqe.context, - aqe.preprocessingRules, - aqe.isSubquery, - supportsColumnar = true) - newAqe.copyTagsFrom(aqe) - newAqe + case aqe: AdaptiveSparkPlanExec => + // Keeps the child if its is AQE even if its supportsColumnar == false. + // ColumnarSubqueryBroadcastExec is compatible with both row-based + // and columnar inputs. + aqe case other => other } val out = ColumnarSubqueryBroadcastExec(from.name, from.index, from.buildKeys, newChild) diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala index a74d428fd452..eaa180516c49 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala @@ -24,8 +24,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec -import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelation, HashJoin, LongHashedRelation} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.ThreadUtils @@ -75,35 +73,25 @@ case class ColumnarSubqueryBroadcastExec( SQLExecution.withExecutionId(session, executionId) { val rows = GlutenTimeMetric.millis(longMetric("collectTime")) { _ => - val exchangeChild = child match { - case exec: ReusedExchangeExec => - exec.child - case _ => - child - } - if ( - exchangeChild.isInstanceOf[ColumnarBroadcastExchangeExec] || - exchangeChild.isInstanceOf[AdaptiveSparkPlanExec] - ) { - // transform broadcasted columnar value to Array[InternalRow] by key - exchangeChild - .executeBroadcast[BuildSideRelation] - .value - .transform(buildKeys(index)) - .distinct - } else { - val broadcastRelation = exchangeChild.executeBroadcast[HashedRelation]().value - val (iter, expr) = if (broadcastRelation.isInstanceOf[LongHashedRelation]) { - (broadcastRelation.keys(), HashJoin.extractKeyExprAt(buildKeys, index)) - } else { - ( - broadcastRelation.keys(), - BoundReference(index, buildKeys(index).dataType, buildKeys(index).nullable)) - } + val relation = child.executeBroadcast[Any]().value + relation match { + case b: BuildSideRelation => + b.transform(buildKeys(index)).distinct + case h: HashedRelation => + val (iter, expr) = if (h.isInstanceOf[LongHashedRelation]) { + (h.keys(), HashJoin.extractKeyExprAt(buildKeys, index)) + } else { + ( + h.keys(), + BoundReference(index, buildKeys(index).dataType, buildKeys(index).nullable)) + } - val proj = UnsafeProjection.create(expr) - val keyIter = iter.map(proj).map(_.copy()) - keyIter.toArray[InternalRow].distinct + val proj = UnsafeProjection.create(expr) + val keyIter = iter.map(proj).map(_.copy()) + keyIter.toArray[InternalRow].distinct + case other => + throw new UnsupportedOperationException( + s"Unrecognizable broadcast relation: $other") } } val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes).sum