Skip to content

Commit

Permalink
[CORE] Add InputIteratorTransformer to decouple ReadRel and iterator …
Browse files Browse the repository at this point in the history
…index (apache#3854)
  • Loading branch information
ulysses-you authored Nov 30, 2023
1 parent 8f4712a commit d2980b7
Show file tree
Hide file tree
Showing 40 changed files with 391 additions and 619 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
MetricsUtil.updateNativeMetrics(child, relMap, joinParamsMap, aggParamsMap)
}

override def genInputIteratorTransformerMetrics(
sparkContext: SparkContext): Map[String, SQLMetric] = {
Map(
"iterReadTime" -> SQLMetrics.createTimingMetric(
sparkContext,
"time of reading from iterator"),
"outputVectors" -> SQLMetrics.createMetric(sparkContext, "number of output vectors")
)
}

override def genInputIteratorTransformerMetricsUpdater(
metrics: Map[String, SQLMetric]): MetricsUpdater = {
InputIteratorMetricsUpdater(metrics)
}

override def genBatchScanTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] =
Map(
"inputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"),
Expand Down Expand Up @@ -163,8 +178,6 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
SQLMetrics.createTimingMetric(sparkContext, "time of aggregating"),
"postProjectTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of postProjection"),
"iterReadTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of reading from iterator"),
"totalTime" -> SQLMetrics.createTimingMetric(sparkContext, "total time")
)

Expand Down Expand Up @@ -312,12 +325,8 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
"extraTime" -> SQLMetrics.createTimingMetric(sparkContext, "extra operators time"),
"inputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for data"),
"outputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for output"),
"streamIterReadTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of stream side read"),
"streamPreProjectionTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of stream side preProjection"),
"buildIterReadTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of build side read"),
"buildPreProjectionTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of build side preProjection"),
"postProjectTime" ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
}

def wrapChild(child: SparkPlan): WholeStageTransformer = {
WholeStageTransformer(ProjectExecTransformer(child.output ++ appendedProjections, child))(
val childWithAdapter = ColumnarCollapseTransformStages.wrapInputIteratorTransformer(child)
WholeStageTransformer(
ProjectExecTransformer(child.output ++ appendedProjections, childWithAdapter))(
ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet()
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,10 @@ package io.glutenproject.execution

import io.glutenproject.extension.ValidationResult
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.rel.RelBuilder

import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.{And, Expression}
import org.apache.spark.sql.execution.SparkPlan

import java.util

import scala.collection.JavaConverters._

case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
extends FilterExecTransformerBase(condition, child) {

Expand All @@ -46,13 +41,8 @@ case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
val leftCondition = getLeftCondition
val childCtx = child match {
case c: TransformSupport =>
c.doTransform(context)
case _ =>
throw new IllegalStateException(s"child ${child.nodeName} doesn't support transform.");
}

val operatorId = context.nextOperatorId(this.nodeName)
if (leftCondition == null) {
Expand All @@ -63,34 +53,15 @@ case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
TransformContext(childCtx.inputAttributes, output, childCtx.root)
}

val currRel = if (childCtx != null) {
getRelNode(
context,
leftCondition,
child.output,
operatorId,
childCtx.root,
validation = false)
} else {
// This means the input is just an iterator, so an ReadRel will be created as child.
// Prepare the input schema.
val attrList = new util.ArrayList[Attribute](child.output.asJava)
getRelNode(
context,
leftCondition,
child.output,
operatorId,
RelBuilder.makeReadRel(attrList, context, operatorId),
validation = false)
}
val currRel = getRelNode(
context,
leftCondition,
child.output,
operatorId,
childCtx.root,
validation = false)
assert(currRel != null, "Filter rel should be valid.")
val inputAttributes = if (childCtx != null) {
// Use the outputAttributes of child context as inputAttributes.
childCtx.outputAttributes
} else {
child.output
}
TransformContext(inputAttributes, output, currRel)
TransformContext(childCtx.outputAttributes, output, currRel)
}

private def getLeftCondition: Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.glutenproject.expression._
import io.glutenproject.substrait.`type`.TypeNode
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.rel.{LocalFilesBuilder, RelBuilder, RelNode}
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -80,17 +80,12 @@ case class CHHashAggregateExecTransformer(
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child match {
case c: TransformSupport =>
c.doTransform(context)
case _ =>
null
}

val aggParams = new AggregationParams
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
val operatorId = context.nextOperatorId(this.nodeName)

val (relNode, inputAttributes, outputAttributes) = if (childCtx != null) {
val aggParams = new AggregationParams
val isChildTransformSupported = !child.isInstanceOf[InputIteratorTransformer]
val (relNode, inputAttributes, outputAttributes) = if (isChildTransformSupported) {
// The final HashAggregateExecTransformer and partial HashAggregateExecTransformer
// are in the one WholeStageTransformer.
if (modes.isEmpty || !modes.contains(Partial)) {
Expand All @@ -110,7 +105,6 @@ case class CHHashAggregateExecTransformer(
// Notes: Currently, ClickHouse backend uses the output attributes of
// aggregateResultAttributes as Shuffle output,
// which is different from Velox backend.
aggParams.isReadRel = true
val typeList = new util.ArrayList[TypeNode]()
val nameList = new util.ArrayList[String]()
val (inputAttrs, outputAttrs) = {
Expand Down Expand Up @@ -152,14 +146,10 @@ case class CHHashAggregateExecTransformer(
}
}

// The iterator index will be added in the path of LocalFiles.
val iteratorIndex: Long = context.nextIteratorIndex
val inputIter = LocalFilesBuilder.makeLocalFiles(
ConverterUtils.ITERATOR_PREFIX.concat(iteratorIndex.toString))
context.setIteratorNode(iteratorIndex, inputIter)
// The output is different with child.output, so we can not use `childCtx.root` as the
// `ReadRel`. Here we re-generate the `ReadRel` with the special output list.
val readRel =
RelBuilder.makeReadRel(typeList, nameList, null, iteratorIndex, context, operatorId)

RelBuilder.makeReadRelForInputIteratorWithoutRegister(typeList, nameList, context)
(getAggRel(context, operatorId, aggParams, readRel), inputAttrs, outputAttrs)
}
TransformContext(inputAttributes, outputAttributes, relNode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,6 @@ class HashAggregateMetricsUpdater(val metrics: Map[String, SQLMetric])
var currentIdx = operatorMetrics.metricsList.size() - 1
var totalTime = 0L

// read rel
if (aggregationParams.isReadRel) {
metrics("iterReadTime") +=
(operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong
metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors
totalTime += operatorMetrics.metricsList.get(currentIdx).time
currentIdx -= 1
}

// pre projection
if (aggregationParams.preProjectionNeeded) {
metrics("preProjectTime") +=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,6 @@ class HashJoinMetricsUpdater(val metrics: Map[String, SQLMetric])
currentIdx -= 1
}

// build side read rel
if (joinParams.isBuildReadRel) {
val buildSideRealRel = operatorMetrics.metricsList.get(currentIdx)
metrics("buildIterReadTime") += (buildSideRealRel.time / 1000L).toLong
metrics("outputVectors") += buildSideRealRel.outputVectors
totalTime += buildSideRealRel.time
currentIdx -= 1
}

// stream side pre projection
if (joinParams.streamPreProjectionNeeded) {
metrics("streamPreProjectionTime") +=
Expand All @@ -58,25 +49,15 @@ class HashJoinMetricsUpdater(val metrics: Map[String, SQLMetric])
currentIdx -= 1
}

// stream side read rel
if (joinParams.isStreamedReadRel) {
metrics("streamIterReadTime") +=
(operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong
metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors
totalTime += operatorMetrics.metricsList.get(currentIdx).time

// update fillingRightJoinSideTime
MetricsUtil
.getAllProcessorList(operatorMetrics.metricsList.get(currentIdx))
.foreach(
processor => {
if (processor.name.equalsIgnoreCase("FillingRightJoinSide")) {
metrics("fillingRightJoinSideTime") += (processor.time / 1000L).toLong
}
})

currentIdx -= 1
}
// update fillingRightJoinSideTime
MetricsUtil
.getAllProcessorList(operatorMetrics.metricsList.get(currentIdx))
.foreach(
processor => {
if (processor.name.equalsIgnoreCase("FillingRightJoinSide")) {
metrics("fillingRightJoinSideTime") += (processor.time / 1000L).toLong
}
})

// joining
val joinMetricsData = operatorMetrics.metricsList.get(currentIdx)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.metrics

import org.apache.spark.sql.execution.metric.SQLMetric

case class InputIteratorMetricsUpdater(metrics: Map[String, SQLMetric]) extends MetricsUpdater {
override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = {
if (opMetrics != null) {
val operatorMetrics = opMetrics.asInstanceOf[OperatorMetrics]
if (!operatorMetrics.metricsList.isEmpty) {
val metricsData = operatorMetrics.metricsList.get(0)
metrics("iterReadTime") += (metricsData.time / 1000L).toLong
metrics("outputVectors") += metricsData.outputVectors
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import io.glutenproject.expression.{ConverterUtils, ExpressionConverter}
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.expression.ExpressionNode
import io.glutenproject.substrait.plan.{PlanBuilder, PlanNode}
import io.glutenproject.substrait.rel.{LocalFilesBuilder, RelBuilder}
import io.glutenproject.substrait.rel.RelBuilder

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression}

Expand All @@ -31,17 +31,10 @@ object PlanNodesUtil {
def genProjectionsPlanNode(key: Expression, output: Seq[Attribute]): PlanNode = {
val context = new SubstraitContext

// input
val iteratorIndex: Long = context.nextIteratorIndex
var operatorId = context.nextOperatorId("ClickHouseBuildSideRelationReadIter")
val inputIter = LocalFilesBuilder.makeLocalFiles(
ConverterUtils.ITERATOR_PREFIX.concat(iteratorIndex.toString))
context.setIteratorNode(iteratorIndex, inputIter)

val typeList = ConverterUtils.collectAttributeTypeNodes(output)
val nameList = ConverterUtils.collectAttributeNamesWithExprId(output)
val readRel =
RelBuilder.makeReadRel(typeList, nameList, null, iteratorIndex, context, operatorId)
val readRel = RelBuilder.makeReadRelForInputIterator(typeList, nameList, context, operatorId)

// replace attribute to BoundRefernce according to the output
val newBoundRefKey = key.transformDown {
Expand Down
Loading

0 comments on commit d2980b7

Please sign in to comment.