Skip to content

Commit

Permalink
feat: Support BroadcastHashJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 16, 2024
1 parent 895160b commit 1b50512
Show file tree
Hide file tree
Showing 9 changed files with 238 additions and 29 deletions.
51 changes: 51 additions & 0 deletions common/src/main/java/org/apache/comet/CometArrowStreamWriter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet;

import java.io.IOException;
import java.nio.channels.WritableByteChannel;

import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;

/**
* A custom `ArrowStreamWriter` that allows writing batches from different root to the same stream.
* Arrow `ArrowStreamWriter` cannot change the root after initialization.
*/
public class CometArrowStreamWriter extends ArrowStreamWriter {
public CometArrowStreamWriter(
VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
super(root, provider, out);
}

public void writeMoreBatch(VectorSchemaRoot root) throws IOException {
VectorUnloader unloader =
new VectorUnloader(
root, /*includeNullCount*/ true, NoCompressionCodec.INSTANCE, /*alignBuffers*/ true);

try (ArrowRecordBatch batch = unloader.getRecordBatch()) {
writeRecordBatch(batch);
}
}
}
18 changes: 9 additions & 9 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.comet.vector

import java.io.OutputStream
import java.nio.channels.Channels

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand All @@ -28,10 +29,11 @@ import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictiona
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.SparkException
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.CometArrowStreamWriter

class NativeUtil {
private val allocator = new RootAllocator(Long.MaxValue)
private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider
Expand All @@ -46,29 +48,27 @@ class NativeUtil {
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = {
var schemaRoot: Option[VectorSchemaRoot] = None
var writer: Option[ArrowStreamWriter] = None
var writer: Option[CometArrowStreamWriter] = None
var rowCount = 0

batches.foreach { batch =>
val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
val root = schemaRoot.getOrElse(new VectorSchemaRoot(fieldVectors.asJava))
val root = new VectorSchemaRoot(fieldVectors.asJava)
val provider = batchProviderOpt.getOrElse(dictionaryProvider)

if (writer.isEmpty) {
writer = Some(new ArrowStreamWriter(root, provider, out))
writer = Some(new CometArrowStreamWriter(root, provider, Channels.newChannel(out)))
writer.get.start()
writer.get.writeBatch()
} else {
writer.get.writeMoreBatch(root)
}
writer.get.writeBatch()

root.clear()
schemaRoot = Some(root)

rowCount += batch.numRows()
}

writer.map(_.end())
schemaRoot.map(_.close())

rowCount
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.nio.channels.ReadableByteChannel

import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.vector.StreamReader
import org.apache.comet.vector._

class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] {

Expand All @@ -36,6 +36,13 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna
return true
}

// Release the previous batch.
// If it is not released, when closing the reader, arrow library will complain about
// memory leak.
if (currentBatch != null) {
currentBatch.close()
}

batch = nextBatch()
if (batch.isEmpty) {
return false
Expand All @@ -50,13 +57,6 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna

val nextBatch = batch.get

// Release the previous batch.
// If it is not released, when closing the reader, arrow library will complain about
// memory leak.
if (currentBatch != null) {
currentBatch.close()
}

currentBatch = nextBatch
batch = None
currentBatch
Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use futures::Stream;
use itertools::Itertools;

use arrow::compute::{cast_with_options, CastOptions};
use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions};
use arrow_array::{make_array, Array, ArrayRef, RecordBatch, RecordBatchOptions};
use arrow_data::ArrayData;
use arrow_schema::{DataType, Field, Schema, SchemaRef};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -359,6 +359,27 @@ class CometSparkSessionExtensions
op
}

case op: BroadcastHashJoinExec
if isCometOperatorEnabled(conf, "broadcast_hash_join") &&
op.children.forall(isCometNative(_)) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometBroadcastHashJoinExec(
nativeOp,
op,
op.leftKeys,
op.rightKeys,
op.joinType,
op.condition,
op.buildSide,
op.left,
op.right,
SerializedPlan(None))
case None =>
op
}

case c @ CoalesceExec(numPartitions, child)
if isCometOperatorEnabled(conf, "coalesce")
&& isCometNative(child) =>
Expand Down Expand Up @@ -394,6 +415,16 @@ class CometSparkSessionExtensions
u
}

// For AQE broadcast stage on a Comet broadcast exchange
case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) =>
val newOp = transform1(s)
newOp match {
case Some(nativeOp) =>
CometSinkPlaceHolder(nativeOp, s, s)
case None =>
s
}

case b: BroadcastExchangeExec
if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") &&
isCometBroadCastEnabled(conf) =>
Expand Down
18 changes: 14 additions & 4 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -1840,7 +1840,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
}
}

case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf, "hash_join") =>
case join: HashJoin =>
// `HashJoin` has only two implementations in Spark, but we check the type of the join to
// make sure we are handling the correct join type.
if (!(isCometOperatorEnabled(op.conf, "hash_join") &&
join.isInstanceOf[ShuffledHashJoinExec]) &&
!(isCometOperatorEnabled(op.conf, "broadcast_hash_join") &&
join.isInstanceOf[BroadcastHashJoinExec])) {
return None
}

if (join.buildSide == BuildRight) {
// DataFusion HashJoin assumes build side is always left.
// TODO: support BuildRight
Expand Down Expand Up @@ -1932,6 +1941,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true
case _: TakeOrderedAndProjectExec => true
case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
case _: BroadcastExchangeExec => true
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,28 @@ class CometBatchRDD(
}

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometBatchPartition]
new Iterator[ColumnarBatch] {
val partition = split.asInstanceOf[CometBatchPartition]
val batchesIter = partition.value.value.map(CometExec.decodeBatches(_)).toIterator
var iter: Iterator[ColumnarBatch] = null

override def hasNext: Boolean = {
if (iter != null) {
if (iter.hasNext) {
return true
}
}
if (batchesIter.hasNext) {
iter = batchesIter.next()
return iter.hasNext
}
false
}

partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator
override def next(): ColumnarBatch = {
iter.next()
}
}
}
}

Expand Down
42 changes: 40 additions & 2 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -269,7 +269,8 @@ abstract class CometNativeExec extends CometExec {
plan match {
case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec |
_: AQEShuffleReadExec | _: CometShuffleExchangeExec | _: CometUnionExec |
_: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec =>
_: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec |
_: CometBroadcastExchangeExec | _: BroadcastQueryStageExec =>
func(plan)
case _: CometPlan =>
// Other Comet operators, continue to traverse the tree.
Expand Down Expand Up @@ -625,6 +626,43 @@ case class CometHashJoinExec(
Objects.hashCode(leftKeys, rightKeys, condition, left, right)
}

case class CometBroadcastHashJoinExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
buildSide: BuildSide,
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {
override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

override def stringArgs: Iterator[Any] =
Iterator(leftKeys, rightKeys, joinType, condition, left, right)

override def equals(obj: Any): Boolean = {
obj match {
case other: CometBroadcastHashJoinExec =>
this.leftKeys == other.leftKeys &&
this.rightKeys == other.rightKeys &&
this.condition == other.condition &&
this.buildSide == other.buildSide &&
this.left == other.left &&
this.right == other.right &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int =
Objects.hashCode(leftKeys, rightKeys, condition, left, right)
}

case class CometScanWrapper(override val nativeOp: Operator, override val originalPlan: SparkPlan)
extends CometNativeExec
with LeafExecNode {
Expand Down
Loading

0 comments on commit 1b50512

Please sign in to comment.