diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala index 9dcb7ee3c41d..d6511f7a4a29 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala @@ -22,13 +22,17 @@ import org.apache.gluten.substrait.expression.ExpressionNode import org.apache.gluten.substrait.plan.{PlanBuilder, PlanNode} import org.apache.gluten.substrait.rel.RelBuilder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression} import com.google.common.collect.Lists +import java.util + +import scala.collection.JavaConverters._ + object PlanNodesUtil { - def genProjectionsPlanNode(key: Expression, output: Seq[Attribute]): PlanNode = { + def genProjectionsPlanNode(key: Seq[Expression], output: Seq[Attribute]): PlanNode = { val context = new SubstraitContext var operatorId = context.nextOperatorId("ClickHouseBuildSideRelationReadIter") @@ -36,41 +40,36 @@ object PlanNodesUtil { val nameList = ConverterUtils.collectAttributeNamesWithExprId(output) val readRel = RelBuilder.makeReadRelForInputIterator(typeList, nameList, context, operatorId) - // replace attribute to BoundRefernce according to the output - val newBoundRefKey = key.transformDown { - case expression: AttributeReference => - val columnInOutput = output.zipWithIndex.filter { - p: (Attribute, Int) => p._1.exprId == expression.exprId || p._1.name == expression.name - } - if (columnInOutput.isEmpty) { - throw new IllegalStateException( - s"Key $expression not found from build side relation output: $output") - } - if (columnInOutput.size != 1) { - throw new IllegalStateException( - s"More than one key $expression found from build side relation output: $output") - } - val boundReference = columnInOutput.head - BoundReference(boundReference._2, boundReference._1.dataType, boundReference._1.nullable) - case other => other - } - // project operatorId = context.nextOperatorId("ClickHouseBuildSideRelationProjection") val args = context.registeredFunction val columnarProjExpr = ExpressionConverter - .replaceWithExpressionTransformer(newBoundRefKey, attributeSeq = output) + .replaceWithExpressionTransformer(key, attributeSeq = output) val projExprNodeList = new java.util.ArrayList[ExpressionNode]() - projExprNodeList.add(columnarProjExpr.doTransform(args)) + columnarProjExpr.foreach(e => projExprNodeList.add(e.doTransform(args))) PlanBuilder.makePlan( context, Lists.newArrayList( RelBuilder.makeProjectRel(readRel, projExprNodeList, context, operatorId, output.size)), - Lists.newArrayList( - ConverterUtils.genColumnNameWithExprId(ConverterUtils.getAttrFromExpr(key))) + Lists.newArrayList(genColumnNameWithExprId(key, output)) ) } + + private def genColumnNameWithExprId( + key: Seq[Expression], + output: Seq[Attribute]): util.List[String] = { + key + .map { + k => + val reference = k.collectFirst { case BoundReference(ordinal, _, _) => output(ordinal) } + assert(reference.isDefined) + reference.get + } + .map(ConverterUtils.genColumnNameWithExprId) + .toList + .asJava + } } diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala index 92887f16d70a..668525ba0a40 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala @@ -22,8 +22,8 @@ import org.apache.gluten.vectorized._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode} import org.apache.spark.sql.execution.utils.CHExecUtil import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.CHShuffleReadStreamFactory @@ -72,7 +72,7 @@ case class ClickHouseBuildSideRelation( } /** - * Transform columnar broadcast value to Array[InternalRow] by key and distinct. + * Transform columnar broadcast value to Array[InternalRow] by key. * * @return */ @@ -80,10 +80,18 @@ case class ClickHouseBuildSideRelation( // native block reader val blockReader = new CHStreamReader(CHShuffleReadStreamFactory.create(batches, true)) val broadCastIter: Iterator[ColumnarBatch] = IteratorUtil.createBatchIterator(blockReader) + + val transformProjections = mode match { + case HashedRelationBroadcastMode(k, _) => k + case IdentityBroadcastMode => output + } + // Expression compute, return block iterator val expressionEval = new SimpleExpressionEval( new ColumnarNativeIterator(broadCastIter.asJava), - PlanNodesUtil.genProjectionsPlanNode(key, output)) + PlanNodesUtil.genProjectionsPlanNode(transformProjections, output)) + + val proj = UnsafeProjection.create(Seq(key)) try { // convert columnar to row @@ -95,6 +103,7 @@ case class ClickHouseBuildSideRelation( } else { CHExecUtil .getRowIterFromSparkRowInfo(block, batch.numColumns(), batch.numRows()) + .map(proj) .map(row => row.copy()) } }.toArray diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 2df2e2718eaa..d837ac423407 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -632,7 +632,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { } numOutputRows += serialized.map(_.getNumRows).sum dataSize += rawSize - ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized)) + ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode) } override def doCanonicalizeForBroadcastMode(mode: BroadcastMode): BroadcastMode = { diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala index 11a8cc980904..c5323d4f8d50 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala @@ -106,7 +106,8 @@ object BroadcastUtils { } ColumnarBuildSideRelation( SparkShimLoader.getSparkShims.attributesFromStruct(schema), - serialized) + serialized, + mode) } // Rebroadcast Velox relation. context.broadcast(toRelation).asInstanceOf[Broadcast[T]] @@ -124,7 +125,8 @@ object BroadcastUtils { } ColumnarBuildSideRelation( SparkShimLoader.getSparkShims.attributesFromStruct(schema), - serialized) + serialized, + mode) } // Rebroadcast Velox relation. context.broadcast(toRelation).asInstanceOf[Broadcast[T]] diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala index fa3d348967d5..977357990c43 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala @@ -26,8 +26,11 @@ import org.apache.gluten.utils.ArrowAbiUtil import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode import org.apache.spark.sql.execution.joins.BuildSideRelation +import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.utils.SparkArrowUtil import org.apache.spark.sql.vectorized.ColumnarBatch @@ -37,9 +40,19 @@ import org.apache.arrow.c.ArrowSchema import scala.collection.JavaConverters.asScalaIteratorConverter -case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Array[Byte]]) +case class ColumnarBuildSideRelation( + output: Seq[Attribute], + batches: Array[Array[Byte]], + mode: BroadcastMode) extends BuildSideRelation { + private def transformProjection: UnsafeProjection = { + mode match { + case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k) + case IdentityBroadcastMode => UnsafeProjection.create(output, output) + } + } + override def deserialized: Iterator[ColumnarBatch] = { val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName, "BuildSideRelation#deserialized") @@ -84,8 +97,11 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra override def asReadOnlyCopy(): ColumnarBuildSideRelation = this /** - * Transform columnar broadcast value to Array[InternalRow] by key and distinct. NOTE: This method - * was called in Spark Driver, should manage resources carefully. + * Transform columnar broadcast value to Array[InternalRow] by key. + * + * NOTE: + * - This method was called in Spark Driver, should manage resources carefully. + * - The "key" must be already been bound reference. */ override def transform(key: Expression): Array[InternalRow] = TaskResources.runUnsafe { val runtime = @@ -106,17 +122,7 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra var closed = false - val exprIds = output.map(_.exprId) - val projExpr = key.transformDown { - case attr: AttributeReference if !exprIds.contains(attr.exprId) => - val i = output.count(_.name == attr.name) - if (i != 1) { - throw new IllegalArgumentException(s"Only one attr with the same name is supported: $key") - } else { - output.find(_.name == attr.name).get - } - } - val proj = UnsafeProjection.create(Seq(projExpr), output) + val proj = UnsafeProjection.create(Seq(key)) // Convert columnar to Row. val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime) @@ -178,7 +184,7 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra rowId += 1 row } - }.map(proj).map(_.copy()) + }.map(transformProjection).map(proj).map(_.copy()) } } } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala index 6275fbb3aa3c..12280cc42aed 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelation, HashJoin, LongHashedRelation} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.types.IntegralType import org.apache.spark.util.ThreadUtils import scala.concurrent.Future @@ -64,6 +65,14 @@ case class ColumnarSubqueryBroadcastExec( copy(name = "native-dpp", buildKeys = keys, child = child.canonicalized) } + // Copy from org.apache.spark.sql.execution.joins.HashJoin#canRewriteAsLongType + // we should keep consistent with it to identify the LongHashRelation. + private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = { + // TODO: support BooleanType, DateType and TimestampType + keys.forall(_.dataType.isInstanceOf[IntegralType]) && + keys.map(_.dataType.defaultSize).sum <= 8 + } + @transient private lazy val relationFuture: Future[Array[InternalRow]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. @@ -78,7 +87,13 @@ case class ColumnarSubqueryBroadcastExec( relation match { case b: BuildSideRelation => // Transform columnar broadcast value to Array[InternalRow] by key. - b.transform(buildKeys(index)).distinct + if (canRewriteAsLongType(buildKeys)) { + b.transform(HashJoin.extractKeyExprAt(buildKeys, index)).distinct + } else { + b.transform( + BoundReference(index, buildKeys(index).dataType, buildKeys(index).nullable)) + .distinct + } case h: HashedRelation => val (iter, expr) = if (h.isInstanceOf[LongHashedRelation]) { (h.keys(), HashJoin.extractKeyExprAt(buildKeys, index)) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala index 60f3e2ffd966..e9dbeb560c68 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.vectorized.ColumnarBatch trait BuildSideRelation extends Serializable { @@ -26,11 +27,19 @@ trait BuildSideRelation extends Serializable { def deserialized: Iterator[ColumnarBatch] /** - * Transform columnar broadcasted value to Array[InternalRow] by key and distinct. + * Transform columnar broadcasted value to Array[InternalRow] by key. * @return */ def transform(key: Expression): Array[InternalRow] /** Returns a read-only copy of this, to be safely used in current thread. */ def asReadOnlyCopy(): BuildSideRelation + + /** + * The broadcast mode that is associated with this relation in Gluten allows for direct + * broadcasting of the original relation, so transforming a relation has a post-processing nature. + * + * Post-processed relation transforms can use this mode to obtain the desired format. + */ + val mode: BroadcastMode }