Skip to content

Commit

Permalink
[CORE] Simplify code of offload scan (#8144)
Browse files Browse the repository at this point in the history
  • Loading branch information
zml1206 authored Dec 6, 2024
1 parent 731c5b5 commit f96105d
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ object CHBackendSettings extends BackendSettingsApi with Logging {

override def supportCartesianProductExec(): Boolean = true

override def supportCartesianProductExecWithCondition(): Boolean = false

override def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = {
t =>
if (super.supportHashBuildJoinTypeOnLeft(t)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
left: SparkPlan,
right: SparkPlan,
condition: Option[Expression]): CartesianProductExecTransformer =
if (!condition.isEmpty) {
throw new GlutenNotSupportException(
"CartesianProductExecTransformer with condition is not supported in ch backend.")
} else {
CartesianProductExecTransformer(
ColumnarCartesianProductBridge(left),
ColumnarCartesianProductBridge(right),
condition)
}
CartesianProductExecTransformer(
ColumnarCartesianProductBridge(left),
ColumnarCartesianProductBridge(right),
condition)

override def genBroadcastNestedLoopJoinExecTransformer(
left: SparkPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ trait BackendSettingsApi {

def supportCartesianProductExec(): Boolean = false

def supportCartesianProductExecWithCondition(): Boolean = true

def supportBroadcastNestedLoopJoinExec(): Boolean = true

def supportSampleExec(): Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.hive.{HiveTableScanExecTransformer, HiveUDFTransformer}
import org.apache.spark.sql.hive.HiveUDFTransformer
import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand All @@ -64,9 +64,6 @@ trait SparkPlanExecApi {
*/
def genFilterExecTransformer(condition: Expression, child: SparkPlan): FilterExecTransformerBase

def genHiveTableScanExecTransformer(plan: SparkPlan): HiveTableScanExecTransformer =
HiveTableScanExecTransformer(plan)

def genProjectExecTransformer(
projectList: Seq[NamedExpression],
child: SparkPlan): ProjectExecTransformer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ abstract class BatchScanExecTransformerBase(
}

override def doValidateInternal(): ValidationResult = {
if (!ScanTransformerFactory.supportedBatchScan(scan)) {
return ValidationResult.failed(s"Unsupported scan $scan")
}

if (pushedAggregate.nonEmpty) {
return ValidationResult.failed(s"Unsupported aggregation push down for $scan.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ case class CartesianProductExecTransformer(
}

override protected def doValidateInternal(): ValidationResult = {
if (
!BackendsApiManager.getSettings.supportCartesianProductExecWithCondition() &&
condition.nonEmpty
) {
return ValidationResult.failed(
"CartesianProductExecTransformer with condition is not supported in this backend.")
}

if (!BackendsApiManager.getSettings.supportCartesianProductExec()) {
return ValidationResult.failed("Cartesian product is not supported in this backend")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}

import java.util.ServiceLoader
Expand Down Expand Up @@ -58,8 +56,7 @@ object ScanTransformerFactory {
}
}

private def lookupBatchScanTransformer(
batchScanExec: BatchScanExec): BatchScanExecTransformerBase = {
def createBatchScanTransformer(batchScanExec: BatchScanExec): BatchScanExecTransformerBase = {
val scan = batchScanExec.scan
lookupDataSourceScanTransformer(scan.getClass.getName) match {
case Some(clz) =>
Expand All @@ -69,46 +66,16 @@ object ScanTransformerFactory {
.asInstanceOf[DataSourceScanTransformerRegister]
.createDataSourceV2Transformer(batchScanExec)
case _ =>
scan match {
case _: FileScan =>
BatchScanExecTransformer(
batchScanExec.output,
batchScanExec.scan,
batchScanExec.runtimeFilters,
table = SparkShimLoader.getSparkShims.getBatchScanExecTable(batchScanExec)
)
case _ =>
throw new GlutenNotSupportException(s"Unsupported scan $scan")
}
}
}

def createBatchScanTransformer(
batchScan: BatchScanExec,
validation: Boolean = false): SparkPlan = {
if (supportedBatchScan(batchScan.scan)) {
val transformer = lookupBatchScanTransformer(batchScan)
if (!validation) {
val validationResult = transformer.doValidate()
if (validationResult.ok()) {
transformer
} else {
FallbackTags.add(batchScan, validationResult.reason())
batchScan
}
} else {
transformer
}
} else {
if (validation) {
throw new GlutenNotSupportException(s"Unsupported scan ${batchScan.scan}")
}
FallbackTags.add(batchScan, "The scan in BatchScanExec is not supported.")
batchScan
BatchScanExecTransformer(
batchScanExec.output,
batchScanExec.scan,
batchScanExec.runtimeFilters,
table = SparkShimLoader.getSparkShims.getBatchScanExecTable(batchScanExec)
)
}
}

private def supportedBatchScan(scan: Scan): Boolean = scan match {
def supportedBatchScan(scan: Scan): Boolean = scan match {
case _: FileScan => true
case _ => lookupDataSourceScanTransformer(scan.getClass.getName).nonEmpty
}
Expand All @@ -132,5 +99,4 @@ object ScanTransformerFactory {
)
Option(clz)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.gluten.extension.columnar.offload

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.logging.LogLevelUtil
Expand Down Expand Up @@ -178,7 +177,7 @@ object OffloadJoin {
// Other transformations.
case class OffloadOthers() extends OffloadSingleNode with LogLevelUtil {
import OffloadOthers._
private val replace = new ReplaceSingleNode()
private val replace = new ReplaceSingleNode

override def offload(plan: SparkPlan): SparkPlan = replace.doReplace(plan)
}
Expand All @@ -190,7 +189,7 @@ object OffloadOthers {
// Do not look up on children on the input node in this rule. Otherwise
// it may break RAS which would group all the possible input nodes to
// search for validate candidates.
private class ReplaceSingleNode() extends LogLevelUtil with Logging {
private class ReplaceSingleNode extends LogLevelUtil with Logging {

def doReplace(p: SparkPlan): SparkPlan = {
val plan = p
Expand All @@ -199,11 +198,15 @@ object OffloadOthers {
}
plan match {
case plan: BatchScanExec =>
applyScanTransformer(plan)
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ScanTransformerFactory.createBatchScanTransformer(plan)
case plan: FileSourceScanExec =>
applyScanTransformer(plan)
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ScanTransformerFactory.createFileSourceScanTransformer(plan)
case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
applyScanTransformer(plan)
// TODO: Add DynamicPartitionPruningHiveScanSuite.scala
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HiveTableScanExecTransformer(plan)
case plan: CoalesceExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ColumnarCoalesceExec(plan.numPartitions, plan.child)
Expand Down Expand Up @@ -333,40 +336,5 @@ object OffloadOthers {
case other => other
}
}

/**
* Apply scan transformer for file source and batch source,
* 1. create new filter and scan transformer, 2. validate, tag new scan as unsupported if
* failed, 3. return new source.
*/
private def applyScanTransformer(plan: SparkPlan): SparkPlan = plan match {
case plan: FileSourceScanExec =>
val transformer = ScanTransformerFactory.createFileSourceScanTransformer(plan)
val validationResult = transformer.doValidate()
if (validationResult.ok()) {
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
transformer
} else {
logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.")
FallbackTags.add(plan, validationResult.reason())
plan
}
case plan: BatchScanExec =>
ScanTransformerFactory.createBatchScanTransformer(plan)
case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
// TODO: Add DynamicPartitionPruningHiveScanSuite.scala
val hiveTableScanExecTransformer =
BackendsApiManager.getSparkPlanExecApiInstance.genHiveTableScanExecTransformer(plan)
val validateResult = hiveTableScanExecTransformer.doValidate()
if (validateResult.ok()) {
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
return hiveTableScanExecTransformer
}
logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.")
FallbackTags.add(plan, validateResult.reason())
plan
case other =>
throw new GlutenNotSupportException(s"${other.getClass.toString} is not supported.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.gluten.extension.columnar.validator

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.{BackendsApiManager, BackendSettingsApi}
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression.ExpressionUtils
import org.apache.gluten.extension.columnar.FallbackTags
Expand Down Expand Up @@ -95,7 +94,7 @@ object Validators {
* native validation failed.
*/
def fallbackByNativeValidation(): Validator.Builder = {
builder.add(new FallbackByNativeValidation())
builder.add(new FallbackByNativeValidation)
}
}

Expand Down Expand Up @@ -223,34 +222,16 @@ object Validators {
}
}

private class FallbackByNativeValidation() extends Validator with Logging {
override def validate(plan: SparkPlan): Validator.OutCome = {
try {
validate0(plan)
} catch {
case e @ (_: GlutenNotSupportException | _: UnsupportedOperationException) =>
if (!e.isInstanceOf[GlutenNotSupportException]) {
logDebug("Just a warning. This exception perhaps needs to be fixed.", e)
}
fail(
s"${e.getMessage}, original Spark plan is " +
s"${plan.getClass}(${plan.children.toList.map(_.getClass)})")
}
}

private def validate0(plan: SparkPlan): Validator.OutCome = plan match {
private class FallbackByNativeValidation extends Validator with Logging {
override def validate(plan: SparkPlan): Validator.OutCome = plan match {
case plan: BatchScanExec =>
val transformer =
ScanTransformerFactory
.createBatchScanTransformer(plan, validation = true)
.asInstanceOf[BasicScanExecTransformer]
val transformer = ScanTransformerFactory.createBatchScanTransformer(plan)
transformer.doValidate().toValidatorOutcome()
case plan: FileSourceScanExec =>
val transformer =
ScanTransformerFactory.createFileSourceScanTransformer(plan)
val transformer = ScanTransformerFactory.createFileSourceScanTransformer(plan)
transformer.doValidate().toValidatorOutcome()
case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
HiveTableScanExecTransformer.validate(plan).toValidatorOutcome()
HiveTableScanExecTransformer(plan).doValidate().toValidatorOutcome()
case plan: ProjectExec =>
val transformer = ProjectExecTransformer(plan.projectList, plan.child)
transformer.doValidate().toValidatorOutcome()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.execution.BasicScanExecTransformer
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat

Expand Down Expand Up @@ -181,8 +180,8 @@ case class HiveTableScanExecTransformer(

object HiveTableScanExecTransformer {

val NULL_VALUE: Char = 0x00
val DEFAULT_FIELD_DELIMITER: Char = 0x01
private val NULL_VALUE: Char = 0x00
private val DEFAULT_FIELD_DELIMITER: Char = 0x01
val TEXT_INPUT_FORMAT_CLASS: Class[TextInputFormat] =
Utils.classForName("org.apache.hadoop.mapred.TextInputFormat")
val ORC_INPUT_FORMAT_CLASS: Class[OrcInputFormat] =
Expand All @@ -193,24 +192,6 @@ object HiveTableScanExecTransformer {
plan.isInstanceOf[HiveTableScanExec]
}

def copyWith(plan: SparkPlan, newPartitionFilters: Seq[Expression]): SparkPlan = {
val hiveTableScanExec = plan.asInstanceOf[HiveTableScanExec]
hiveTableScanExec.copy(partitionPruningPred = newPartitionFilters)(sparkSession =
hiveTableScanExec.session)
}

def validate(plan: SparkPlan): ValidationResult = {
plan match {
case hiveTableScan: HiveTableScanExec =>
val hiveTableScanTransformer = new HiveTableScanExecTransformer(
hiveTableScan.requestedAttributes,
hiveTableScan.relation,
hiveTableScan.partitionPruningPred)(hiveTableScan.session)
hiveTableScanTransformer.doValidate()
case _ => ValidationResult.failed("Is not a Hive scan")
}
}

def apply(plan: SparkPlan): HiveTableScanExecTransformer = {
plan match {
case hiveTableScan: HiveTableScanExec =>
Expand Down

0 comments on commit f96105d

Please sign in to comment.