From 917653a210cc60b31f794ae3719d1291bd2056be Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 3 Jul 2024 12:08:54 +0800 Subject: [PATCH] fixup fixup fixup fixup fixup fixup fixup --- .../backendsapi/clickhouse/CHBackend.scala | 18 +-- .../execution/CHGenerateExecTransformer.scala | 2 +- .../execution/CHHashJoinExecTransformer.scala | 6 +- .../CHSortMergeJoinExecTransformer.scala | 2 +- .../FallbackBroadcaseHashJoinRules.scala | 10 +- .../sql/execution/CHColumnarToRowExec.scala | 2 +- .../backendsapi/velox/VeloxBackend.scala | 14 +- .../velox/VeloxSparkPlanExecApi.scala | 8 +- .../execution/GenerateExecTransformer.scala | 4 +- .../execution/VeloxColumnarToRowExec.scala | 2 +- .../validate/NativePlanValidationInfo.java | 16 ++- .../backendsapi/BackendSettingsApi.scala | 4 +- .../BasicPhysicalOperatorTransformer.scala | 7 +- .../execution/BasicScanExecTransformer.scala | 2 +- .../execution/BatchScanExecTransformer.scala | 6 +- ...oadcastNestedLoopJoinExecTransformer.scala | 7 +- .../CartesianProductExecTransformer.scala | 4 +- .../execution/ExpandExecTransformer.scala | 4 +- .../FileSourceScanExecTransformer.scala | 8 +- .../GenerateExecTransformerBase.scala | 2 +- .../HashAggregateExecBaseTransformer.scala | 4 +- .../execution/JoinExecTransformer.scala | 2 +- .../execution/SampleExecTransformer.scala | 2 +- .../execution/ScanTransformerFactory.scala | 4 +- .../execution/SortExecTransformer.scala | 2 +- .../SortMergeJoinExecTransformer.scala | 4 +- ...TakeOrderedAndProjectExecTransformer.scala | 6 +- .../execution/WindowExecTransformer.scala | 2 +- .../WindowGroupLimitExecTransformer.scala | 2 +- .../execution/WriteFilesExecTransformer.scala | 4 +- .../apache/gluten/extension/GlutenPlan.scala | 39 ++--- .../CollapseProjectExecTransformer.scala | 4 +- .../EnsureLocalSortRequirements.scala | 2 +- .../columnar/ExpandFallbackPolicy.scala | 8 +- ...lbackTagRule.scala => FallbackRules.scala} | 135 ++++++++++-------- .../columnar/OffloadSingleNode.scala | 10 +- ...RemoveNativeWriteFilesSortAndProject.scala | 2 +- .../enumerated/PushFilterToScan.scala | 2 +- .../columnar/enumerated/RasOffload.scala | 2 +- .../RewriteSparkPlanRulesManager.scala | 8 +- .../columnar/validator/Validators.scala | 6 +- .../ColumnarBroadcastExchangeExec.scala | 6 +- .../ColumnarShuffleExchangeExec.scala | 7 +- .../sql/execution/GlutenExplainUtils.scala | 6 +- .../execution/GlutenFallbackReporter.scala | 52 ++----- .../python/EvalPythonExecTransformer.scala | 2 +- .../hive/HiveTableScanExecTransformer.scala | 2 +- 47 files changed, 223 insertions(+), 230 deletions(-) rename gluten-core/src/main/scala/org/apache/gluten/extension/columnar/{FallbackTagRule.scala => FallbackRules.scala} (88%) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index cdca1b031a915..99db20d756dcd 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -172,20 +172,20 @@ object CHBackendSettings extends BackendSettingsApi with Logging { format match { case ParquetReadFormat => if (validateFilePath) { - ValidationResult.ok + ValidationResult.succeeded } else { - ValidationResult.notOk("Validate file path failed.") + ValidationResult.failed("Validate file path failed.") } - case OrcReadFormat => ValidationResult.ok - case MergeTreeReadFormat => ValidationResult.ok + case OrcReadFormat => ValidationResult.succeeded + case MergeTreeReadFormat => ValidationResult.succeeded case TextReadFormat => if (!hasComplexType) { - ValidationResult.ok + ValidationResult.succeeded } else { - ValidationResult.notOk("Has complex type.") + ValidationResult.failed("Has complex type.") } - case JsonReadFormat => ValidationResult.ok - case _ => ValidationResult.notOk(s"Unsupported file format $format") + case JsonReadFormat => ValidationResult.succeeded + case _ => ValidationResult.failed(s"Unsupported file format $format") } } @@ -290,7 +290,7 @@ object CHBackendSettings extends BackendSettingsApi with Logging { fields: Array[StructField], bucketSpec: Option[BucketSpec], options: Map[String, String]): ValidationResult = - ValidationResult.notOk("CH backend is unsupported.") + ValidationResult.failed("CH backend is unsupported.") override def enableNativeWriteFiles(): Boolean = { GlutenConfig.getConf.enableNativeWriter.getOrElse(false) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala index 733c0a472814d..44cb0deca5230 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala @@ -64,7 +64,7 @@ case class CHGenerateExecTransformer( override protected def doGeneratorValidate( generator: Generator, outer: Boolean): ValidationResult = - ValidationResult.ok + ValidationResult.succeeded override protected def getRelNode( context: SubstraitContext, diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index da9d9c7586c05..48870892d290e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -60,7 +60,7 @@ case class CHShuffledHashJoinExecTransformer( right.outputSet, condition) if (shouldFallback) { - return ValidationResult.notOk("ch join validate fail") + return ValidationResult.failed("ch join validate fail") } super.doValidateInternal() } @@ -118,10 +118,10 @@ case class CHBroadcastHashJoinExecTransformer( condition) if (shouldFallback) { - return ValidationResult.notOk("ch join validate fail") + return ValidationResult.failed("ch join validate fail") } if (isNullAwareAntiJoin) { - return ValidationResult.notOk("ch does not support NAAJ") + return ValidationResult.failed("ch does not support NAAJ") } super.doValidateInternal() } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala index e2b5865517391..670fbed693003 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala @@ -50,7 +50,7 @@ case class CHSortMergeJoinExecTransformer( right.outputSet, condition) if (shouldFallback) { - return ValidationResult.notOk("ch join validate fail") + return ValidationResult.failed("ch join validate fail") } super.doValidateInternal() } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala index 59c2d6494bdba..2628665b47583 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala @@ -57,12 +57,12 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend !columnarConf.enableColumnarBroadcastExchange || !columnarConf.enableColumnarBroadcastJoin ) { - ValidationResult.notOk( + ValidationResult.failed( "columnar broadcast exchange is disabled or " + "columnar broadcast join is disabled") } else { if (FallbackTags.nonEmpty(bhj)) { - ValidationResult.notOk("broadcast join is already tagged as not transformable") + ValidationResult.failed("broadcast join is already tagged as not transformable") } else { val bhjTransformer = BackendsApiManager.getSparkPlanExecApiInstance .genBroadcastHashJoinExecTransformer( @@ -75,7 +75,7 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend bhj.right, bhj.isNullAwareAntiJoin) val isBhjTransformable = bhjTransformer.doValidate() - if (isBhjTransformable.isValid) { + if (isBhjTransformable.ok()) { val exchangeTransformer = ColumnarBroadcastExchangeExec(mode, child) exchangeTransformer.doValidate() } else { @@ -148,7 +148,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl maybeExchange match { case Some(exchange @ BroadcastExchangeExec(mode, child)) => isBhjTransformable.tagOnFallback(bhj) - if (!isBhjTransformable.isValid) { + if (!isBhjTransformable.ok()) { FallbackTags.add(exchange, isBhjTransformable) } case None => @@ -186,7 +186,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl bhj, "it's a materialized broadcast exchange or reused broadcast exchange") case ColumnarBroadcastExchangeExec(mode, child) => - if (!isBhjTransformable.isValid) { + if (!isBhjTransformable.ok()) { throw new IllegalStateException( s"BroadcastExchange has already been" + s" transformed to columnar version but BHJ is determined as" + diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala index 522c7d68fa68b..29fa0d0aba960 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala @@ -58,7 +58,7 @@ case class CHColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBase(c s"${field.dataType} is not supported in ColumnarToRowExecBase.") } } - ValidationResult.ok + ValidationResult.succeeded } override def doExecuteInternal(): RDD[InternalRow] = { diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 0238508d96995..0d94cbef657d1 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -79,9 +79,9 @@ object VeloxBackendSettings extends BackendSettingsApi { // Collect unsupported types. val unsupportedDataTypeReason = fields.collect(validatorFunc) if (unsupportedDataTypeReason.isEmpty) { - ValidationResult.ok + ValidationResult.succeeded } else { - ValidationResult.notOk( + ValidationResult.failed( s"Found unsupported data type in $format: ${unsupportedDataTypeReason.mkString(", ")}.") } } @@ -135,10 +135,10 @@ object VeloxBackendSettings extends BackendSettingsApi { } else { validateTypes(parquetTypeValidatorWithComplexTypeFallback) } - case DwrfReadFormat => ValidationResult.ok + case DwrfReadFormat => ValidationResult.succeeded case OrcReadFormat => if (!GlutenConfig.getConf.veloxOrcScanEnabled) { - ValidationResult.notOk(s"Velox ORC scan is turned off.") + ValidationResult.failed(s"Velox ORC scan is turned off.") } else { val typeValidator: PartialFunction[StructField, String] = { case StructField(_, arrayType: ArrayType, _, _) @@ -164,7 +164,7 @@ object VeloxBackendSettings extends BackendSettingsApi { validateTypes(orcTypeValidatorWithComplexTypeFallback) } } - case _ => ValidationResult.notOk(s"Unsupported file format for $format.") + case _ => ValidationResult.failed(s"Unsupported file format for $format.") } } @@ -284,8 +284,8 @@ object VeloxBackendSettings extends BackendSettingsApi { .orElse(validateDataTypes()) .orElse(validateWriteFilesOptions()) .orElse(validateBucketSpec()) match { - case Some(reason) => ValidationResult.notOk(reason) - case _ => ValidationResult.ok + case Some(reason) => ValidationResult.failed(reason) + case _ => ValidationResult.succeeded } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index e13ebd971ef55..deaaa9273b708 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -367,7 +367,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ child.output val projectTransformer = ProjectExecTransformer(projectList, child) val validationResult = projectTransformer.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { val newChild = maybeAddAppendBatchesExec(projectTransformer) ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output.drop(1)) } else { @@ -393,7 +393,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { val projectTransformer = ProjectExecTransformer(projectList, child) val projectBeforeSortValidationResult = projectTransformer.doValidate() // Make sure we support offload hash expression - val projectBeforeSort = if (projectBeforeSortValidationResult.isValid) { + val projectBeforeSort = if (projectBeforeSortValidationResult.ok()) { projectTransformer } else { val project = ProjectExec(projectList, child) @@ -406,7 +406,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { val dropSortColumnTransformer = ProjectExecTransformer(projectList.drop(1), sortByHashCode) val validationResult = dropSortColumnTransformer.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { val newChild = maybeAddAppendBatchesExec(dropSortColumnTransformer) ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output) } else { @@ -891,7 +891,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { case p @ LimitTransformer(SortExecTransformer(sortOrder, _, child, _), 0, count) => val global = child.outputPartitioning.satisfies(AllTuples) val topN = TopNTransformer(count, sortOrder, global, child) - if (topN.doValidate().isValid) { + if (topN.doValidate().ok()) { topN } else { p diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala index 8ceea8c14f6ad..c7b81d55fa067 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala @@ -72,11 +72,11 @@ case class GenerateExecTransformer( generator: Generator, outer: Boolean): ValidationResult = { if (!supportsGenerate(generator, outer)) { - ValidationResult.notOk( + ValidationResult.failed( s"Velox backend does not support this generator: ${generator.getClass.getSimpleName}" + s", outer: $outer") } else { - ValidationResult.ok + ValidationResult.succeeded } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala index 1a54255208ea4..2c46893e4576a 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala @@ -66,7 +66,7 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas s"VeloxColumnarToRowExec.") } } - ValidationResult.ok + ValidationResult.succeeded } override def doExecuteInternal(): RDD[InternalRow] = { diff --git a/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java b/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java index 12f050c660f43..9cfad44d60f52 100644 --- a/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java +++ b/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java @@ -16,6 +16,8 @@ */ package org.apache.gluten.validate; +import org.apache.gluten.extension.ValidationResult; + import java.util.Vector; public class NativePlanValidationInfo { @@ -30,11 +32,13 @@ public NativePlanValidationInfo(int isSupported, String fallbackInfo) { } } - public boolean isSupported() { - return isSupported == 1; - } - - public Vector getFallbackInfo() { - return fallbackInfo; + public ValidationResult asResult() { + if (isSupported == 1) { + return ValidationResult.succeeded(); + } + return ValidationResult.failed( + String.format( + "Native validation failed: %n%s", + fallbackInfo.stream().reduce((l, r) -> l + "\n" + r))); } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index d159486373ace..874d62a53ea43 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -35,12 +35,12 @@ trait BackendSettingsApi { format: ReadFileFormat, fields: Array[StructField], partTable: Boolean, - paths: Seq[String]): ValidationResult = ValidationResult.ok + paths: Seq[String]): ValidationResult = ValidationResult.succeeded def supportWriteFilesExec( format: FileFormat, fields: Array[StructField], bucketSpec: Option[BucketSpec], - options: Map[String, String]): ValidationResult = ValidationResult.ok + options: Map[String, String]): ValidationResult = ValidationResult.succeeded def supportNativeWrite(fields: Array[StructField]): Boolean = true def supportNativeMetadataColumns(): Boolean = false def supportNativeRowIndexColumn(): Boolean = false diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala index 0b792d52e0561..7ba9895e91f07 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala @@ -113,7 +113,7 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP if (remainingCondition == null) { // All the filters can be pushed down and the computing of this Filter // is not needed. - return ValidationResult.ok + return ValidationResult.succeeded } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) @@ -316,9 +316,10 @@ case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan with Gl BackendsApiManager.getValidatorApiInstance .doSchemaValidate(schema) .map { - reason => ValidationResult.notOk(s"Found schema check failure for $schema, due to: $reason") + reason => + ValidationResult.failed(s"Found schema check failure for $schema, due to: $reason") } - .getOrElse(ValidationResult.ok) + .getOrElse(ValidationResult.succeeded) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala index 64071fb14c0c0..99f145eeab1c9 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala @@ -88,7 +88,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource val validationResult = BackendsApiManager.getSettings .supportFileFormatRead(fileFormat, fields, getPartitionSchema.nonEmpty, getInputFilePaths) - if (!validationResult.isValid) { + if (!validationResult.ok()) { return validationResult } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala index 6bff68895a249..4860847de9acd 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala @@ -133,18 +133,18 @@ abstract class BatchScanExecTransformerBase( override def doValidateInternal(): ValidationResult = { if (pushedAggregate.nonEmpty) { - return ValidationResult.notOk(s"Unsupported aggregation push down for $scan.") + return ValidationResult.failed(s"Unsupported aggregation push down for $scan.") } if ( SparkShimLoader.getSparkShims.findRowIndexColumnIndexInSchema(schema) > 0 && !BackendsApiManager.getSettings.supportNativeRowIndexColumn() ) { - return ValidationResult.notOk("Unsupported row index column scan in native.") + return ValidationResult.failed("Unsupported row index column scan in native.") } if (hasUnsupportedColumns) { - return ValidationResult.notOk(s"Unsupported columns scan in native.") + return ValidationResult.failed(s"Unsupported columns scan in native.") } super.doValidateInternal() diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index 092612ea73407..83b7967813213 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -147,14 +147,15 @@ abstract class BroadcastNestedLoopJoinExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportBroadcastNestedLoopJoinExec()) { - return ValidationResult.notOk("Broadcast Nested Loop join is not supported in this backend") + return ValidationResult.failed("Broadcast Nested Loop join is not supported in this backend") } if (substraitJoinType == CrossRel.JoinType.UNRECOGNIZED) { - return ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") + return ValidationResult.failed( + s"$joinType join is not supported with BroadcastNestedLoopJoin") } (joinType, buildSide) match { case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => - return ValidationResult.notOk(s"$joinType join is not supported with $buildSide") + return ValidationResult.failed(s"$joinType join is not supported with $buildSide") case _ => // continue } val substraitContext = new SubstraitContext diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala index 91831f18493ad..c914d6e1c576e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala @@ -24,7 +24,7 @@ import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.rel.RelBuilder import org.apache.gluten.utils.SubstraitUtil -import org.apache.spark.{Dependency, NarrowDependency, Partition, SparkContext, TaskContext} +import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} @@ -112,7 +112,7 @@ case class CartesianProductExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportCartesianProductExec()) { - return ValidationResult.notOk("Cartesian product is not supported in this backend") + return ValidationResult.failed("Cartesian product is not supported in this backend") } val substraitContext = new SubstraitContext val expressionNode = condition.map { diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala index 362debb531ee6..63f76a25a2318 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala @@ -95,10 +95,10 @@ case class ExpandExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportExpandExec()) { - return ValidationResult.notOk("Current backend does not support expand") + return ValidationResult.failed("Current backend does not support expand") } if (projections.isEmpty) { - return ValidationResult.notOk("Current backend does not support empty projections in expand") + return ValidationResult.failed("Current backend does not support empty projections in expand") } val substraitContext = new SubstraitContext diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala index 4f120488c2fb5..3b8ed1167afcf 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala @@ -130,23 +130,23 @@ abstract class FileSourceScanExecTransformerBase( if ( !metadataColumns.isEmpty && !BackendsApiManager.getSettings.supportNativeMetadataColumns() ) { - return ValidationResult.notOk(s"Unsupported metadata columns scan in native.") + return ValidationResult.failed(s"Unsupported metadata columns scan in native.") } if ( SparkShimLoader.getSparkShims.findRowIndexColumnIndexInSchema(schema) > 0 && !BackendsApiManager.getSettings.supportNativeRowIndexColumn() ) { - return ValidationResult.notOk("Unsupported row index column scan in native.") + return ValidationResult.failed("Unsupported row index column scan in native.") } if (hasUnsupportedColumns) { - return ValidationResult.notOk(s"Unsupported columns scan in native.") + return ValidationResult.failed(s"Unsupported columns scan in native.") } if (hasFieldIds) { // Spark read schema expects field Ids , the case didn't support yet by native. - return ValidationResult.notOk( + return ValidationResult.failed( s"Unsupported matching schema column names " + s"by field ids in native scan.") } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala index b5c9b85aeb0d5..af4a92f194c1b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala @@ -67,7 +67,7 @@ abstract class GenerateExecTransformerBase( override protected def doValidateInternal(): ValidationResult = { val validationResult = doGeneratorValidate(generator, outer) - if (!validationResult.isValid) { + if (!validationResult.ok()) { return validationResult } val context = new SubstraitContext diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala index 9345b3a3636fc..d29734d4d5091 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala @@ -117,7 +117,7 @@ abstract class HashAggregateExecBaseTransformer( val unsupportedAggExprs = aggregateAttributes.filterNot(attr => checkType(attr.dataType)) if (unsupportedAggExprs.nonEmpty) { - return ValidationResult.notOk( + return ValidationResult.failed( "Found unsupported data type in aggregation expression: " + unsupportedAggExprs .map(attr => s"${attr.name}#${attr.exprId.id}:${attr.dataType}") @@ -125,7 +125,7 @@ abstract class HashAggregateExecBaseTransformer( } val unsupportedGroupExprs = groupingExpressions.filterNot(attr => checkType(attr.dataType)) if (unsupportedGroupExprs.nonEmpty) { - return ValidationResult.notOk( + return ValidationResult.failed( "Found unsupported data type in grouping expression: " + unsupportedGroupExprs .map(attr => s"${attr.name}#${attr.exprId.id}:${attr.dataType}") diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala index cd22c578594c6..86e6c1f412656 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala @@ -210,7 +210,7 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport { // Firstly, need to check if the Substrait plan for this operator can be successfully generated. if (substraitJoinType == JoinRel.JoinType.UNRECOGNIZED) { return ValidationResult - .notOk(s"Unsupported join type of $hashJoinType for substrait: $substraitJoinType") + .failed(s"Unsupported join type of $hashJoinType for substrait: $substraitJoinType") } val relNode = JoinUtils.createJoinRel( streamedKeyExprs, diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala index 6f9ef34282bf0..bed59b913a1e9 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala @@ -99,7 +99,7 @@ case class SampleExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (withReplacement) { - return ValidationResult.notOk( + return ValidationResult.failed( "Unsupported sample exec in native with " + s"withReplacement parameter is $withReplacement") } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala index 44a823834f926..a05a5e72bfe1d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala @@ -95,11 +95,11 @@ object ScanTransformerFactory { transformer.setPushDownFilters(allPushDownFilters.get) // Validate again if allPushDownFilters is defined. val validationResult = transformer.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { transformer } else { val newSource = batchScan.copy(runtimeFilters = transformer.runtimeFilters) - FallbackTags.add(newSource, validationResult.reason.get) + FallbackTags.add(newSource, validationResult.reason()) newSource } } else { diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala index f79dc69e680b5..b69925d60fd21 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala @@ -91,7 +91,7 @@ case class SortExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportSortExec()) { - return ValidationResult.notOk("Current backend does not support sort") + return ValidationResult.failed("Current backend does not support sort") } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala index f032c4ca00879..c96789569f9ac 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala @@ -164,7 +164,7 @@ abstract class SortMergeJoinExecTransformerBase( // Firstly, need to check if the Substrait plan for this operator can be successfully generated. if (substraitJoinType == JoinRel.JoinType.UNRECOGNIZED) { return ValidationResult - .notOk(s"Found unsupported join type of $joinType for substrait: $substraitJoinType") + .failed(s"Found unsupported join type of $joinType for substrait: $substraitJoinType") } val relNode = JoinUtils.createJoinRel( streamedKeys, @@ -253,7 +253,7 @@ case class SortMergeJoinExecTransformer( // Firstly, need to check if the Substrait plan for this operator can be successfully generated. if (substraitJoinType == JoinRel.JoinType.JOIN_TYPE_OUTER) { return ValidationResult - .notOk(s"Found unsupported join type of $joinType for velox smj: $substraitJoinType") + .failed(s"Found unsupported join type of $joinType for velox smj: $substraitJoinType") } super.doValidateInternal() } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala index 74158d6332dc5..b31471e21397a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala @@ -67,7 +67,7 @@ case class TakeOrderedAndProjectExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (offset != 0) { - return ValidationResult.notOk(s"Native TopK does not support offset: $offset") + return ValidationResult.failed(s"Native TopK does not support offset: $offset") } var tagged: ValidationResult = null @@ -83,14 +83,14 @@ case class TakeOrderedAndProjectExecTransformer( ColumnarCollapseTransformStages.wrapInputIteratorTransformer(child) val sortPlan = SortExecTransformer(sortOrder, false, inputTransformer) val sortValidation = sortPlan.doValidate() - if (!sortValidation.isValid) { + if (!sortValidation.ok()) { return sortValidation } val limitPlan = LimitTransformer(sortPlan, offset, limit) tagged = limitPlan.doValidate() } - if (tagged.isValid) { + if (tagged.ok()) { val projectPlan = ProjectExecTransformer(projectList, child) tagged = projectPlan.doValidate() } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala index 6832221a404d9..eb66a972157b8 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala @@ -169,7 +169,7 @@ case class WindowExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportWindowExec(windowExpression)) { return ValidationResult - .notOk(s"Found unsupported window expression: ${windowExpression.mkString(", ")}") + .failed(s"Found unsupported window expression: ${windowExpression.mkString(", ")}") } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala index 46a4e1aa4eeec..59c9853375488 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala @@ -135,7 +135,7 @@ case class WindowGroupLimitExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction)) { return ValidationResult - .notOk(s"Found unsupported rank like function: $rankLikeFunction") + .failed(s"Found unsupported rank like function: $rankLikeFunction") } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala index 14d58bfa83771..d78f21beaabfe 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala @@ -150,8 +150,8 @@ case class WriteFilesExecTransformer( finalChildOutput.toStructType.fields, bucketSpec, caseInsensitiveOptions) - if (!validationResult.isValid) { - return ValidationResult.notOk("Unsupported native write: " + validationResult.reason.get) + if (!validationResult.ok()) { + return ValidationResult.failed("Unsupported native write: " + validationResult.reason()) } val substraitContext = new SubstraitContext diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala index 8f1004be4aaae..788f320b39843 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala @@ -26,28 +26,29 @@ import org.apache.gluten.substrait.plan.PlanBuilder import org.apache.gluten.substrait.rel.RelNode import org.apache.gluten.test.TestStats import org.apache.gluten.utils.LogLevelUtil -import org.apache.gluten.validate.NativePlanValidationInfo import org.apache.spark.sql.execution.SparkPlan import com.google.common.collect.Lists -import scala.collection.JavaConverters._ - -case class ValidationResult(isValid: Boolean, reason: Option[String]) +sealed trait ValidationResult { + def ok(): Boolean + def reason(): String +} object ValidationResult { - def ok: ValidationResult = ValidationResult(isValid = true, None) - def notOk(reason: String): ValidationResult = ValidationResult(isValid = false, Option(reason)) - def convertFromValidationInfo(info: NativePlanValidationInfo): ValidationResult = { - if (info.isSupported) { - ok - } else { - val fallbackInfo = info.getFallbackInfo.asScala - .mkString("Native validation failed:\n ", "\n ", "") - notOk(fallbackInfo) - } + private case object Succeeded extends ValidationResult { + override def ok(): Boolean = true + override def reason(): String = throw new UnsupportedOperationException( + "Succeeded validation doesn't have failure details") } + + private case class Failed(override val reason: String) extends ValidationResult { + override def ok(): Boolean = false + } + + def succeeded: ValidationResult = Succeeded + def failed(reason: String): ValidationResult = Failed(reason) } /** Every Gluten Operator should extend this trait. */ @@ -66,7 +67,7 @@ trait GlutenPlan extends SparkPlan with Convention.KnownBatchType with LogLevelU try { TransformerState.enterValidation val res = doValidateInternal() - if (!res.isValid) { + if (!res.ok()) { TestStats.addFallBackClassName(this.getClass.toString) } res @@ -80,7 +81,7 @@ trait GlutenPlan extends SparkPlan with Convention.KnownBatchType with LogLevelU logValidationMessage( s"Validation failed with exception for plan: $nodeName, due to: ${e.getMessage}", e) - ValidationResult.notOk(e.getMessage) + ValidationResult.failed(e.getMessage) } finally { TransformerState.finishValidation } @@ -99,16 +100,16 @@ trait GlutenPlan extends SparkPlan with Convention.KnownBatchType with LogLevelU BackendsApiManager.getSparkPlanExecApiInstance.batchType } - protected def doValidateInternal(): ValidationResult = ValidationResult.ok + protected def doValidateInternal(): ValidationResult = ValidationResult.succeeded protected def doNativeValidation(context: SubstraitContext, node: RelNode): ValidationResult = { if (node != null && enableNativeValidation) { val planNode = PlanBuilder.makePlan(context, Lists.newArrayList(node)) val info = BackendsApiManager.getValidatorApiInstance .doNativeValidateWithFailureReason(planNode) - ValidationResult.convertFromValidationInfo(info) + info.asResult() } else { - ValidationResult.ok + ValidationResult.succeeded } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala index 25674fd17147f..bfb926706cf7b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala @@ -40,11 +40,11 @@ object CollapseProjectExecTransformer extends Rule[SparkPlan] { val collapsedProject = p2.copy(projectList = CollapseProjectShim.buildCleanedProjectList(p1.projectList, p2.projectList)) val validationResult = collapsedProject.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { logDebug(s"Collapse project $p1 and $p2.") collapsedProject } else { - logDebug(s"Failed to collapse project, due to ${validationResult.reason.getOrElse("")}") + logDebug(s"Failed to collapse project, due to ${validationResult.reason()}") p1 } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala index afc29a51e19a7..2c369e993f356 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala @@ -47,7 +47,7 @@ object EnsureLocalSortRequirements extends Rule[SparkPlan] { newChild.child, newChild.testSpillFrequency) val validationResult = newChildWithTransformer.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { newChildWithTransformer } else { FallbackTags.add(newChild, validationResult) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala index e334fcfbce889..d48b36b5a2cca 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala @@ -243,7 +243,7 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP originalPlan .find(_.logicalLink.exists(_.fastEquals(p.logicalLink.get))) .filterNot(FallbackTags.nonEmpty) - .foreach(origin => FallbackTags.tag(origin, FallbackTags.getTag(p))) + .foreach(origin => FallbackTags.add(origin, FallbackTags.get(p))) case _ => } @@ -274,13 +274,11 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP val vanillaSparkTransitionCost = countTransitionCostForVanillaSparkPlan(vanillaSparkPlan) if ( GlutenConfig.getConf.fallbackPreferColumnar && - fallbackInfo.netTransitionCost <= vanillaSparkTransitionCost + fallbackInfo.netTransitionCost <= vanillaSparkTransitionCost ) { plan } else { - FallbackTags.addRecursively( - vanillaSparkPlan, - TRANSFORM_UNSUPPORTED(fallbackInfo.reason, appendReasonIfExists = false)) + FallbackTags.addRecursively(vanillaSparkPlan, FallbackTag.Exclusive(fallbackInfo.reason.getOrElse("Unknown reason"))) FallbackNode(vanillaSparkPlan) } } else { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackTagRule.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala similarity index 88% rename from gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackTagRule.scala rename to gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala index d34cb0df4e7e4..324443eea08b9 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackTagRule.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, QueryStageExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -50,10 +50,51 @@ sealed trait FallbackTag { if (FallbackTags.DEBUG) { Some(ExceptionUtils.getStackTrace(new Throwable())) } else None + + def reason(): String } -case class TRANSFORM_UNSUPPORTED(reason: Option[String], appendReasonIfExists: Boolean = true) - extends FallbackTag +object FallbackTag { + + /** A tag that stores one reason text of fall back. */ + case class Appendable(override val reason: String) extends FallbackTag + + /** + * A tag that stores reason text of fall back. Other reasons will be discarded when this tag is + * added to plan. + */ + case class Exclusive(override val reason: String) extends FallbackTag + + trait Converter[T] { + def from(obj: T): Option[FallbackTag] + } + + object Converter { + implicit def asIs[T <: FallbackTag]: Converter[T] = (tag: T) => Some(tag) + + implicit object FromString extends Converter[String] { + override def from(reason: String): Option[FallbackTag] = Some(Appendable(reason)) + } + + implicit object FromStringOption extends Converter[Option[String]] { + override def from(reason: Option[String]): Option[FallbackTag] = { + reason match { + case Some(r) => Some(Appendable(r)) + case None => Some(Appendable("Unknown fallback reason")) + } + } + } + + implicit object FromValidationResult extends Converter[ValidationResult] { + override def from(result: ValidationResult): Option[FallbackTag] = { + if (result.ok()) { + return None + } + Some(Appendable(result.reason())) + } + } + } +} object FallbackTags { val TAG: TreeNodeTag[FallbackTag] = @@ -70,10 +111,7 @@ object FallbackTags { * rules are passed. */ def nonEmpty(plan: SparkPlan): Boolean = { - getTagOption(plan) match { - case Some(TRANSFORM_UNSUPPORTED(_, _)) => true - case _ => false - } + getOption(plan).nonEmpty } /** @@ -84,72 +122,55 @@ object FallbackTags { */ def maybeOffloadable(plan: SparkPlan): Boolean = !nonEmpty(plan) - def tag(plan: SparkPlan, hint: FallbackTag): Unit = { - val mergedHint = getTagOption(plan) - .map { - case originalHint @ TRANSFORM_UNSUPPORTED(Some(originalReason), originAppend) => - hint match { - case TRANSFORM_UNSUPPORTED(Some(newReason), append) => - if (originAppend && append) { - TRANSFORM_UNSUPPORTED(Some(originalReason + "; " + newReason)) - } else if (originAppend) { - TRANSFORM_UNSUPPORTED(Some(originalReason)) - } else if (append) { - TRANSFORM_UNSUPPORTED(Some(newReason)) - } else { - TRANSFORM_UNSUPPORTED(Some(originalReason), false) + def add[T](plan: TreeNode[_], t: T)(implicit converter: FallbackTag.Converter[T]): Unit = { + val tagOption = getOption(plan) + val newTagOption = converter.from(t) + + tagOption + .flatMap( + tag => + newTagOption.map( + newTag => { + // New tag comes while the plan was already tagged, merge. + (tag, newTag) match { + case (_, exclusive: FallbackTag.Exclusive) => + exclusive + case (exclusive: FallbackTag.Exclusive, _) => + exclusive + case (l: FallbackTag.Appendable, r: FallbackTag.Appendable) => + FallbackTag.Appendable(s"${l.reason}; ${r.reason}") } - case TRANSFORM_UNSUPPORTED(None, _) => - originalHint - case _ => - throw new GlutenNotSupportException( - "Plan was already tagged as non-transformable, " + - s"cannot mark it as transformable after that:\n${plan.toString()}") - } - case _ => - hint - } - .getOrElse(hint) - plan.setTagValue(TAG, mergedHint) - } - - def untag(plan: SparkPlan): Unit = { - plan.unsetTagValue(TAG) + })) + .foreach(mergedTag => plan.setTagValue(TAG, mergedTag)) } - def add(plan: SparkPlan, validationResult: ValidationResult): Unit = { - if (!validationResult.isValid) { - tag(plan, TRANSFORM_UNSUPPORTED(validationResult.reason)) - } - } - - def add(plan: SparkPlan, reason: String): Unit = { - tag(plan, TRANSFORM_UNSUPPORTED(Some(reason))) - } - - def addRecursively(plan: SparkPlan, hint: TRANSFORM_UNSUPPORTED): Unit = { + def addRecursively[T](plan: TreeNode[_], t: T)(implicit + converter: FallbackTag.Converter[T]): Unit = { plan.foreach { case _: GlutenPlan => // ignore - case other => tag(other, hint) + case other: TreeNode[_] => add(other, t) } } - def getTag(plan: SparkPlan): FallbackTag = { - getTagOption(plan).getOrElse( + def untag(plan: TreeNode[_]): Unit = { + plan.unsetTagValue(TAG) + } + + def get(plan: TreeNode[_]): FallbackTag = { + getOption(plan).getOrElse( throw new IllegalStateException("Transform hint tag not set in plan: " + plan.toString())) } - def getTagOption(plan: SparkPlan): Option[FallbackTag] = { + def getOption(plan: TreeNode[_]): Option[FallbackTag] = { plan.getTagValue(TAG) } - implicit class EncodeFallbackTagImplicits(validationResult: ValidationResult) { - def tagOnFallback(plan: SparkPlan): Unit = { - if (validationResult.isValid) { + implicit class EncodeFallbackTagImplicits(result: ValidationResult) { + def tagOnFallback(plan: TreeNode[_]): Unit = { + if (result.ok()) { return } - val newTag = TRANSFORM_UNSUPPORTED(validationResult.reason) - tag(plan, newTag) + add(plan, result) } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 7a4222b5cb382..85d24d37cea35 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -379,7 +379,7 @@ case class OffloadFilter() extends OffloadSingleNode with LogLevelUtil { val newScan = FilterHandler.pushFilterToScan(filter.condition, scan) newScan match { - case ts: TransformSupport if ts.doValidate().isValid => ts + case ts: TransformSupport if ts.doValidate().ok() => ts case _ => scan } } else scan @@ -556,12 +556,12 @@ object OffloadOthers { case plan: FileSourceScanExec => val transformer = ScanTransformerFactory.createFileSourceScanTransformer(plan) val validationResult = transformer.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") transformer } else { logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.") - FallbackTags.add(plan, validationResult.reason.get) + FallbackTags.add(plan, validationResult.reason()) plan } case plan: BatchScanExec => @@ -571,12 +571,12 @@ object OffloadOthers { val hiveTableScanExecTransformer = BackendsApiManager.getSparkPlanExecApiInstance.genHiveTableScanExecTransformer(plan) val validateResult = hiveTableScanExecTransformer.doValidate() - if (validateResult.isValid) { + if (validateResult.ok()) { logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") return hiveTableScanExecTransformer } logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.") - FallbackTags.add(plan, validateResult.reason.get) + FallbackTags.add(plan, validateResult.reason()) plan case other => throw new GlutenNotSupportException(s"${other.getClass.toString} is not supported.") diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala index d32de32ebb322..ac35ac83bb21d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala @@ -71,7 +71,7 @@ object NativeWriteFilesWithSkippingSortAndProject extends Logging { } val transformer = ProjectExecTransformer(newProjectList, p.child) val validationResult = transformer.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { Some(transformer) } else { // If we can not transform the project, then we fallback to origin plan which means diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala index 611d6db0bd483..4070a0a586120 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala @@ -37,7 +37,7 @@ class PushFilterToScan(validator: Validator) extends RasRule[SparkPlan] { val newScan = FilterHandler.pushFilterToScan(filter.condition, scan) newScan match { - case ts: TransformSupport if ts.doValidate().isValid => + case ts: TransformSupport if ts.doValidate().ok() => List(filter.withNewChildren(List(ts))) case _ => List.empty diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala index 8091127da0bfa..43f52a9e4758e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala @@ -106,7 +106,7 @@ object RasOffload { case Validator.Passed => val offloaded = base.offload(from) val offloadedNodes = offloaded.collect[GlutenPlan] { case t: GlutenPlan => t } - if (offloadedNodes.exists(!_.doValidate().isValid)) { + if (offloadedNodes.exists(!_.doValidate().ok())) { // 4. If native validation fails on the offloaded node, return the // original one. from diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala index 2abd4d7d48074..e005a3dc81639 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala @@ -74,7 +74,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules: Seq[RewriteSingleNode] case p if !p.isInstanceOf[ProjectExec] && !p.isInstanceOf[RewrittenNodeWall] => p } assert(target.size == 1) - FallbackTags.getTagOption(target.head) + FallbackTags.getOption(target.head) } private def applyRewriteRules(origin: SparkPlan): (SparkPlan, Option[String]) = { @@ -112,10 +112,10 @@ class RewriteSparkPlanRulesManager private (rewriteRules: Seq[RewriteSingleNode] origin } else { addHint.apply(rewrittenPlan) - val hint = getFallbackTagBack(rewrittenPlan) - if (hint.isDefined) { + val tag = getFallbackTagBack(rewrittenPlan) + if (tag.isDefined) { // If the rewritten plan is still not transformable, return the original plan. - FallbackTags.tag(origin, hint.get) + FallbackTags.add(origin, tag.get) origin } else { rewrittenPlan.transformUp { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala index 959bf808aba46..47dd0d546d03b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala @@ -19,7 +19,7 @@ package org.apache.gluten.extension.columnar.validator import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.{BackendsApiManager, BackendSettingsApi} import org.apache.gluten.expression.ExpressionUtils -import org.apache.gluten.extension.columnar.{FallbackTags, TRANSFORM_UNSUPPORTED} +import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.execution._ @@ -109,8 +109,8 @@ object Validators { private object FallbackByHint extends Validator { override def validate(plan: SparkPlan): Validator.OutCome = { if (FallbackTags.nonEmpty(plan)) { - val hint = FallbackTags.getTag(plan).asInstanceOf[TRANSFORM_UNSUPPORTED] - return fail(hint.reason.getOrElse("Reason not recorded")) + val tag = FallbackTags.get(plan) + return fail(tag.reason()) } pass() } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala index df1c87cb0ccc4..d6507e3e51d15 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala @@ -139,18 +139,18 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) mode == IdentityBroadcastMode && !BackendsApiManager.getSettings .supportBroadcastNestedLoopJoinExec() ) { - return ValidationResult.notOk("This backend does not support IdentityBroadcastMode and BNLJ") + return ValidationResult.failed("This backend does not support IdentityBroadcastMode and BNLJ") } BackendsApiManager.getValidatorApiInstance .doSchemaValidate(schema) .map { reason => { - ValidationResult.notOk( + ValidationResult.failed( s"Unsupported schema in broadcast exchange: $schema, reason: $reason") } } - .getOrElse(ValidationResult.ok) + .getOrElse(ValidationResult.succeeded) } override def doPrepare(): Unit = { diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index 85a4dd3878a30..31175a43fbaec 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.ValidationResult +import org.apache.gluten.extension.{GlutenPlan, ValidationResult} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark._ @@ -119,10 +118,10 @@ case class ColumnarShuffleExchangeExec( .doColumnarShuffleExchangeExecValidate(outputPartitioning, child) .map { reason => - ValidationResult.notOk( + ValidationResult.failed( s"Found schema check failure for schema ${child.schema} due to: $reason") } - .getOrElse(ValidationResult.ok) + .getOrElse(ValidationResult.succeeded) } override def nodeName: String = "ColumnarExchange" diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala index 781dc6b6f717a..338136b6d8704 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.execution import org.apache.gluten.execution.WholeStageTransformer import org.apache.gluten.extension.GlutenPlan +import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.utils.PlanUtil import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.execution.GlutenFallbackReporter.FALLBACK_REASON_TAG import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} @@ -59,8 +59,8 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { p: SparkPlan, fallbackNodeToReason: mutable.HashMap[String, String] ): Unit = { - p.logicalLink.flatMap(_.getTagValue(FALLBACK_REASON_TAG)) match { - case Some(reason) => addFallbackNodeWithReason(p, reason, fallbackNodeToReason) + p.logicalLink.flatMap(FallbackTags.getOption) match { + case Some(tag) => addFallbackNodeWithReason(p, tag.reason(), fallbackNodeToReason) case _ => // If the SparkPlan does not have fallback reason, then there are two options: // 1. Gluten ignore that plan and it's a kind of fallback diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala index d41dce882602b..67ecf81b9ffda 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql.execution import org.apache.gluten.GlutenConfig import org.apache.gluten.events.GlutenPlanFallbackEvent import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackTags, TRANSFORM_UNSUPPORTED} +import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.utils.LogLevelUtil import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat -import org.apache.spark.sql.execution.GlutenFallbackReporter.FALLBACK_REASON_TAG import org.apache.spark.sql.execution.ui.GlutenEventUtils /** @@ -58,41 +56,14 @@ case class GlutenFallbackReporter(glutenConfig: GlutenConfig, spark: SparkSessio plan.foreachUp { case _: GlutenPlan => // ignore case p: SparkPlan if FallbackTags.nonEmpty(p) => - FallbackTags.getTag(p) match { - case TRANSFORM_UNSUPPORTED(Some(reason), append) => - logFallbackReason(validationLogLevel, p.nodeName, reason) - // With in next round stage in AQE, the physical plan would be a new instance that - // can not preserve the tag, so we need to set the fallback reason to logical plan. - // Then we can be aware of the fallback reason for the whole plan. - // If a logical plan mapping to several physical plan, we add all reason into - // that logical plan to make sure we do not lose any fallback reason. - p.logicalLink.foreach { - logicalPlan => - val newReason = logicalPlan - .getTagValue(FALLBACK_REASON_TAG) - .map { - lastReason => - if (!append) { - lastReason - } else if (lastReason.contains(reason)) { - // use the last reason, as the reason is redundant - lastReason - } else if (reason.contains(lastReason)) { - // overwrite the reason - reason - } else { - // add the new reason - lastReason + "; " + reason - } - } - .getOrElse(reason) - logicalPlan.setTagValue(FALLBACK_REASON_TAG, newReason) - } - case TRANSFORM_UNSUPPORTED(_, _) => - logFallbackReason(validationLogLevel, p.nodeName, "unknown reason") - case _ => - throw new IllegalStateException("Unreachable code") - } + val tag = FallbackTags.get(p) + logFallbackReason(validationLogLevel, p.nodeName, tag.reason()) + // With in next round stage in AQE, the physical plan would be a new instance that + // can not preserve the tag, so we need to set the fallback reason to logical plan. + // Then we can be aware of the fallback reason for the whole plan. + // If a logical plan mapping to several physical plan, we add all reason into + // that logical plan to make sure we do not lose any fallback reason. + p.logicalLink.foreach(logicalPlan => FallbackTags.add(logicalPlan, tag)) case _ => } } @@ -119,7 +90,4 @@ case class GlutenFallbackReporter(glutenConfig: GlutenConfig, spark: SparkSessio } } -object GlutenFallbackReporter { - // A tag used to inject to logical plan to preserve the fallback reason - val FALLBACK_REASON_TAG = new TreeNodeTag[String]("GLUTEN_FALLBACK_REASON") -} +object GlutenFallbackReporter {} diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala index ecedc1bae01c8..6a9da0a9cf92e 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala @@ -62,7 +62,7 @@ case class EvalPythonExecTransformer( // All udfs should be scalar python udf for (udf <- udfs) { if (!PythonUDF.isScalarPythonUDF(udf)) { - return ValidationResult.notOk(s"$udf is not scalar python udf") + return ValidationResult.failed(s"$udf is not scalar python udf") } } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala index 95793e5dc9354..2a3ba79ebc2a5 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala @@ -202,7 +202,7 @@ object HiveTableScanExecTransformer { hiveTableScan.relation, hiveTableScan.partitionPruningPred)(hiveTableScan.session) hiveTableScanTransformer.doValidate() - case _ => ValidationResult.notOk("Is not a Hive scan") + case _ => ValidationResult.failed("Is not a Hive scan") } }