Skip to content

Commit

Permalink
[VL][Core] SampleExec Operator Native Support (apache#5856)
Browse files Browse the repository at this point in the history
[VL] SampleExec Operator Native Support.
  • Loading branch information
gaoyangxiaozhu authored May 28, 2024
1 parent 291f084 commit 729d345
Show file tree
Hide file tree
Showing 15 changed files with 262 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,17 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
s"NestedLoopJoinTransformer metrics update is not supported in CH backend")
}

override def genSampleTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = {
throw new UnsupportedOperationException(
s"SampleTransformer metrics update is not supported in CH backend")
}

override def genSampleTransformerMetricsUpdater(
metrics: Map[String, SQLMetric]): MetricsUpdater = {
throw new UnsupportedOperationException(
s"SampleTransformer metrics update is not supported in CH backend")
}

def genWriteFilesTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = {
throw new UnsupportedOperationException(
s"WriteFilesTransformer metrics update is not supported in CH backend")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,14 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
throw new GlutenNotSupportException(
"BroadcastNestedLoopJoinExecTransformer is not supported in ch backend.")

override def genSampleExecTransformer(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: SparkPlan): SampleExecTransformer =
throw new GlutenNotSupportException("SampleExecTransformer is not supported in ch backend.")

/** Generate an expression transformer to transform GetMapValue to Substrait. */
def genGetMapValueTransformer(
substraitExprName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ object VeloxBackendSettings extends BackendSettingsApi {

override def supportBroadcastNestedLoopJoinExec(): Boolean = true

override def supportSampleExec(): Boolean = true

override def supportColumnarArrowUdf(): Boolean = true

override def generateHdfsConfForLibhdfs(): Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,4 +540,20 @@ class VeloxMetricsApi extends MetricsApi with Logging {

override def genNestedLoopJoinTransformerMetricsUpdater(
metrics: Map[String, SQLMetric]): MetricsUpdater = new NestedLoopJoinMetricsUpdater(metrics)

override def genSampleTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] =
Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"outputVectors" -> SQLMetrics.createMetric(sparkContext, "number of output vectors"),
"outputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of output bytes"),
"wallNanos" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime of sample"),
"cpuCount" -> SQLMetrics.createMetric(sparkContext, "cpu wall time count"),
"peakMemoryBytes" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory bytes"),
"numMemoryAllocations" -> SQLMetrics.createMetric(
sparkContext,
"number of memory allocations")
)

override def genSampleTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater =
new SampleMetricsUpdater(metrics)
}
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,15 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
right,
isNullAwareAntiJoin)

override def genSampleExecTransformer(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: SparkPlan): SampleExecTransformer = {
SampleExecTransformer(lowerBound, upperBound, withReplacement, seed, child)
}

override def genSortMergeJoinExecTransformer(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,18 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
}
}

test("Test sample op") {
withSQLConf("spark.gluten.sql.columnarSampleEnabled" -> "true") {
withTable("t") {
sql("create table t (id int, b boolean) using parquet")
sql("insert into t values (1, true), (2, false), (3, null), (4, true), (5, false)")
runQueryAndCompare("select * from t TABLESAMPLE(20 PERCENT)", false) {
checkGlutenOperatorMatch[SampleExecTransformer]
}
}
}
}

test("test cross join") {
withTable("t1", "t2") {
sql("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ trait BackendSettingsApi {

def supportBroadcastNestedLoopJoinExec(): Boolean = false

def supportSampleExec(): Boolean = false

/** Merge two phases hash based aggregate if need */
def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ trait MetricsApi extends Serializable {

def genNestedLoopJoinTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater

def genSampleTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric]

def genSampleTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater

def genColumnarInMemoryTableMetrics(sparkContext: SparkContext): Map[String, SQLMetric] =
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ trait SparkPlanExecApi {
right: SparkPlan,
isNullAwareAntiJoin: Boolean = false): BroadcastHashJoinExecTransformerBase

def genSampleExecTransformer(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: SparkPlan): SampleExecTransformer

/** Generate ShuffledHashJoinExecTransformer. */
def genSortMergeJoinExecTransformer(
leftKeys: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, LessThan, Literal, Rand}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.DoubleType

import scala.collection.JavaConverters._

/**
* SampleExec supports two sampling methods: with replacement and without replacement. This
* transformer currently supports only sampling without replacement. For sampling without
* replacement, sampleExec uses `seed + partitionId` as the seed for each partition. The `upperBound
* \- lowerBound` value is used as the fraction, and the XORShiftRandom number generator is
* employed. Each row undergoes a Bernoulli trial, and if the generated random number falls within
* the range [lowerBound, upperBound), the row is included; otherwise, it is skipped.
*
* This transformer converts sampleExec to a Substrait Filter relation, achieving a similar sampling
* effect through the filter op with rand sampling expression. Specifically, the `upperBound -
* lowerBound` value is used as the fraction, and the node be translated to `filter(rand(seed +
* partitionId) < fraction)` for random sampling.
*/
case class SampleExecTransformer(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: SparkPlan)
extends UnaryTransformSupport
with Logging {
def fraction: Double = upperBound - lowerBound

def condition: Expression = {
val randExpr: Expression = Rand(seed)
val sampleRateExpr: Expression = Literal(fraction, DoubleType)
LessThan(randExpr, sampleRateExpr)
}

override def output: Seq[Attribute] = child.output

// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
@transient override lazy val metrics =
BackendsApiManager.getMetricsApiInstance.genSampleTransformerMetrics(sparkContext)

override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genSampleTransformerMetricsUpdater(metrics)

def getRelNode(
context: SubstraitContext,
condExpr: Expression,
originalInputAttributes: Seq[Attribute],
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
assert(condExpr != null)
val args = context.registeredFunction
val condExprNode = ExpressionConverter
.replaceWithExpressionTransformer(condExpr, attributeSeq = originalInputAttributes)
.doTransform(args)

if (!validation) {
RelBuilder.makeFilterRel(input, condExprNode, context, operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeFilterRel(input, condExprNode, extensionNode, context, operatorId)
}
}

override protected def doValidateInternal(): ValidationResult = {
if (withReplacement) {
return ValidationResult.notOk(
"Unsupported sample exec in native with " +
s"withReplacement parameter is $withReplacement")
}
val substraitContext = new SubstraitContext
val operatorId = substraitContext.nextOperatorId((this.nodeName))
// Firstly, need to check if the Substrait plan for this operator can be successfully generated.
val relNode =
getRelNode(substraitContext, condition, child.output, operatorId, null, validation = true)
// Then, validate the generated plan in native engine.
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
val operatorId = context.nextOperatorId(this.nodeName)
val currRel =
getRelNode(context, condition, child.output, operatorId, childCtx.root, validation = false)
assert(currRel != null, "Filter rel should be valid.")
TransformContext(childCtx.outputAttributes, output, currRel)
}

override protected def withNewChildInternal(newChild: SparkPlan): SampleExecTransformer =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,15 @@ object OffloadOthers {
child,
plan.evalType)
}
case plan: SampleExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
val child = plan.child
BackendsApiManager.getSparkPlanExecApiInstance.genSampleExecTransformer(
plan.lowerBound,
plan.upperBound,
plan.withReplacement,
plan.seed,
child)
case p if !p.isInstanceOf[GlutenPlan] =>
logDebug(s"Transformation for ${p.getClass} is currently not supported.")
val children = plan.children
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,15 @@ case class AddTransformHintRule() extends Rule[SparkPlan] {
plan.child,
offset)
transformer.doValidate().tagOnFallback(plan)
case plan: SampleExec =>
val transformer = BackendsApiManager.getSparkPlanExecApiInstance.genSampleExecTransformer(
plan.lowerBound,
plan.upperBound,
plan.withReplacement,
plan.seed,
plan.child
)
transformer.doValidate().tagOnFallback(plan)
case _ =>
// Currently we assume a plan to be transformable by default.
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ object Validators {
case p
if HiveTableScanExecTransformer.isHiveTableScan(p) && !conf.enableColumnarHiveTableScan =>
fail(p)
case p: SampleExec
if !(conf.enableColumnarSample && BackendsApiManager.getSettings.supportSampleExec()) =>
fail(p)
case _ => pass()
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.metrics

import org.apache.spark.sql.execution.metric.SQLMetric

class SampleMetricsUpdater(val metrics: Map[String, SQLMetric]) extends MetricsUpdater {

override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = {
if (opMetrics != null) {
val operatorMetrics = opMetrics.asInstanceOf[OperatorMetrics]
metrics("numOutputRows") += operatorMetrics.outputRows
metrics("outputVectors") += operatorMetrics.outputVectors
metrics("outputBytes") += operatorMetrics.outputBytes
metrics("cpuCount") += operatorMetrics.cpuCount
metrics("wallNanos") += operatorMetrics.wallNanos
metrics("peakMemoryBytes") += operatorMetrics.peakMemoryBytes
metrics("numMemoryAllocations") += operatorMetrics.numMemoryAllocations
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {

def enableColumnarBroadcastJoin: Boolean = conf.getConf(COLUMNAR_BROADCAST_JOIN_ENABLED)

def enableColumnarSample: Boolean = conf.getConf(COLUMNAR_SAMPLE_ENABLED)

def enableColumnarArrowUDF: Boolean = conf.getConf(COLUMNAR_ARROW_UDF_ENABLED)

def enableColumnarCoalesce: Boolean = conf.getConf(COLUMNAR_COALESCE_ENABLED)
Expand Down Expand Up @@ -1772,6 +1774,13 @@ object GlutenConfig {
.booleanConf
.createWithDefault(true)

val COLUMNAR_SAMPLE_ENABLED =
buildConf("spark.gluten.sql.columnarSampleEnabled")
.internal()
.doc("Disable or enable columnar sample.")
.booleanConf
.createWithDefault(false)

val CACHE_WHOLE_STAGE_TRANSFORMER_CONTEXT =
buildConf("spark.gluten.sql.cacheWholeStageTransformerContext")
.internal()
Expand Down

0 comments on commit 729d345

Please sign in to comment.