Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Broadcast HashJoin #211

Merged
merged 16 commits into from
Mar 26, 2024
Merged
51 changes: 51 additions & 0 deletions common/src/main/java/org/apache/comet/CometArrowStreamWriter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet;

import java.io.IOException;
import java.nio.channels.WritableByteChannel;

import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;

/**
* A custom `ArrowStreamWriter` that allows writing batches from different root to the same stream.
* Arrow `ArrowStreamWriter` cannot change the root after initialization.
*/
public class CometArrowStreamWriter extends ArrowStreamWriter {
public CometArrowStreamWriter(
VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
super(root, provider, out);
}

public void writeMoreBatch(VectorSchemaRoot root) throws IOException {
VectorUnloader unloader =
new VectorUnloader(
root, /*includeNullCount*/ true, NoCompressionCodec.INSTANCE, /*alignBuffers*/ true);

try (ArrowRecordBatch batch = unloader.getRecordBatch()) {
writeRecordBatch(batch);
}
}
}
18 changes: 9 additions & 9 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.comet.vector

import java.io.OutputStream
import java.nio.channels.Channels

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand All @@ -28,10 +29,11 @@ import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictiona
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.SparkException
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.CometArrowStreamWriter

class NativeUtil {
private val allocator = new RootAllocator(Long.MaxValue)
private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider
Expand All @@ -46,29 +48,27 @@ class NativeUtil {
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = {
var schemaRoot: Option[VectorSchemaRoot] = None
var writer: Option[ArrowStreamWriter] = None
var writer: Option[CometArrowStreamWriter] = None
var rowCount = 0

batches.foreach { batch =>
val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
val root = schemaRoot.getOrElse(new VectorSchemaRoot(fieldVectors.asJava))
val root = new VectorSchemaRoot(fieldVectors.asJava)
val provider = batchProviderOpt.getOrElse(dictionaryProvider)
Copy link
Contributor

@advancedxy advancedxy Mar 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One related question what if incoming batches have different dictionary provider?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the batch has its provider, it should be returned in batchProviderOpt ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the writer is reused. Once the writer is created, new dictionary provider(if different from previous one) from new batches is never used/ written?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. I suppose that the dictionary provider is same across batches. This seems to be the reason why there is dictionary provider, i.e. to store dictionary values for arrays/batches.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. It seems getBatchFieldVectors only checks same dictionary provider across arrays but not batches. Maybe we should add that too? Anyway, it's kind of out of this PR's scope. Maybe in a separate issue to track that.


if (writer.isEmpty) {
writer = Some(new ArrowStreamWriter(root, provider, out))
writer = Some(new CometArrowStreamWriter(root, provider, Channels.newChannel(out)))
writer.get.start()
writer.get.writeBatch()
} else {
writer.get.writeMoreBatch(root)
}
writer.get.writeBatch()

root.clear()
schemaRoot = Some(root)

rowCount += batch.numRows()
}

writer.map(_.end())
schemaRoot.map(_.close())

rowCount
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.nio.channels.ReadableByteChannel

import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.vector.StreamReader
import org.apache.comet.vector._

class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] {

Expand All @@ -36,6 +36,13 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna
return true
}

// Release the previous batch.
// If it is not released, when closing the reader, arrow library will complain about
// memory leak.
if (currentBatch != null) {
currentBatch.close()
}

Comment on lines +39 to +45
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to release the batch before loading next batch. Because ArrowStreamReader loads data into same vectors of root internally. After loading next batch, close will release the just loaded batch instead of previous batch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this related to the memory leak we saw?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because ArrowStreamReader loads data into same vectors of root internally. After loading next batch, close will release the just loaded batch instead of previous batch.

This sounds like a data corruption problem. If the just loaded batch is closed/released, the just loaded ColumnarBatch would be corrupted? But it seems like that the CI passes without any issue previously.

When working on #206, I also found out it might be inconvenient to use Arrow Java's memory API. It requires extra caution to allocate and release ArrowBuf correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this related to the memory leak we saw?

It's not, although I suspected it before too. For shuffle, a channel only contains one batch, so ArrowReaderIterator doesn't hit this issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like a data corruption problem. If the just loaded batch is closed/released, the just loaded ColumnarBatch would be corrupted? But it seems like that the CI passes without any issue previously.

When working on #206, I also found out it might be inconvenient to use Arrow Java's memory API. It requires extra caution to allocate and release ArrowBuf correctly.

Due to #211 (comment), this issue is not exposed before.

I feel that Arrow Java API is hard to use and somehow counter-intuitive, especially compared with arrow-rs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I feel the same pain when using Java Arrow. I think in the long term we'd better to switch away from it. It should be relatively easy except the Java Arrow Flight feature.

batch = nextBatch()
if (batch.isEmpty) {
return false
Expand All @@ -50,13 +57,6 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna

val nextBatch = batch.get

// Release the previous batch.
// If it is not released, when closing the reader, arrow library will complain about
// memory leak.
if (currentBatch != null) {
currentBatch.close()
}

currentBatch = nextBatch
batch = None
currentBatch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -335,6 +335,26 @@ class CometSparkSessionExtensions
op
}

case op: SortMergeJoinExec
viirya marked this conversation as resolved.
Show resolved Hide resolved
if isCometOperatorEnabled(conf, "sort_merge_join") &&
op.children.forall(isCometNative(_)) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometSortMergeJoinExec(
nativeOp,
op,
op.leftKeys,
op.rightKeys,
op.joinType,
op.condition,
op.left,
op.right,
SerializedPlan(None))
case None =>
op
}

case op: ShuffledHashJoinExec
if isCometOperatorEnabled(conf, "hash_join") &&
op.children.forall(isCometNative(_)) =>
Expand All @@ -356,19 +376,20 @@ class CometSparkSessionExtensions
op
}

case op: SortMergeJoinExec
if isCometOperatorEnabled(conf, "sort_merge_join") &&
case op: BroadcastHashJoinExec
if isCometOperatorEnabled(conf, "broadcast_hash_join") &&
op.children.forall(isCometNative(_)) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometSortMergeJoinExec(
CometBroadcastHashJoinExec(
nativeOp,
op,
op.leftKeys,
op.rightKeys,
op.joinType,
op.condition,
op.buildSide,
op.left,
op.right,
SerializedPlan(None))
Expand Down Expand Up @@ -411,6 +432,16 @@ class CometSparkSessionExtensions
u
}

// For AQE broadcast stage on a Comet broadcast exchange
case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) =>
val newOp = transform1(s)
newOp match {
case Some(nativeOp) =>
CometSinkPlaceHolder(nativeOp, s, s)
case None =>
s
}

case b: BroadcastExchangeExec
if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") &&
isCometBroadCastEnabled(conf) =>
Expand Down
18 changes: 14 additions & 4 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -1915,7 +1915,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
}
}

case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf, "hash_join") =>
case join: HashJoin =>
// `HashJoin` has only two implementations in Spark, but we check the type of the join to
// make sure we are handling the correct join type.
if (!(isCometOperatorEnabled(op.conf, "hash_join") &&
join.isInstanceOf[ShuffledHashJoinExec]) &&
!(isCometOperatorEnabled(op.conf, "broadcast_hash_join") &&
join.isInstanceOf[BroadcastHashJoinExec])) {
return None
}

if (join.buildSide == BuildRight) {
// DataFusion HashJoin assumes build side is always left.
// TODO: support BuildRight
Expand Down Expand Up @@ -2063,6 +2072,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true
case _: TakeOrderedAndProjectExec => true
case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
case _: BroadcastExchangeExec => true
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,28 @@ class CometBatchRDD(
}

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometBatchPartition]
new Iterator[ColumnarBatch] {
val partition = split.asInstanceOf[CometBatchPartition]
val batchesIter = partition.value.value.map(CometExec.decodeBatches(_)).toIterator
viirya marked this conversation as resolved.
Show resolved Hide resolved
var iter: Iterator[ColumnarBatch] = null

override def hasNext: Boolean = {
if (iter != null) {
if (iter.hasNext) {
viirya marked this conversation as resolved.
Show resolved Hide resolved
return true
}
}
if (batchesIter.hasNext) {
iter = batchesIter.next()
viirya marked this conversation as resolved.
Show resolved Hide resolved
return iter.hasNext
}
false
}

partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator
viirya marked this conversation as resolved.
Show resolved Hide resolved
override def next(): ColumnarBatch = {
iter.next()
}
}
}
}

Expand Down
46 changes: 42 additions & 4 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -269,7 +269,8 @@ abstract class CometNativeExec extends CometExec {
plan match {
case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec |
_: AQEShuffleReadExec | _: CometShuffleExchangeExec | _: CometUnionExec |
_: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec =>
_: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec |
_: CometBroadcastExchangeExec | _: BroadcastQueryStageExec =>
func(plan)
case _: CometPlan =>
// Other Comet operators, continue to traverse the tree.
Expand Down Expand Up @@ -588,6 +589,41 @@ case class CometHashAggregateExec(
Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child)
}

case class CometSortMergeJoinExec(
viirya marked this conversation as resolved.
Show resolved Hide resolved
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {
override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

override def stringArgs: Iterator[Any] =
Iterator(leftKeys, rightKeys, joinType, condition, left, right)

override def equals(obj: Any): Boolean = {
obj match {
case other: CometSortMergeJoinExec =>
this.leftKeys == other.leftKeys &&
this.rightKeys == other.rightKeys &&
this.condition == other.condition &&
this.left == other.left &&
this.right == other.right &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int =
Objects.hashCode(leftKeys, rightKeys, condition, left, right)
}

case class CometHashJoinExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
Expand Down Expand Up @@ -625,13 +661,14 @@ case class CometHashJoinExec(
Objects.hashCode(leftKeys, rightKeys, condition, left, right)
}

case class CometSortMergeJoinExec(
case class CometBroadcastHashJoinExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
buildSide: BuildSide,
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
Expand All @@ -644,10 +681,11 @@ case class CometSortMergeJoinExec(

override def equals(obj: Any): Boolean = {
obj match {
case other: CometSortMergeJoinExec =>
case other: CometBroadcastHashJoinExec =>
this.leftKeys == other.leftKeys &&
this.rightKeys == other.rightKeys &&
this.condition == other.condition &&
this.buildSide == other.buildSide &&
sunchao marked this conversation as resolved.
Show resolved Hide resolved
this.left == other.left &&
this.right == other.right &&
this.serializedPlanOpt == other.serializedPlanOpt
Expand Down
Loading
Loading