Skip to content

Commit

Permalink
[GLUTEN-6067][CH] [Part 2] Support CH backend with Spark3.5 - Prepare…
Browse files Browse the repository at this point in the history
… for supporting sink transform (#6197)

[CH] [Part 2] Support CH backend with Spark3.5 - Prepare for supporting sink transform

* [Refactor] remove duplicate codes

* Add NativeWriteChecker

* [Prepare to commit] getExtendedColumnarPostRules from Spark shim
  • Loading branch information
baibaichen authored Jun 24, 2024
1 parent f07e348 commit 1fbdbc4
Show file tree
Hide file tree
Showing 22 changed files with 1,379 additions and 1,055 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.backendsapi.clickhouse

import org.apache.gluten.{GlutenConfig, GlutenNumaBindingInfo}
import org.apache.gluten.GlutenNumaBindingInfo
import org.apache.gluten.backendsapi.IteratorApi
import org.apache.gluten.execution._
import org.apache.gluten.expression.ConverterUtils
Expand Down Expand Up @@ -61,6 +61,52 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
StructType(dataSchema)
}

private def createNativeIterator(
splitInfoByteArray: Array[Array[Byte]],
wsPlan: Array[Byte],
materializeInput: Boolean,
inputIterators: Seq[Iterator[ColumnarBatch]]): BatchIterator = {

/** Generate closeable ColumnBatch iterator. */
val listIterator =
inputIterators
.map {
case i: CloseableCHColumnBatchIterator => i
case it => new CloseableCHColumnBatchIterator(it)
}
.map(it => new ColumnarNativeIterator(it.asJava).asInstanceOf[GeneralInIterator])
.asJava
new CHNativeExpressionEvaluator().createKernelWithBatchIterator(
wsPlan,
splitInfoByteArray,
listIterator,
materializeInput
)
}

private def createCloseIterator(
context: TaskContext,
pipelineTime: SQLMetric,
updateNativeMetrics: IMetrics => Unit,
updateInputMetrics: Option[InputMetricsWrapper => Unit] = None,
nativeIter: BatchIterator): CloseableCHColumnBatchIterator = {

val iter = new CollectMetricIterator(
nativeIter,
updateNativeMetrics,
updateInputMetrics,
updateInputMetrics.map(_ => context.taskMetrics().inputMetrics).orNull)

context.addTaskFailureListener(
(ctx, _) => {
if (ctx.isInterrupted()) {
iter.cancel()
}
})
context.addTaskCompletionListener[Unit](_ => iter.close())
new CloseableCHColumnBatchIterator(iter, Some(pipelineTime))
}

// only set file schema for text format table
private def setFileSchemaForLocalFiles(
localFilesNode: LocalFilesNode,
Expand Down Expand Up @@ -198,45 +244,24 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
inputIterators: Seq[Iterator[ColumnarBatch]] = Seq()
): Iterator[ColumnarBatch] = {

assert(
require(
inputPartition.isInstanceOf[GlutenPartition],
"CH backend only accepts GlutenPartition in GlutenWholeStageColumnarRDD.")

val transKernel = new CHNativeExpressionEvaluator()
val inBatchIters = new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarNativeIterator(CHIteratorApi.genCloseableColumnBatchIterator(iter).asJava)
}.asJava)

val splitInfoByteArray = inputPartition
.asInstanceOf[GlutenPartition]
.splitInfosByteArray
val nativeIter =
transKernel.createKernelWithBatchIterator(
inputPartition.plan,
splitInfoByteArray,
inBatchIters,
false)
val wsPlan = inputPartition.plan
val materializeInput = false

val iter = new CollectMetricIterator(
nativeIter,
updateNativeMetrics,
updateInputMetrics,
context.taskMetrics().inputMetrics)

context.addTaskFailureListener(
(ctx, _) => {
if (ctx.isInterrupted()) {
iter.cancel()
}
})
context.addTaskCompletionListener[Unit](_ => iter.close())

// TODO: SPARK-25083 remove the type erasure hack in data source scan
new InterruptibleIterator(
context,
new CloseableCHColumnBatchIterator(
iter.asInstanceOf[Iterator[ColumnarBatch]],
Some(pipelineTime)))
createCloseIterator(
context,
pipelineTime,
updateNativeMetrics,
Some(updateInputMetrics),
createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, inputIterators))
)
}

// Generate Iterator[ColumnarBatch] for final stage.
Expand All @@ -252,52 +277,26 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
partitionIndex: Int,
materializeInput: Boolean): Iterator[ColumnarBatch] = {
// scalastyle:on argcount
GlutenConfig.getConf

val transKernel = new CHNativeExpressionEvaluator()
val columnarNativeIterator =
new JArrayList[GeneralInIterator](inputIterators.map {
iter =>
new ColumnarNativeIterator(CHIteratorApi.genCloseableColumnBatchIterator(iter).asJava)
}.asJava)
// we need to complete dependency RDD's firstly
val nativeIterator = transKernel.createKernelWithBatchIterator(
rootNode.toProtobuf.toByteArray,
// Final iterator does not contain scan split, so pass empty split info to native here.
new Array[Array[Byte]](0),
columnarNativeIterator,
materializeInput
)

val iter = new CollectMetricIterator(nativeIterator, updateNativeMetrics, null, null)

context.addTaskFailureListener(
(ctx, _) => {
if (ctx.isInterrupted()) {
iter.cancel()
}
})
context.addTaskCompletionListener[Unit](_ => iter.close())
new CloseableCHColumnBatchIterator(iter, Some(pipelineTime))
}
}
// Final iterator does not contain scan split, so pass empty split info to native here.
val splitInfoByteArray = new Array[Array[Byte]](0)
val wsPlan = rootNode.toProtobuf.toByteArray

object CHIteratorApi {

/** Generate closeable ColumnBatch iterator. */
def genCloseableColumnBatchIterator(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
iter match {
case _: CloseableCHColumnBatchIterator => iter
case _ => new CloseableCHColumnBatchIterator(iter)
}
// we need to complete dependency RDD's firstly
createCloseIterator(
context,
pipelineTime,
updateNativeMetrics,
None,
createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, inputIterators))
}
}

class CollectMetricIterator(
val nativeIterator: BatchIterator,
val updateNativeMetrics: IMetrics => Unit,
val updateInputMetrics: InputMetricsWrapper => Unit,
val inputMetrics: InputMetrics
val updateInputMetrics: Option[InputMetricsWrapper => Unit] = None,
val inputMetrics: InputMetrics = null
) extends Iterator[ColumnarBatch] {
private var outputRowCount = 0L
private var outputVectorCount = 0L
Expand Down Expand Up @@ -329,9 +328,7 @@ class CollectMetricIterator(
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics]
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
if (updateInputMetrics != null) {
updateInputMetrics(inputMetrics)
}
updateInputMetrics.foreach(_(inputMetrics))
metricsUpdated = true
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ import org.apache.spark.sql.delta.files.TahoeFileIndex
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
import org.apache.spark.sql.execution.datasources.{FileFormat, HadoopFsRelation}
import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules.NativeWritePostRule
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
Expand Down Expand Up @@ -583,14 +582,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] =
List()

/**
* Generate extended columnar post-rules.
*
* @return
*/
override def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] =
List(spark => NativeWritePostRule(spark))

override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = {
List()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.clickhouse.CHIteratorApi
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy}

Expand Down Expand Up @@ -75,7 +74,7 @@ case class CHBroadcastBuildSideRDD(

override def genBroadcastBuildSideIterator(): Iterator[ColumnarBatch] = {
CHBroadcastBuildSideCache.getOrBuildBroadcastHashTable(broadcasted, broadcastContext)
CHIteratorApi.genCloseableColumnBatchIterator(Iterator.empty)
Iterator.empty
}
}

Expand Down
Loading

0 comments on commit 1fbdbc4

Please sign in to comment.