Skip to content

Commit

Permalink
[GLUTEN-8115][CORE] Refine the BuildSideRelation transform to support…
Browse files Browse the repository at this point in the history
… all scenarios (apache#8116)
  • Loading branch information
yikf authored Dec 9, 2024
1 parent 152be37 commit 15f4cde
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,55 +22,54 @@ import org.apache.gluten.substrait.expression.ExpressionNode
import org.apache.gluten.substrait.plan.{PlanBuilder, PlanNode}
import org.apache.gluten.substrait.rel.RelBuilder

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression}

import com.google.common.collect.Lists

import java.util

import scala.collection.JavaConverters._

object PlanNodesUtil {

def genProjectionsPlanNode(key: Expression, output: Seq[Attribute]): PlanNode = {
def genProjectionsPlanNode(key: Seq[Expression], output: Seq[Attribute]): PlanNode = {
val context = new SubstraitContext

var operatorId = context.nextOperatorId("ClickHouseBuildSideRelationReadIter")
val typeList = ConverterUtils.collectAttributeTypeNodes(output)
val nameList = ConverterUtils.collectAttributeNamesWithExprId(output)
val readRel = RelBuilder.makeReadRelForInputIterator(typeList, nameList, context, operatorId)

// replace attribute to BoundRefernce according to the output
val newBoundRefKey = key.transformDown {
case expression: AttributeReference =>
val columnInOutput = output.zipWithIndex.filter {
p: (Attribute, Int) => p._1.exprId == expression.exprId || p._1.name == expression.name
}
if (columnInOutput.isEmpty) {
throw new IllegalStateException(
s"Key $expression not found from build side relation output: $output")
}
if (columnInOutput.size != 1) {
throw new IllegalStateException(
s"More than one key $expression found from build side relation output: $output")
}
val boundReference = columnInOutput.head
BoundReference(boundReference._2, boundReference._1.dataType, boundReference._1.nullable)
case other => other
}

// project
operatorId = context.nextOperatorId("ClickHouseBuildSideRelationProjection")
val args = context.registeredFunction

val columnarProjExpr = ExpressionConverter
.replaceWithExpressionTransformer(newBoundRefKey, attributeSeq = output)
.replaceWithExpressionTransformer(key, attributeSeq = output)

val projExprNodeList = new java.util.ArrayList[ExpressionNode]()
projExprNodeList.add(columnarProjExpr.doTransform(args))
columnarProjExpr.foreach(e => projExprNodeList.add(e.doTransform(args)))

PlanBuilder.makePlan(
context,
Lists.newArrayList(
RelBuilder.makeProjectRel(readRel, projExprNodeList, context, operatorId, output.size)),
Lists.newArrayList(
ConverterUtils.genColumnNameWithExprId(ConverterUtils.getAttrFromExpr(key)))
Lists.newArrayList(genColumnNameWithExprId(key, output))
)
}

private def genColumnNameWithExprId(
key: Seq[Expression],
output: Seq[Attribute]): util.List[String] = {
key
.map {
k =>
val reference = k.collectFirst { case BoundReference(ordinal, _, _) => output(ordinal) }
assert(reference.isDefined)
reference.get
}
.map(ConverterUtils.genColumnNameWithExprId)
.toList
.asJava
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import org.apache.gluten.vectorized._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode}
import org.apache.spark.sql.execution.utils.CHExecUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.CHShuffleReadStreamFactory
Expand Down Expand Up @@ -72,18 +72,26 @@ case class ClickHouseBuildSideRelation(
}

/**
* Transform columnar broadcast value to Array[InternalRow] by key and distinct.
* Transform columnar broadcast value to Array[InternalRow] by key.
*
* @return
*/
override def transform(key: Expression): Array[InternalRow] = {
// native block reader
val blockReader = new CHStreamReader(CHShuffleReadStreamFactory.create(batches, true))
val broadCastIter: Iterator[ColumnarBatch] = IteratorUtil.createBatchIterator(blockReader)

val transformProjections = mode match {
case HashedRelationBroadcastMode(k, _) => k
case IdentityBroadcastMode => output
}

// Expression compute, return block iterator
val expressionEval = new SimpleExpressionEval(
new ColumnarNativeIterator(broadCastIter.asJava),
PlanNodesUtil.genProjectionsPlanNode(key, output))
PlanNodesUtil.genProjectionsPlanNode(transformProjections, output))

val proj = UnsafeProjection.create(Seq(key))

try {
// convert columnar to row
Expand All @@ -95,6 +103,7 @@ case class ClickHouseBuildSideRelation(
} else {
CHExecUtil
.getRowIterFromSparkRowInfo(block, batch.numColumns(), batch.numRows())
.map(proj)
.map(row => row.copy())
}
}.toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
}
numOutputRows += serialized.map(_.getNumRows).sum
dataSize += rawSize
ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized))
ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode)
}

override def doCanonicalizeForBroadcastMode(mode: BroadcastMode): BroadcastMode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ object BroadcastUtils {
}
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized)
serialized,
mode)
}
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
Expand All @@ -124,7 +125,8 @@ object BroadcastUtils {
}
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized)
serialized,
mode)
}
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ import org.apache.gluten.utils.ArrowAbiUtil
import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.utils.SparkArrowUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -37,9 +40,19 @@ import org.apache.arrow.c.ArrowSchema

import scala.collection.JavaConverters.asScalaIteratorConverter

case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Array[Byte]])
case class ColumnarBuildSideRelation(
output: Seq[Attribute],
batches: Array[Array[Byte]],
mode: BroadcastMode)
extends BuildSideRelation {

private def transformProjection: UnsafeProjection = {
mode match {
case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k)
case IdentityBroadcastMode => UnsafeProjection.create(output, output)
}
}

override def deserialized: Iterator[ColumnarBatch] = {
val runtime =
Runtimes.contextInstance(BackendsApiManager.getBackendName, "BuildSideRelation#deserialized")
Expand Down Expand Up @@ -84,8 +97,11 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra
override def asReadOnlyCopy(): ColumnarBuildSideRelation = this

/**
* Transform columnar broadcast value to Array[InternalRow] by key and distinct. NOTE: This method
* was called in Spark Driver, should manage resources carefully.
* Transform columnar broadcast value to Array[InternalRow] by key.
*
* NOTE:
* - This method was called in Spark Driver, should manage resources carefully.
* - The "key" must be already been bound reference.
*/
override def transform(key: Expression): Array[InternalRow] = TaskResources.runUnsafe {
val runtime =
Expand All @@ -106,17 +122,7 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra

var closed = false

val exprIds = output.map(_.exprId)
val projExpr = key.transformDown {
case attr: AttributeReference if !exprIds.contains(attr.exprId) =>
val i = output.count(_.name == attr.name)
if (i != 1) {
throw new IllegalArgumentException(s"Only one attr with the same name is supported: $key")
} else {
output.find(_.name == attr.name).get
}
}
val proj = UnsafeProjection.create(Seq(projExpr), output)
val proj = UnsafeProjection.create(Seq(key))

// Convert columnar to Row.
val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime)
Expand Down Expand Up @@ -178,7 +184,7 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra
rowId += 1
row
}
}.map(proj).map(_.copy())
}.map(transformProjection).map(proj).map(_.copy())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelation, HashJoin, LongHashedRelation}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.IntegralType
import org.apache.spark.util.ThreadUtils

import scala.concurrent.Future
Expand Down Expand Up @@ -64,6 +65,14 @@ case class ColumnarSubqueryBroadcastExec(
copy(name = "native-dpp", buildKeys = keys, child = child.canonicalized)
}

// Copy from org.apache.spark.sql.execution.joins.HashJoin#canRewriteAsLongType
// we should keep consistent with it to identify the LongHashRelation.
private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = {
// TODO: support BooleanType, DateType and TimestampType
keys.forall(_.dataType.isInstanceOf[IntegralType]) &&
keys.map(_.dataType.defaultSize).sum <= 8
}

@transient
private lazy val relationFuture: Future[Array[InternalRow]] = {
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
Expand All @@ -78,7 +87,13 @@ case class ColumnarSubqueryBroadcastExec(
relation match {
case b: BuildSideRelation =>
// Transform columnar broadcast value to Array[InternalRow] by key.
b.transform(buildKeys(index)).distinct
if (canRewriteAsLongType(buildKeys)) {
b.transform(HashJoin.extractKeyExprAt(buildKeys, index)).distinct
} else {
b.transform(
BoundReference(index, buildKeys(index).dataType, buildKeys(index).nullable))
.distinct
}
case h: HashedRelation =>
val (iter, expr) = if (h.isInstanceOf[LongHashedRelation]) {
(h.keys(), HashJoin.extractKeyExprAt(buildKeys, index))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.vectorized.ColumnarBatch

trait BuildSideRelation extends Serializable {
Expand All @@ -26,11 +27,19 @@ trait BuildSideRelation extends Serializable {
def deserialized: Iterator[ColumnarBatch]

/**
* Transform columnar broadcasted value to Array[InternalRow] by key and distinct.
* Transform columnar broadcasted value to Array[InternalRow] by key.
* @return
*/
def transform(key: Expression): Array[InternalRow]

/** Returns a read-only copy of this, to be safely used in current thread. */
def asReadOnlyCopy(): BuildSideRelation

/**
* The broadcast mode that is associated with this relation in Gluten allows for direct
* broadcasting of the original relation, so transforming a relation has a post-processing nature.
*
* Post-processed relation transforms can use this mode to obtain the desired format.
*/
val mode: BroadcastMode
}

0 comments on commit 15f4cde

Please sign in to comment.