diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 6b23c6f39c62..341a3e0f0a52 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -172,20 +172,20 @@ object CHBackendSettings extends BackendSettingsApi with Logging { format match { case ParquetReadFormat => if (validateFilePath) { - ValidationResult.ok + ValidationResult.succeeded } else { - ValidationResult.notOk("Validate file path failed.") + ValidationResult.failed("Validate file path failed.") } - case OrcReadFormat => ValidationResult.ok - case MergeTreeReadFormat => ValidationResult.ok + case OrcReadFormat => ValidationResult.succeeded + case MergeTreeReadFormat => ValidationResult.succeeded case TextReadFormat => if (!hasComplexType) { - ValidationResult.ok + ValidationResult.succeeded } else { - ValidationResult.notOk("Has complex type.") + ValidationResult.failed("Has complex type.") } - case JsonReadFormat => ValidationResult.ok - case _ => ValidationResult.notOk(s"Unsupported file format $format") + case JsonReadFormat => ValidationResult.succeeded + case _ => ValidationResult.failed(s"Unsupported file format $format") } } @@ -291,7 +291,7 @@ object CHBackendSettings extends BackendSettingsApi with Logging { fields: Array[StructField], bucketSpec: Option[BucketSpec], options: Map[String, String]): ValidationResult = - ValidationResult.notOk("CH backend is unsupported.") + ValidationResult.failed("CH backend is unsupported.") override def enableNativeWriteFiles(): Boolean = { GlutenConfig.getConf.enableNativeWriter.getOrElse(false) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala index 35be8ee0b13e..9c0f41361f02 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala @@ -98,12 +98,12 @@ case class CHBroadcastNestedLoopJoinExecTransformer( case _: InnerLike => case _ => if (joinType == LeftSemi || condition.isDefined) { - return ValidationResult.notOk( + return ValidationResult.failed( s"Broadcast Nested Loop join is not supported join type $joinType with conditions") } } - ValidationResult.ok + ValidationResult.succeeded } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala index 733c0a472814..44cb0deca523 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala @@ -64,7 +64,7 @@ case class CHGenerateExecTransformer( override protected def doGeneratorValidate( generator: Generator, outer: Boolean): ValidationResult = - ValidationResult.ok + ValidationResult.succeeded override protected def getRelNode( context: SubstraitContext, diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index da9d9c7586c0..48870892d290 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -60,7 +60,7 @@ case class CHShuffledHashJoinExecTransformer( right.outputSet, condition) if (shouldFallback) { - return ValidationResult.notOk("ch join validate fail") + return ValidationResult.failed("ch join validate fail") } super.doValidateInternal() } @@ -118,10 +118,10 @@ case class CHBroadcastHashJoinExecTransformer( condition) if (shouldFallback) { - return ValidationResult.notOk("ch join validate fail") + return ValidationResult.failed("ch join validate fail") } if (isNullAwareAntiJoin) { - return ValidationResult.notOk("ch does not support NAAJ") + return ValidationResult.failed("ch does not support NAAJ") } super.doValidateInternal() } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala index e2b586551739..670fbed69300 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala @@ -50,7 +50,7 @@ case class CHSortMergeJoinExecTransformer( right.outputSet, condition) if (shouldFallback) { - return ValidationResult.notOk("ch join validate fail") + return ValidationResult.failed("ch join validate fail") } super.doValidateInternal() } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala index c7f9b47de642..842dc76153f3 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala @@ -57,12 +57,12 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend !columnarConf.enableColumnarBroadcastExchange || !columnarConf.enableColumnarBroadcastJoin ) { - ValidationResult.notOk( + ValidationResult.failed( "columnar broadcast exchange is disabled or " + "columnar broadcast join is disabled") } else { if (FallbackTags.nonEmpty(bhj)) { - ValidationResult.notOk("broadcast join is already tagged as not transformable") + ValidationResult.failed("broadcast join is already tagged as not transformable") } else { val bhjTransformer = BackendsApiManager.getSparkPlanExecApiInstance .genBroadcastHashJoinExecTransformer( @@ -75,7 +75,7 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend bhj.right, bhj.isNullAwareAntiJoin) val isBhjTransformable = bhjTransformer.doValidate() - if (isBhjTransformable.isValid) { + if (isBhjTransformable.ok()) { val exchangeTransformer = ColumnarBroadcastExchangeExec(mode, child) exchangeTransformer.doValidate() } else { @@ -111,12 +111,12 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend !GlutenConfig.getConf.enableColumnarBroadcastExchange || !GlutenConfig.getConf.enableColumnarBroadcastJoin ) { - ValidationResult.notOk( + ValidationResult.failed( "columnar broadcast exchange is disabled or " + "columnar broadcast join is disabled") } else { if (FallbackTags.nonEmpty(bnlj)) { - ValidationResult.notOk("broadcast join is already tagged as not transformable") + ValidationResult.failed("broadcast join is already tagged as not transformable") } else { val transformer = BackendsApiManager.getSparkPlanExecApiInstance .genBroadcastNestedLoopJoinExecTransformer( @@ -126,7 +126,7 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend bnlj.joinType, bnlj.condition) val isTransformable = transformer.doValidate() - if (isTransformable.isValid) { + if (isTransformable.ok()) { val exchangeTransformer = ColumnarBroadcastExchangeExec(mode, child) exchangeTransformer.doValidate() } else { @@ -242,7 +242,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl maybeExchange match { case Some(exchange @ BroadcastExchangeExec(_, _)) => isTransformable.tagOnFallback(plan) - if (!isTransformable.isValid) { + if (!isTransformable.ok) { FallbackTags.add(exchange, isTransformable) } case None => @@ -280,7 +280,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl plan, "it's a materialized broadcast exchange or reused broadcast exchange") case ColumnarBroadcastExchangeExec(mode, child) => - if (!isTransformable.isValid) { + if (!isTransformable.ok) { throw new IllegalStateException( s"BroadcastExchange has already been" + s" transformed to columnar version but BHJ is determined as" + diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala index 522c7d68fa68..29fa0d0aba96 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala @@ -58,7 +58,7 @@ case class CHColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBase(c s"${field.dataType} is not supported in ColumnarToRowExecBase.") } } - ValidationResult.ok + ValidationResult.succeeded } override def doExecuteInternal(): RDD[InternalRow] = { diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 8d98c111af68..31f0324e32eb 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -79,9 +79,9 @@ object VeloxBackendSettings extends BackendSettingsApi { // Collect unsupported types. val unsupportedDataTypeReason = fields.collect(validatorFunc) if (unsupportedDataTypeReason.isEmpty) { - ValidationResult.ok + ValidationResult.succeeded } else { - ValidationResult.notOk( + ValidationResult.failed( s"Found unsupported data type in $format: ${unsupportedDataTypeReason.mkString(", ")}.") } } @@ -135,10 +135,10 @@ object VeloxBackendSettings extends BackendSettingsApi { } else { validateTypes(parquetTypeValidatorWithComplexTypeFallback) } - case DwrfReadFormat => ValidationResult.ok + case DwrfReadFormat => ValidationResult.succeeded case OrcReadFormat => if (!GlutenConfig.getConf.veloxOrcScanEnabled) { - ValidationResult.notOk(s"Velox ORC scan is turned off.") + ValidationResult.failed(s"Velox ORC scan is turned off.") } else { val typeValidator: PartialFunction[StructField, String] = { case StructField(_, arrayType: ArrayType, _, _) @@ -164,7 +164,7 @@ object VeloxBackendSettings extends BackendSettingsApi { validateTypes(orcTypeValidatorWithComplexTypeFallback) } } - case _ => ValidationResult.notOk(s"Unsupported file format for $format.") + case _ => ValidationResult.failed(s"Unsupported file format for $format.") } } @@ -284,8 +284,8 @@ object VeloxBackendSettings extends BackendSettingsApi { .orElse(validateDataTypes()) .orElse(validateWriteFilesOptions()) .orElse(validateBucketSpec()) match { - case Some(reason) => ValidationResult.notOk(reason) - case _ => ValidationResult.ok + case Some(reason) => ValidationResult.failed(reason) + case _ => ValidationResult.succeeded } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 2b9d0173846a..37b46df3e23d 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -366,7 +366,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ child.output val projectTransformer = ProjectExecTransformer(projectList, child) val validationResult = projectTransformer.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { val newChild = maybeAddAppendBatchesExec(projectTransformer) ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output.drop(1)) } else { @@ -392,7 +392,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { val projectTransformer = ProjectExecTransformer(projectList, child) val projectBeforeSortValidationResult = projectTransformer.doValidate() // Make sure we support offload hash expression - val projectBeforeSort = if (projectBeforeSortValidationResult.isValid) { + val projectBeforeSort = if (projectBeforeSortValidationResult.ok()) { projectTransformer } else { val project = ProjectExec(projectList, child) @@ -405,7 +405,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { val dropSortColumnTransformer = ProjectExecTransformer(projectList.drop(1), sortByHashCode) val validationResult = dropSortColumnTransformer.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { val newChild = maybeAddAppendBatchesExec(dropSortColumnTransformer) ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output) } else { @@ -888,7 +888,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { case p @ LimitTransformer(SortExecTransformer(sortOrder, _, child, _), 0, count) => val global = child.outputPartitioning.satisfies(AllTuples) val topN = TopNTransformer(count, sortOrder, global, child) - if (topN.doValidate().isValid) { + if (topN.doValidate().ok()) { topN } else { p diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala index 8ceea8c14f6a..c7b81d55fa06 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala @@ -72,11 +72,11 @@ case class GenerateExecTransformer( generator: Generator, outer: Boolean): ValidationResult = { if (!supportsGenerate(generator, outer)) { - ValidationResult.notOk( + ValidationResult.failed( s"Velox backend does not support this generator: ${generator.getClass.getSimpleName}" + s", outer: $outer") } else { - ValidationResult.ok + ValidationResult.succeeded } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala index 1a54255208ea..2c46893e4576 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala @@ -66,7 +66,7 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas s"VeloxColumnarToRowExec.") } } - ValidationResult.ok + ValidationResult.succeeded } override def doExecuteInternal(): RDD[InternalRow] = { diff --git a/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java b/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java index 12f050c660f4..9cfad44d60f5 100644 --- a/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java +++ b/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java @@ -16,6 +16,8 @@ */ package org.apache.gluten.validate; +import org.apache.gluten.extension.ValidationResult; + import java.util.Vector; public class NativePlanValidationInfo { @@ -30,11 +32,13 @@ public NativePlanValidationInfo(int isSupported, String fallbackInfo) { } } - public boolean isSupported() { - return isSupported == 1; - } - - public Vector getFallbackInfo() { - return fallbackInfo; + public ValidationResult asResult() { + if (isSupported == 1) { + return ValidationResult.succeeded(); + } + return ValidationResult.failed( + String.format( + "Native validation failed: %n%s", + fallbackInfo.stream().reduce((l, r) -> l + "\n" + r))); } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index 07ead88601ff..8b4c18b01970 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -35,12 +35,12 @@ trait BackendSettingsApi { format: ReadFileFormat, fields: Array[StructField], partTable: Boolean, - paths: Seq[String]): ValidationResult = ValidationResult.ok + paths: Seq[String]): ValidationResult = ValidationResult.succeeded def supportWriteFilesExec( format: FileFormat, fields: Array[StructField], bucketSpec: Option[BucketSpec], - options: Map[String, String]): ValidationResult = ValidationResult.ok + options: Map[String, String]): ValidationResult = ValidationResult.succeeded def supportNativeWrite(fields: Array[StructField]): Boolean = true def supportNativeMetadataColumns(): Boolean = false def supportNativeRowIndexColumn(): Boolean = false diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala index 97b4c3a3f807..8e87baf5381d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala @@ -113,7 +113,7 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP if (remainingCondition == null) { // All the filters can be pushed down and the computing of this Filter // is not needed. - return ValidationResult.ok + return ValidationResult.succeeded } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala index 64071fb14c0c..99f145eeab1c 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala @@ -88,7 +88,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource val validationResult = BackendsApiManager.getSettings .supportFileFormatRead(fileFormat, fields, getPartitionSchema.nonEmpty, getInputFilePaths) - if (!validationResult.isValid) { + if (!validationResult.ok()) { return validationResult } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala index 6bff68895a24..4860847de9ac 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala @@ -133,18 +133,18 @@ abstract class BatchScanExecTransformerBase( override def doValidateInternal(): ValidationResult = { if (pushedAggregate.nonEmpty) { - return ValidationResult.notOk(s"Unsupported aggregation push down for $scan.") + return ValidationResult.failed(s"Unsupported aggregation push down for $scan.") } if ( SparkShimLoader.getSparkShims.findRowIndexColumnIndexInSchema(schema) > 0 && !BackendsApiManager.getSettings.supportNativeRowIndexColumn() ) { - return ValidationResult.notOk("Unsupported row index column scan in native.") + return ValidationResult.failed("Unsupported row index column scan in native.") } if (hasUnsupportedColumns) { - return ValidationResult.notOk(s"Unsupported columns scan in native.") + return ValidationResult.failed(s"Unsupported columns scan in native.") } super.doValidateInternal() diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index b90c1ad8b6e7..ae407b3b3efa 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -163,34 +163,35 @@ abstract class BroadcastNestedLoopJoinExecTransformer( def validateJoinTypeAndBuildSide(): ValidationResult = { val result = joinType match { - case _: InnerLike | LeftOuter | RightOuter => ValidationResult.ok + case _: InnerLike | LeftOuter | RightOuter => ValidationResult.succeeded case _ => - ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") + ValidationResult.failed(s"$joinType join is not supported with BroadcastNestedLoopJoin") } - if (!result.isValid) { + if (!result.ok()) { return result } (joinType, buildSide) match { case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => - ValidationResult.notOk(s"$joinType join is not supported with $buildSide") - case _ => ValidationResult.ok // continue + ValidationResult.failed(s"$joinType join is not supported with $buildSide") + case _ => ValidationResult.succeeded // continue } } override protected def doValidateInternal(): ValidationResult = { if (!GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled) { - return ValidationResult.notOk( + return ValidationResult.failed( s"Config ${GlutenConfig.BROADCAST_NESTED_LOOP_JOIN_TRANSFORMER_ENABLED.key} not enabled") } if (substraitJoinType == CrossRel.JoinType.UNRECOGNIZED) { - return ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") + return ValidationResult.failed( + s"$joinType join is not supported with BroadcastNestedLoopJoin") } val validateResult = validateJoinTypeAndBuildSide() - if (!validateResult.isValid) { + if (!validateResult.ok()) { return validateResult } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala index 91831f18493a..0dd110fa542f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala @@ -112,7 +112,7 @@ case class CartesianProductExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportCartesianProductExec()) { - return ValidationResult.notOk("Cartesian product is not supported in this backend") + return ValidationResult.failed("Cartesian product is not supported in this backend") } val substraitContext = new SubstraitContext val expressionNode = condition.map { diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala index 362debb531ee..63f76a25a231 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala @@ -95,10 +95,10 @@ case class ExpandExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportExpandExec()) { - return ValidationResult.notOk("Current backend does not support expand") + return ValidationResult.failed("Current backend does not support expand") } if (projections.isEmpty) { - return ValidationResult.notOk("Current backend does not support empty projections in expand") + return ValidationResult.failed("Current backend does not support empty projections in expand") } val substraitContext = new SubstraitContext diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala index 4f120488c2fb..3b8ed1167afc 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala @@ -130,23 +130,23 @@ abstract class FileSourceScanExecTransformerBase( if ( !metadataColumns.isEmpty && !BackendsApiManager.getSettings.supportNativeMetadataColumns() ) { - return ValidationResult.notOk(s"Unsupported metadata columns scan in native.") + return ValidationResult.failed(s"Unsupported metadata columns scan in native.") } if ( SparkShimLoader.getSparkShims.findRowIndexColumnIndexInSchema(schema) > 0 && !BackendsApiManager.getSettings.supportNativeRowIndexColumn() ) { - return ValidationResult.notOk("Unsupported row index column scan in native.") + return ValidationResult.failed("Unsupported row index column scan in native.") } if (hasUnsupportedColumns) { - return ValidationResult.notOk(s"Unsupported columns scan in native.") + return ValidationResult.failed(s"Unsupported columns scan in native.") } if (hasFieldIds) { // Spark read schema expects field Ids , the case didn't support yet by native. - return ValidationResult.notOk( + return ValidationResult.failed( s"Unsupported matching schema column names " + s"by field ids in native scan.") } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala index b5c9b85aeb0d..af4a92f194c1 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala @@ -67,7 +67,7 @@ abstract class GenerateExecTransformerBase( override protected def doValidateInternal(): ValidationResult = { val validationResult = doGeneratorValidate(generator, outer) - if (!validationResult.isValid) { + if (!validationResult.ok()) { return validationResult } val context = new SubstraitContext diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala index b200426d91ce..9a28af801d83 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala @@ -117,7 +117,7 @@ abstract class HashAggregateExecBaseTransformer( val unsupportedAggExprs = aggregateAttributes.filterNot(attr => checkType(attr.dataType)) if (unsupportedAggExprs.nonEmpty) { - return ValidationResult.notOk( + return ValidationResult.failed( "Found unsupported data type in aggregation expression: " + unsupportedAggExprs .map(attr => s"${attr.name}#${attr.exprId.id}:${attr.dataType}") @@ -125,7 +125,7 @@ abstract class HashAggregateExecBaseTransformer( } val unsupportedGroupExprs = groupingExpressions.filterNot(attr => checkType(attr.dataType)) if (unsupportedGroupExprs.nonEmpty) { - return ValidationResult.notOk( + return ValidationResult.failed( "Found unsupported data type in grouping expression: " + unsupportedGroupExprs .map(attr => s"${attr.name}#${attr.exprId.id}:${attr.dataType}") diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala index cd22c578594c..86e6c1f41265 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala @@ -210,7 +210,7 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport { // Firstly, need to check if the Substrait plan for this operator can be successfully generated. if (substraitJoinType == JoinRel.JoinType.UNRECOGNIZED) { return ValidationResult - .notOk(s"Unsupported join type of $hashJoinType for substrait: $substraitJoinType") + .failed(s"Unsupported join type of $hashJoinType for substrait: $substraitJoinType") } val relNode = JoinUtils.createJoinRel( streamedKeyExprs, diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala index 6f9ef34282bf..bed59b913a1e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala @@ -99,7 +99,7 @@ case class SampleExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (withReplacement) { - return ValidationResult.notOk( + return ValidationResult.failed( "Unsupported sample exec in native with " + s"withReplacement parameter is $withReplacement") } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala index 44a823834f92..a05a5e72bfe1 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala @@ -95,11 +95,11 @@ object ScanTransformerFactory { transformer.setPushDownFilters(allPushDownFilters.get) // Validate again if allPushDownFilters is defined. val validationResult = transformer.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { transformer } else { val newSource = batchScan.copy(runtimeFilters = transformer.runtimeFilters) - FallbackTags.add(newSource, validationResult.reason.get) + FallbackTags.add(newSource, validationResult.reason()) newSource } } else { diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala index f79dc69e680b..b69925d60fd2 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala @@ -91,7 +91,7 @@ case class SortExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportSortExec()) { - return ValidationResult.notOk("Current backend does not support sort") + return ValidationResult.failed("Current backend does not support sort") } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala index f032c4ca0087..c96789569f9a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala @@ -164,7 +164,7 @@ abstract class SortMergeJoinExecTransformerBase( // Firstly, need to check if the Substrait plan for this operator can be successfully generated. if (substraitJoinType == JoinRel.JoinType.UNRECOGNIZED) { return ValidationResult - .notOk(s"Found unsupported join type of $joinType for substrait: $substraitJoinType") + .failed(s"Found unsupported join type of $joinType for substrait: $substraitJoinType") } val relNode = JoinUtils.createJoinRel( streamedKeys, @@ -253,7 +253,7 @@ case class SortMergeJoinExecTransformer( // Firstly, need to check if the Substrait plan for this operator can be successfully generated. if (substraitJoinType == JoinRel.JoinType.JOIN_TYPE_OUTER) { return ValidationResult - .notOk(s"Found unsupported join type of $joinType for velox smj: $substraitJoinType") + .failed(s"Found unsupported join type of $joinType for velox smj: $substraitJoinType") } super.doValidateInternal() } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala index 74158d6332dc..b31471e21397 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala @@ -67,7 +67,7 @@ case class TakeOrderedAndProjectExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (offset != 0) { - return ValidationResult.notOk(s"Native TopK does not support offset: $offset") + return ValidationResult.failed(s"Native TopK does not support offset: $offset") } var tagged: ValidationResult = null @@ -83,14 +83,14 @@ case class TakeOrderedAndProjectExecTransformer( ColumnarCollapseTransformStages.wrapInputIteratorTransformer(child) val sortPlan = SortExecTransformer(sortOrder, false, inputTransformer) val sortValidation = sortPlan.doValidate() - if (!sortValidation.isValid) { + if (!sortValidation.ok()) { return sortValidation } val limitPlan = LimitTransformer(sortPlan, offset, limit) tagged = limitPlan.doValidate() } - if (tagged.isValid) { + if (tagged.ok()) { val projectPlan = ProjectExecTransformer(projectList, child) tagged = projectPlan.doValidate() } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala index 628c08f290eb..4902b6c6cf1b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala @@ -165,7 +165,7 @@ case class WindowExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportWindowExec(windowExpression)) { return ValidationResult - .notOk(s"Found unsupported window expression: ${windowExpression.mkString(", ")}") + .failed(s"Found unsupported window expression: ${windowExpression.mkString(", ")}") } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala index c93d01e7a12e..6068412fbad3 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala @@ -145,7 +145,7 @@ case class WindowGroupLimitExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction)) { return ValidationResult - .notOk(s"Found unsupported rank like function: $rankLikeFunction") + .failed(s"Found unsupported rank like function: $rankLikeFunction") } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala index 14d58bfa8377..d78f21beaabf 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala @@ -150,8 +150,8 @@ case class WriteFilesExecTransformer( finalChildOutput.toStructType.fields, bucketSpec, caseInsensitiveOptions) - if (!validationResult.isValid) { - return ValidationResult.notOk("Unsupported native write: " + validationResult.reason.get) + if (!validationResult.ok()) { + return ValidationResult.failed("Unsupported native write: " + validationResult.reason()) } val substraitContext = new SubstraitContext diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala index 71a76ff63dd1..0c70e1ea7a7b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala @@ -26,28 +26,29 @@ import org.apache.gluten.substrait.plan.PlanBuilder import org.apache.gluten.substrait.rel.RelNode import org.apache.gluten.test.TestStats import org.apache.gluten.utils.LogLevelUtil -import org.apache.gluten.validate.NativePlanValidationInfo import org.apache.spark.sql.execution.SparkPlan import com.google.common.collect.Lists -import scala.collection.JavaConverters._ - -case class ValidationResult(isValid: Boolean, reason: Option[String]) +sealed trait ValidationResult { + def ok(): Boolean + def reason(): String +} object ValidationResult { - def ok: ValidationResult = ValidationResult(isValid = true, None) - def notOk(reason: String): ValidationResult = ValidationResult(isValid = false, Option(reason)) - def convertFromValidationInfo(info: NativePlanValidationInfo): ValidationResult = { - if (info.isSupported) { - ok - } else { - val fallbackInfo = info.getFallbackInfo.asScala - .mkString("Native validation failed:\n ", "\n ", "") - notOk(fallbackInfo) - } + private case object Succeeded extends ValidationResult { + override def ok(): Boolean = true + override def reason(): String = throw new UnsupportedOperationException( + "Succeeded validation doesn't have failure details") } + + private case class Failed(override val reason: String) extends ValidationResult { + override def ok(): Boolean = false + } + + def succeeded: ValidationResult = Succeeded + def failed(reason: String): ValidationResult = Failed(reason) } /** Every Gluten Operator should extend this trait. */ @@ -66,17 +67,18 @@ trait GlutenPlan extends SparkPlan with Convention.KnownBatchType with LogLevelU val schemaVaidationResult = BackendsApiManager.getValidatorApiInstance .doSchemaValidate(schema) .map { - reason => ValidationResult.notOk(s"Found schema check failure for $schema, due to: $reason") + reason => + ValidationResult.failed(s"Found schema check failure for $schema, due to: $reason") } - .getOrElse(ValidationResult.ok) - if (!schemaVaidationResult.isValid) { + .getOrElse(ValidationResult.succeeded) + if (!schemaVaidationResult.ok()) { TestStats.addFallBackClassName(this.getClass.toString) return schemaVaidationResult } try { TransformerState.enterValidation val res = doValidateInternal() - if (!res.isValid) { + if (!res.ok()) { TestStats.addFallBackClassName(this.getClass.toString) } res @@ -90,7 +92,7 @@ trait GlutenPlan extends SparkPlan with Convention.KnownBatchType with LogLevelU logValidationMessage( s"Validation failed with exception for plan: $nodeName, due to: ${e.getMessage}", e) - ValidationResult.notOk(e.getMessage) + ValidationResult.failed(e.getMessage) } finally { TransformerState.finishValidation } @@ -109,16 +111,16 @@ trait GlutenPlan extends SparkPlan with Convention.KnownBatchType with LogLevelU BackendsApiManager.getSparkPlanExecApiInstance.batchType } - protected def doValidateInternal(): ValidationResult = ValidationResult.ok + protected def doValidateInternal(): ValidationResult = ValidationResult.succeeded protected def doNativeValidation(context: SubstraitContext, node: RelNode): ValidationResult = { if (node != null && enableNativeValidation) { val planNode = PlanBuilder.makePlan(context, Lists.newArrayList(node)) val info = BackendsApiManager.getValidatorApiInstance .doNativeValidateWithFailureReason(planNode) - ValidationResult.convertFromValidationInfo(info) + info.asResult() } else { - ValidationResult.ok + ValidationResult.succeeded } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala index 25674fd17147..bfb926706cf7 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala @@ -40,11 +40,11 @@ object CollapseProjectExecTransformer extends Rule[SparkPlan] { val collapsedProject = p2.copy(projectList = CollapseProjectShim.buildCleanedProjectList(p1.projectList, p2.projectList)) val validationResult = collapsedProject.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { logDebug(s"Collapse project $p1 and $p2.") collapsedProject } else { - logDebug(s"Failed to collapse project, due to ${validationResult.reason.getOrElse("")}") + logDebug(s"Failed to collapse project, due to ${validationResult.reason()}") p1 } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala index e334fcfbce88..491b54443d67 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala @@ -243,7 +243,7 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP originalPlan .find(_.logicalLink.exists(_.fastEquals(p.logicalLink.get))) .filterNot(FallbackTags.nonEmpty) - .foreach(origin => FallbackTags.tag(origin, FallbackTags.getTag(p))) + .foreach(origin => FallbackTags.add(origin, FallbackTags.get(p))) case _ => } @@ -280,7 +280,7 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP } else { FallbackTags.addRecursively( vanillaSparkPlan, - TRANSFORM_UNSUPPORTED(fallbackInfo.reason, appendReasonIfExists = false)) + FallbackTag.Exclusive(fallbackInfo.reason.getOrElse("Unknown reason"))) FallbackNode(vanillaSparkPlan) } } else { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackTagRule.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala similarity index 89% rename from gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackTagRule.scala rename to gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala index ddc6870e6c71..f9eaa4179c67 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackTagRule.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, QueryStageExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -50,10 +50,42 @@ sealed trait FallbackTag { if (FallbackTags.DEBUG) { Some(ExceptionUtils.getStackTrace(new Throwable())) } else None + + def reason(): String } -case class TRANSFORM_UNSUPPORTED(reason: Option[String], appendReasonIfExists: Boolean = true) - extends FallbackTag +object FallbackTag { + + /** A tag that stores one reason text of fall back. */ + case class Appendable(override val reason: String) extends FallbackTag + + /** + * A tag that stores reason text of fall back. Other reasons will be discarded when this tag is + * added to plan. + */ + case class Exclusive(override val reason: String) extends FallbackTag + + trait Converter[T] { + def from(obj: T): Option[FallbackTag] + } + + object Converter { + implicit def asIs[T <: FallbackTag]: Converter[T] = (tag: T) => Some(tag) + + implicit object FromString extends Converter[String] { + override def from(reason: String): Option[FallbackTag] = Some(Appendable(reason)) + } + + implicit object FromValidationResult extends Converter[ValidationResult] { + override def from(result: ValidationResult): Option[FallbackTag] = { + if (result.ok()) { + return None + } + Some(Appendable(result.reason())) + } + } + } +} object FallbackTags { val TAG: TreeNodeTag[FallbackTag] = @@ -70,10 +102,7 @@ object FallbackTags { * rules are passed. */ def nonEmpty(plan: SparkPlan): Boolean = { - getTagOption(plan) match { - case Some(TRANSFORM_UNSUPPORTED(_, _)) => true - case _ => false - } + getOption(plan).nonEmpty } /** @@ -84,72 +113,51 @@ object FallbackTags { */ def maybeOffloadable(plan: SparkPlan): Boolean = !nonEmpty(plan) - def tag(plan: SparkPlan, hint: FallbackTag): Unit = { - val mergedHint = getTagOption(plan) - .map { - case originalHint @ TRANSFORM_UNSUPPORTED(Some(originalReason), originAppend) => - hint match { - case TRANSFORM_UNSUPPORTED(Some(newReason), append) => - if (originAppend && append) { - TRANSFORM_UNSUPPORTED(Some(originalReason + "; " + newReason)) - } else if (originAppend) { - TRANSFORM_UNSUPPORTED(Some(originalReason)) - } else if (append) { - TRANSFORM_UNSUPPORTED(Some(newReason)) - } else { - TRANSFORM_UNSUPPORTED(Some(originalReason), false) - } - case TRANSFORM_UNSUPPORTED(None, _) => - originalHint - case _ => - throw new GlutenNotSupportException( - "Plan was already tagged as non-transformable, " + - s"cannot mark it as transformable after that:\n${plan.toString()}") - } - case _ => - hint + def add[T](plan: TreeNode[_], t: T)(implicit converter: FallbackTag.Converter[T]): Unit = { + val tagOption = getOption(plan) + val newTagOption = converter.from(t) + + val mergedTagOption: Option[FallbackTag] = + (tagOption ++ newTagOption).reduceOption[FallbackTag] { + // New tag comes while the plan was already tagged, merge. + case (_, exclusive: FallbackTag.Exclusive) => + exclusive + case (exclusive: FallbackTag.Exclusive, _) => + exclusive + case (l: FallbackTag.Appendable, r: FallbackTag.Appendable) => + FallbackTag.Appendable(s"${l.reason}; ${r.reason}") } - .getOrElse(hint) - plan.setTagValue(TAG, mergedHint) - } - - def untag(plan: SparkPlan): Unit = { - plan.unsetTagValue(TAG) + mergedTagOption + .foreach(mergedTag => plan.setTagValue(TAG, mergedTag)) } - def add(plan: SparkPlan, validationResult: ValidationResult): Unit = { - if (!validationResult.isValid) { - tag(plan, TRANSFORM_UNSUPPORTED(validationResult.reason)) - } - } - - def add(plan: SparkPlan, reason: String): Unit = { - tag(plan, TRANSFORM_UNSUPPORTED(Some(reason))) - } - - def addRecursively(plan: SparkPlan, hint: TRANSFORM_UNSUPPORTED): Unit = { + def addRecursively[T](plan: TreeNode[_], t: T)(implicit + converter: FallbackTag.Converter[T]): Unit = { plan.foreach { case _: GlutenPlan => // ignore - case other => tag(other, hint) + case other: TreeNode[_] => add(other, t) } } - def getTag(plan: SparkPlan): FallbackTag = { - getTagOption(plan).getOrElse( + def untag(plan: TreeNode[_]): Unit = { + plan.unsetTagValue(TAG) + } + + def get(plan: TreeNode[_]): FallbackTag = { + getOption(plan).getOrElse( throw new IllegalStateException("Transform hint tag not set in plan: " + plan.toString())) } - def getTagOption(plan: SparkPlan): Option[FallbackTag] = { + def getOption(plan: TreeNode[_]): Option[FallbackTag] = { plan.getTagValue(TAG) } - implicit class EncodeFallbackTagImplicits(validationResult: ValidationResult) { - def tagOnFallback(plan: SparkPlan): Unit = { - if (validationResult.isValid) { + implicit class EncodeFallbackTagImplicits(result: ValidationResult) { + def tagOnFallback(plan: TreeNode[_]): Unit = { + if (result.ok()) { return } - val newTag = TRANSFORM_UNSUPPORTED(validationResult.reason) - tag(plan, newTag) + add(plan, result) } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 742c353410d6..26c70293c7b3 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -379,7 +379,7 @@ case class OffloadFilter() extends OffloadSingleNode with LogLevelUtil { val newScan = FilterHandler.pushFilterToScan(filter.condition, scan) newScan match { - case ts: TransformSupport if ts.doValidate().isValid => ts + case ts: TransformSupport if ts.doValidate().ok() => ts case _ => scan } } else scan @@ -550,12 +550,12 @@ object OffloadOthers { case plan: FileSourceScanExec => val transformer = ScanTransformerFactory.createFileSourceScanTransformer(plan) val validationResult = transformer.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") transformer } else { logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.") - FallbackTags.add(plan, validationResult.reason.get) + FallbackTags.add(plan, validationResult.reason()) plan } case plan: BatchScanExec => @@ -565,12 +565,12 @@ object OffloadOthers { val hiveTableScanExecTransformer = BackendsApiManager.getSparkPlanExecApiInstance.genHiveTableScanExecTransformer(plan) val validateResult = hiveTableScanExecTransformer.doValidate() - if (validateResult.isValid) { + if (validateResult.ok()) { logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") return hiveTableScanExecTransformer } logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.") - FallbackTags.add(plan, validateResult.reason.get) + FallbackTags.add(plan, validateResult.reason()) plan case other => throw new GlutenNotSupportException(s"${other.getClass.toString} is not supported.") diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala index d32de32ebb32..ac35ac83bb21 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala @@ -71,7 +71,7 @@ object NativeWriteFilesWithSkippingSortAndProject extends Logging { } val transformer = ProjectExecTransformer(newProjectList, p.child) val validationResult = transformer.doValidate() - if (validationResult.isValid) { + if (validationResult.ok()) { Some(transformer) } else { // If we can not transform the project, then we fallback to origin plan which means diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala index 611d6db0bd48..4070a0a58612 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala @@ -37,7 +37,7 @@ class PushFilterToScan(validator: Validator) extends RasRule[SparkPlan] { val newScan = FilterHandler.pushFilterToScan(filter.condition, scan) newScan match { - case ts: TransformSupport if ts.doValidate().isValid => + case ts: TransformSupport if ts.doValidate().ok() => List(filter.withNewChildren(List(ts))) case _ => List.empty diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala index 8091127da0bf..43f52a9e4758 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala @@ -106,7 +106,7 @@ object RasOffload { case Validator.Passed => val offloaded = base.offload(from) val offloadedNodes = offloaded.collect[GlutenPlan] { case t: GlutenPlan => t } - if (offloadedNodes.exists(!_.doValidate().isValid)) { + if (offloadedNodes.exists(!_.doValidate().ok())) { // 4. If native validation fails on the offloaded node, return the // original one. from diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala index 2abd4d7d4807..e005a3dc8163 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala @@ -74,7 +74,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules: Seq[RewriteSingleNode] case p if !p.isInstanceOf[ProjectExec] && !p.isInstanceOf[RewrittenNodeWall] => p } assert(target.size == 1) - FallbackTags.getTagOption(target.head) + FallbackTags.getOption(target.head) } private def applyRewriteRules(origin: SparkPlan): (SparkPlan, Option[String]) = { @@ -112,10 +112,10 @@ class RewriteSparkPlanRulesManager private (rewriteRules: Seq[RewriteSingleNode] origin } else { addHint.apply(rewrittenPlan) - val hint = getFallbackTagBack(rewrittenPlan) - if (hint.isDefined) { + val tag = getFallbackTagBack(rewrittenPlan) + if (tag.isDefined) { // If the rewritten plan is still not transformable, return the original plan. - FallbackTags.tag(origin, hint.get) + FallbackTags.add(origin, tag.get) origin } else { rewrittenPlan.transformUp { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala index b6236ae9a536..903723ccb56b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala @@ -19,7 +19,7 @@ package org.apache.gluten.extension.columnar.validator import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.{BackendsApiManager, BackendSettingsApi} import org.apache.gluten.expression.ExpressionUtils -import org.apache.gluten.extension.columnar.{FallbackTags, TRANSFORM_UNSUPPORTED} +import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.execution._ @@ -109,8 +109,8 @@ object Validators { private object FallbackByHint extends Validator { override def validate(plan: SparkPlan): Validator.OutCome = { if (FallbackTags.nonEmpty(plan)) { - val hint = FallbackTags.getTag(plan).asInstanceOf[TRANSFORM_UNSUPPORTED] - return fail(hint.reason.getOrElse("Reason not recorded")) + val tag = FallbackTags.get(plan) + return fail(tag.reason()) } pass() } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index 85a4dd3878a3..31175a43fbae 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.ValidationResult +import org.apache.gluten.extension.{GlutenPlan, ValidationResult} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark._ @@ -119,10 +118,10 @@ case class ColumnarShuffleExchangeExec( .doColumnarShuffleExchangeExecValidate(outputPartitioning, child) .map { reason => - ValidationResult.notOk( + ValidationResult.failed( s"Found schema check failure for schema ${child.schema} due to: $reason") } - .getOrElse(ValidationResult.ok) + .getOrElse(ValidationResult.succeeded) } override def nodeName: String = "ColumnarExchange" diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala index 781dc6b6f717..338136b6d870 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.execution import org.apache.gluten.execution.WholeStageTransformer import org.apache.gluten.extension.GlutenPlan +import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.utils.PlanUtil import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.execution.GlutenFallbackReporter.FALLBACK_REASON_TAG import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} @@ -59,8 +59,8 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { p: SparkPlan, fallbackNodeToReason: mutable.HashMap[String, String] ): Unit = { - p.logicalLink.flatMap(_.getTagValue(FALLBACK_REASON_TAG)) match { - case Some(reason) => addFallbackNodeWithReason(p, reason, fallbackNodeToReason) + p.logicalLink.flatMap(FallbackTags.getOption) match { + case Some(tag) => addFallbackNodeWithReason(p, tag.reason(), fallbackNodeToReason) case _ => // If the SparkPlan does not have fallback reason, then there are two options: // 1. Gluten ignore that plan and it's a kind of fallback diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala index d41dce882602..67ecf81b9ffd 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql.execution import org.apache.gluten.GlutenConfig import org.apache.gluten.events.GlutenPlanFallbackEvent import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackTags, TRANSFORM_UNSUPPORTED} +import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.utils.LogLevelUtil import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat -import org.apache.spark.sql.execution.GlutenFallbackReporter.FALLBACK_REASON_TAG import org.apache.spark.sql.execution.ui.GlutenEventUtils /** @@ -58,41 +56,14 @@ case class GlutenFallbackReporter(glutenConfig: GlutenConfig, spark: SparkSessio plan.foreachUp { case _: GlutenPlan => // ignore case p: SparkPlan if FallbackTags.nonEmpty(p) => - FallbackTags.getTag(p) match { - case TRANSFORM_UNSUPPORTED(Some(reason), append) => - logFallbackReason(validationLogLevel, p.nodeName, reason) - // With in next round stage in AQE, the physical plan would be a new instance that - // can not preserve the tag, so we need to set the fallback reason to logical plan. - // Then we can be aware of the fallback reason for the whole plan. - // If a logical plan mapping to several physical plan, we add all reason into - // that logical plan to make sure we do not lose any fallback reason. - p.logicalLink.foreach { - logicalPlan => - val newReason = logicalPlan - .getTagValue(FALLBACK_REASON_TAG) - .map { - lastReason => - if (!append) { - lastReason - } else if (lastReason.contains(reason)) { - // use the last reason, as the reason is redundant - lastReason - } else if (reason.contains(lastReason)) { - // overwrite the reason - reason - } else { - // add the new reason - lastReason + "; " + reason - } - } - .getOrElse(reason) - logicalPlan.setTagValue(FALLBACK_REASON_TAG, newReason) - } - case TRANSFORM_UNSUPPORTED(_, _) => - logFallbackReason(validationLogLevel, p.nodeName, "unknown reason") - case _ => - throw new IllegalStateException("Unreachable code") - } + val tag = FallbackTags.get(p) + logFallbackReason(validationLogLevel, p.nodeName, tag.reason()) + // With in next round stage in AQE, the physical plan would be a new instance that + // can not preserve the tag, so we need to set the fallback reason to logical plan. + // Then we can be aware of the fallback reason for the whole plan. + // If a logical plan mapping to several physical plan, we add all reason into + // that logical plan to make sure we do not lose any fallback reason. + p.logicalLink.foreach(logicalPlan => FallbackTags.add(logicalPlan, tag)) case _ => } } @@ -119,7 +90,4 @@ case class GlutenFallbackReporter(glutenConfig: GlutenConfig, spark: SparkSessio } } -object GlutenFallbackReporter { - // A tag used to inject to logical plan to preserve the fallback reason - val FALLBACK_REASON_TAG = new TreeNodeTag[String]("GLUTEN_FALLBACK_REASON") -} +object GlutenFallbackReporter {} diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala index ecedc1bae01c..6a9da0a9cf92 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala @@ -62,7 +62,7 @@ case class EvalPythonExecTransformer( // All udfs should be scalar python udf for (udf <- udfs) { if (!PythonUDF.isScalarPythonUDF(udf)) { - return ValidationResult.notOk(s"$udf is not scalar python udf") + return ValidationResult.failed(s"$udf is not scalar python udf") } } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala index 95793e5dc935..2a3ba79ebc2a 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala @@ -202,7 +202,7 @@ object HiveTableScanExecTransformer { hiveTableScan.relation, hiveTableScan.partitionPruningPred)(hiveTableScan.session) hiveTableScanTransformer.doValidate() - case _ => ValidationResult.notOk("Is not a Hive scan") + case _ => ValidationResult.failed("Is not a Hive scan") } } diff --git a/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala b/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala index 1cd735cf7ee7..31e6c6940cd9 100644 --- a/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala +++ b/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala @@ -57,7 +57,7 @@ case class DeltaScanTransformer( _.name == "__delta_internal_is_row_deleted") || requiredSchema.fields.exists( _.name == "__delta_internal_row_index") ) { - return ValidationResult.notOk(s"Deletion vector is not supported in native.") + return ValidationResult.failed(s"Deletion vector is not supported in native.") } super.doValidateInternal() diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala index fe37da206a56..2ee1573ea07a 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala @@ -35,7 +35,7 @@ case class CustomerColumnarPreRules(session: SparkSession) extends Rule[SparkPla fileSourceScan.tableIdentifier, fileSourceScan.disableBucketedScan ) - if (transformer.doValidate().isValid) { + if (transformer.doValidate().ok()) { transformer } else { plan diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index b9c9d8a270bf..54d7596b602c 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags, TRANSFORM_UNSUPPORTED} +import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -124,17 +124,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Tag not transformable more than once") { val originalPlan = UnaryOp1(LeafOp(supportsColumnar = true)) - FallbackTags.tag(originalPlan, TRANSFORM_UNSUPPORTED(Some("fake reason"))) + FallbackTags.add(originalPlan, "fake reason") val rule = FallbackEmptySchemaRelation() val newPlan = rule.apply(originalPlan) - val reason = FallbackTags.getTag(newPlan).asInstanceOf[TRANSFORM_UNSUPPORTED].reason - assert(reason.isDefined) + val reason = FallbackTags.get(newPlan).reason() if (BackendsApiManager.getSettings.fallbackOnEmptySchema(newPlan)) { assert( - reason.get.contains("fake reason") && - reason.get.contains("at least one of its children has empty output")) + reason.contains("fake reason") && + reason.contains("at least one of its children has empty output")) } else { - assert(reason.get.contains("fake reason")) + assert(reason.contains("fake reason")) } } diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala index fe37da206a56..2ee1573ea07a 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala @@ -35,7 +35,7 @@ case class CustomerColumnarPreRules(session: SparkSession) extends Rule[SparkPla fileSourceScan.tableIdentifier, fileSourceScan.disableBucketedScan ) - if (transformer.doValidate().isValid) { + if (transformer.doValidate().ok()) { transformer } else { plan diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 8ce0af8df051..5150a4768851 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags, TRANSFORM_UNSUPPORTED} +import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -125,17 +125,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Tag not transformable more than once") { val originalPlan = UnaryOp1(LeafOp(supportsColumnar = true)) - FallbackTags.tag(originalPlan, TRANSFORM_UNSUPPORTED(Some("fake reason"))) + FallbackTags.add(originalPlan, "fake reason") val rule = FallbackEmptySchemaRelation() val newPlan = rule.apply(originalPlan) - val reason = FallbackTags.getTag(newPlan).asInstanceOf[TRANSFORM_UNSUPPORTED].reason - assert(reason.isDefined) + val reason = FallbackTags.get(newPlan).reason() if (BackendsApiManager.getSettings.fallbackOnEmptySchema(newPlan)) { assert( - reason.get.contains("fake reason") && - reason.get.contains("at least one of its children has empty output")) + reason.contains("fake reason") && + reason.contains("at least one of its children has empty output")) } else { - assert(reason.get.contains("fake reason")) + assert(reason.contains("fake reason")) } } diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala index fe37da206a56..2ee1573ea07a 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala @@ -35,7 +35,7 @@ case class CustomerColumnarPreRules(session: SparkSession) extends Rule[SparkPla fileSourceScan.tableIdentifier, fileSourceScan.disableBucketedScan ) - if (transformer.doValidate().isValid) { + if (transformer.doValidate().ok()) { transformer } else { plan diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 8ce0af8df051..5150a4768851 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags, TRANSFORM_UNSUPPORTED} +import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -125,17 +125,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Tag not transformable more than once") { val originalPlan = UnaryOp1(LeafOp(supportsColumnar = true)) - FallbackTags.tag(originalPlan, TRANSFORM_UNSUPPORTED(Some("fake reason"))) + FallbackTags.add(originalPlan, "fake reason") val rule = FallbackEmptySchemaRelation() val newPlan = rule.apply(originalPlan) - val reason = FallbackTags.getTag(newPlan).asInstanceOf[TRANSFORM_UNSUPPORTED].reason - assert(reason.isDefined) + val reason = FallbackTags.get(newPlan).reason() if (BackendsApiManager.getSettings.fallbackOnEmptySchema(newPlan)) { assert( - reason.get.contains("fake reason") && - reason.get.contains("at least one of its children has empty output")) + reason.contains("fake reason") && + reason.contains("at least one of its children has empty output")) } else { - assert(reason.get.contains("fake reason")) + assert(reason.contains("fake reason")) } } diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala index fe37da206a56..2ee1573ea07a 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala @@ -35,7 +35,7 @@ case class CustomerColumnarPreRules(session: SparkSession) extends Rule[SparkPla fileSourceScan.tableIdentifier, fileSourceScan.disableBucketedScan ) - if (transformer.doValidate().isValid) { + if (transformer.doValidate().ok()) { transformer } else { plan