diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 341a3e0f0a52c..6b23c6f39c62b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -172,20 +172,20 @@ object CHBackendSettings extends BackendSettingsApi with Logging { format match { case ParquetReadFormat => if (validateFilePath) { - ValidationResult.succeeded + ValidationResult.ok } else { - ValidationResult.failed("Validate file path failed.") + ValidationResult.notOk("Validate file path failed.") } - case OrcReadFormat => ValidationResult.succeeded - case MergeTreeReadFormat => ValidationResult.succeeded + case OrcReadFormat => ValidationResult.ok + case MergeTreeReadFormat => ValidationResult.ok case TextReadFormat => if (!hasComplexType) { - ValidationResult.succeeded + ValidationResult.ok } else { - ValidationResult.failed("Has complex type.") + ValidationResult.notOk("Has complex type.") } - case JsonReadFormat => ValidationResult.succeeded - case _ => ValidationResult.failed(s"Unsupported file format $format") + case JsonReadFormat => ValidationResult.ok + case _ => ValidationResult.notOk(s"Unsupported file format $format") } } @@ -291,7 +291,7 @@ object CHBackendSettings extends BackendSettingsApi with Logging { fields: Array[StructField], bucketSpec: Option[BucketSpec], options: Map[String, String]): ValidationResult = - ValidationResult.failed("CH backend is unsupported.") + ValidationResult.notOk("CH backend is unsupported.") override def enableNativeWriteFiles(): Boolean = { GlutenConfig.getConf.enableNativeWriter.getOrElse(false) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala index 9c0f41361f020..35be8ee0b13ea 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala @@ -98,12 +98,12 @@ case class CHBroadcastNestedLoopJoinExecTransformer( case _: InnerLike => case _ => if (joinType == LeftSemi || condition.isDefined) { - return ValidationResult.failed( + return ValidationResult.notOk( s"Broadcast Nested Loop join is not supported join type $joinType with conditions") } } - ValidationResult.succeeded + ValidationResult.ok } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala index 44cb0deca5230..733c0a472814d 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHGenerateExecTransformer.scala @@ -64,7 +64,7 @@ case class CHGenerateExecTransformer( override protected def doGeneratorValidate( generator: Generator, outer: Boolean): ValidationResult = - ValidationResult.succeeded + ValidationResult.ok override protected def getRelNode( context: SubstraitContext, diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 48870892d290e..da9d9c7586c05 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -60,7 +60,7 @@ case class CHShuffledHashJoinExecTransformer( right.outputSet, condition) if (shouldFallback) { - return ValidationResult.failed("ch join validate fail") + return ValidationResult.notOk("ch join validate fail") } super.doValidateInternal() } @@ -118,10 +118,10 @@ case class CHBroadcastHashJoinExecTransformer( condition) if (shouldFallback) { - return ValidationResult.failed("ch join validate fail") + return ValidationResult.notOk("ch join validate fail") } if (isNullAwareAntiJoin) { - return ValidationResult.failed("ch does not support NAAJ") + return ValidationResult.notOk("ch does not support NAAJ") } super.doValidateInternal() } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala index 670fbed693003..e2b5865517391 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala @@ -50,7 +50,7 @@ case class CHSortMergeJoinExecTransformer( right.outputSet, condition) if (shouldFallback) { - return ValidationResult.failed("ch join validate fail") + return ValidationResult.notOk("ch join validate fail") } super.doValidateInternal() } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala index 842dc76153f30..c7f9b47de642d 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala @@ -57,12 +57,12 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend !columnarConf.enableColumnarBroadcastExchange || !columnarConf.enableColumnarBroadcastJoin ) { - ValidationResult.failed( + ValidationResult.notOk( "columnar broadcast exchange is disabled or " + "columnar broadcast join is disabled") } else { if (FallbackTags.nonEmpty(bhj)) { - ValidationResult.failed("broadcast join is already tagged as not transformable") + ValidationResult.notOk("broadcast join is already tagged as not transformable") } else { val bhjTransformer = BackendsApiManager.getSparkPlanExecApiInstance .genBroadcastHashJoinExecTransformer( @@ -75,7 +75,7 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend bhj.right, bhj.isNullAwareAntiJoin) val isBhjTransformable = bhjTransformer.doValidate() - if (isBhjTransformable.ok()) { + if (isBhjTransformable.isValid) { val exchangeTransformer = ColumnarBroadcastExchangeExec(mode, child) exchangeTransformer.doValidate() } else { @@ -111,12 +111,12 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend !GlutenConfig.getConf.enableColumnarBroadcastExchange || !GlutenConfig.getConf.enableColumnarBroadcastJoin ) { - ValidationResult.failed( + ValidationResult.notOk( "columnar broadcast exchange is disabled or " + "columnar broadcast join is disabled") } else { if (FallbackTags.nonEmpty(bnlj)) { - ValidationResult.failed("broadcast join is already tagged as not transformable") + ValidationResult.notOk("broadcast join is already tagged as not transformable") } else { val transformer = BackendsApiManager.getSparkPlanExecApiInstance .genBroadcastNestedLoopJoinExecTransformer( @@ -126,7 +126,7 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend bnlj.joinType, bnlj.condition) val isTransformable = transformer.doValidate() - if (isTransformable.ok()) { + if (isTransformable.isValid) { val exchangeTransformer = ColumnarBroadcastExchangeExec(mode, child) exchangeTransformer.doValidate() } else { @@ -242,7 +242,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl maybeExchange match { case Some(exchange @ BroadcastExchangeExec(_, _)) => isTransformable.tagOnFallback(plan) - if (!isTransformable.ok) { + if (!isTransformable.isValid) { FallbackTags.add(exchange, isTransformable) } case None => @@ -280,7 +280,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl plan, "it's a materialized broadcast exchange or reused broadcast exchange") case ColumnarBroadcastExchangeExec(mode, child) => - if (!isTransformable.ok) { + if (!isTransformable.isValid) { throw new IllegalStateException( s"BroadcastExchange has already been" + s" transformed to columnar version but BHJ is determined as" + diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala index 29fa0d0aba960..522c7d68fa68b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarToRowExec.scala @@ -58,7 +58,7 @@ case class CHColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBase(c s"${field.dataType} is not supported in ColumnarToRowExecBase.") } } - ValidationResult.succeeded + ValidationResult.ok } override def doExecuteInternal(): RDD[InternalRow] = { diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 31f0324e32eb3..8d98c111af68f 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -79,9 +79,9 @@ object VeloxBackendSettings extends BackendSettingsApi { // Collect unsupported types. val unsupportedDataTypeReason = fields.collect(validatorFunc) if (unsupportedDataTypeReason.isEmpty) { - ValidationResult.succeeded + ValidationResult.ok } else { - ValidationResult.failed( + ValidationResult.notOk( s"Found unsupported data type in $format: ${unsupportedDataTypeReason.mkString(", ")}.") } } @@ -135,10 +135,10 @@ object VeloxBackendSettings extends BackendSettingsApi { } else { validateTypes(parquetTypeValidatorWithComplexTypeFallback) } - case DwrfReadFormat => ValidationResult.succeeded + case DwrfReadFormat => ValidationResult.ok case OrcReadFormat => if (!GlutenConfig.getConf.veloxOrcScanEnabled) { - ValidationResult.failed(s"Velox ORC scan is turned off.") + ValidationResult.notOk(s"Velox ORC scan is turned off.") } else { val typeValidator: PartialFunction[StructField, String] = { case StructField(_, arrayType: ArrayType, _, _) @@ -164,7 +164,7 @@ object VeloxBackendSettings extends BackendSettingsApi { validateTypes(orcTypeValidatorWithComplexTypeFallback) } } - case _ => ValidationResult.failed(s"Unsupported file format for $format.") + case _ => ValidationResult.notOk(s"Unsupported file format for $format.") } } @@ -284,8 +284,8 @@ object VeloxBackendSettings extends BackendSettingsApi { .orElse(validateDataTypes()) .orElse(validateWriteFilesOptions()) .orElse(validateBucketSpec()) match { - case Some(reason) => ValidationResult.failed(reason) - case _ => ValidationResult.succeeded + case Some(reason) => ValidationResult.notOk(reason) + case _ => ValidationResult.ok } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 37b46df3e23d9..2b9d0173846a8 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -366,7 +366,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ child.output val projectTransformer = ProjectExecTransformer(projectList, child) val validationResult = projectTransformer.doValidate() - if (validationResult.ok()) { + if (validationResult.isValid) { val newChild = maybeAddAppendBatchesExec(projectTransformer) ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output.drop(1)) } else { @@ -392,7 +392,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { val projectTransformer = ProjectExecTransformer(projectList, child) val projectBeforeSortValidationResult = projectTransformer.doValidate() // Make sure we support offload hash expression - val projectBeforeSort = if (projectBeforeSortValidationResult.ok()) { + val projectBeforeSort = if (projectBeforeSortValidationResult.isValid) { projectTransformer } else { val project = ProjectExec(projectList, child) @@ -405,7 +405,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { val dropSortColumnTransformer = ProjectExecTransformer(projectList.drop(1), sortByHashCode) val validationResult = dropSortColumnTransformer.doValidate() - if (validationResult.ok()) { + if (validationResult.isValid) { val newChild = maybeAddAppendBatchesExec(dropSortColumnTransformer) ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output) } else { @@ -888,7 +888,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { case p @ LimitTransformer(SortExecTransformer(sortOrder, _, child, _), 0, count) => val global = child.outputPartitioning.satisfies(AllTuples) val topN = TopNTransformer(count, sortOrder, global, child) - if (topN.doValidate().ok()) { + if (topN.doValidate().isValid) { topN } else { p diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala index c7b81d55fa067..8ceea8c14f6ad 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/GenerateExecTransformer.scala @@ -72,11 +72,11 @@ case class GenerateExecTransformer( generator: Generator, outer: Boolean): ValidationResult = { if (!supportsGenerate(generator, outer)) { - ValidationResult.failed( + ValidationResult.notOk( s"Velox backend does not support this generator: ${generator.getClass.getSimpleName}" + s", outer: $outer") } else { - ValidationResult.succeeded + ValidationResult.ok } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala index 2c46893e4576a..1a54255208ea4 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala @@ -66,7 +66,7 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas s"VeloxColumnarToRowExec.") } } - ValidationResult.succeeded + ValidationResult.ok } override def doExecuteInternal(): RDD[InternalRow] = { diff --git a/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java b/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java index 9cfad44d60f52..12f050c660f43 100644 --- a/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java +++ b/gluten-core/src/main/java/org/apache/gluten/validate/NativePlanValidationInfo.java @@ -16,8 +16,6 @@ */ package org.apache.gluten.validate; -import org.apache.gluten.extension.ValidationResult; - import java.util.Vector; public class NativePlanValidationInfo { @@ -32,13 +30,11 @@ public NativePlanValidationInfo(int isSupported, String fallbackInfo) { } } - public ValidationResult asResult() { - if (isSupported == 1) { - return ValidationResult.succeeded(); - } - return ValidationResult.failed( - String.format( - "Native validation failed: %n%s", - fallbackInfo.stream().reduce((l, r) -> l + "\n" + r))); + public boolean isSupported() { + return isSupported == 1; + } + + public Vector getFallbackInfo() { + return fallbackInfo; } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index 8b4c18b01970d..07ead88601ff2 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -35,12 +35,12 @@ trait BackendSettingsApi { format: ReadFileFormat, fields: Array[StructField], partTable: Boolean, - paths: Seq[String]): ValidationResult = ValidationResult.succeeded + paths: Seq[String]): ValidationResult = ValidationResult.ok def supportWriteFilesExec( format: FileFormat, fields: Array[StructField], bucketSpec: Option[BucketSpec], - options: Map[String, String]): ValidationResult = ValidationResult.succeeded + options: Map[String, String]): ValidationResult = ValidationResult.ok def supportNativeWrite(fields: Array[StructField]): Boolean = true def supportNativeMetadataColumns(): Boolean = false def supportNativeRowIndexColumn(): Boolean = false diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala index 8e87baf5381d9..97b4c3a3f807a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala @@ -113,7 +113,7 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP if (remainingCondition == null) { // All the filters can be pushed down and the computing of this Filter // is not needed. - return ValidationResult.succeeded + return ValidationResult.ok } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala index 99f145eeab1c9..64071fb14c0c0 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala @@ -88,7 +88,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource val validationResult = BackendsApiManager.getSettings .supportFileFormatRead(fileFormat, fields, getPartitionSchema.nonEmpty, getInputFilePaths) - if (!validationResult.ok()) { + if (!validationResult.isValid) { return validationResult } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala index 4860847de9acd..6bff68895a249 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala @@ -133,18 +133,18 @@ abstract class BatchScanExecTransformerBase( override def doValidateInternal(): ValidationResult = { if (pushedAggregate.nonEmpty) { - return ValidationResult.failed(s"Unsupported aggregation push down for $scan.") + return ValidationResult.notOk(s"Unsupported aggregation push down for $scan.") } if ( SparkShimLoader.getSparkShims.findRowIndexColumnIndexInSchema(schema) > 0 && !BackendsApiManager.getSettings.supportNativeRowIndexColumn() ) { - return ValidationResult.failed("Unsupported row index column scan in native.") + return ValidationResult.notOk("Unsupported row index column scan in native.") } if (hasUnsupportedColumns) { - return ValidationResult.failed(s"Unsupported columns scan in native.") + return ValidationResult.notOk(s"Unsupported columns scan in native.") } super.doValidateInternal() diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index ae407b3b3efa1..b90c1ad8b6e74 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -163,35 +163,34 @@ abstract class BroadcastNestedLoopJoinExecTransformer( def validateJoinTypeAndBuildSide(): ValidationResult = { val result = joinType match { - case _: InnerLike | LeftOuter | RightOuter => ValidationResult.succeeded + case _: InnerLike | LeftOuter | RightOuter => ValidationResult.ok case _ => - ValidationResult.failed(s"$joinType join is not supported with BroadcastNestedLoopJoin") + ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") } - if (!result.ok()) { + if (!result.isValid) { return result } (joinType, buildSide) match { case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => - ValidationResult.failed(s"$joinType join is not supported with $buildSide") - case _ => ValidationResult.succeeded // continue + ValidationResult.notOk(s"$joinType join is not supported with $buildSide") + case _ => ValidationResult.ok // continue } } override protected def doValidateInternal(): ValidationResult = { if (!GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled) { - return ValidationResult.failed( + return ValidationResult.notOk( s"Config ${GlutenConfig.BROADCAST_NESTED_LOOP_JOIN_TRANSFORMER_ENABLED.key} not enabled") } if (substraitJoinType == CrossRel.JoinType.UNRECOGNIZED) { - return ValidationResult.failed( - s"$joinType join is not supported with BroadcastNestedLoopJoin") + return ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") } val validateResult = validateJoinTypeAndBuildSide() - if (!validateResult.ok()) { + if (!validateResult.isValid) { return validateResult } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala index 0dd110fa542f2..91831f18493ad 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala @@ -112,7 +112,7 @@ case class CartesianProductExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportCartesianProductExec()) { - return ValidationResult.failed("Cartesian product is not supported in this backend") + return ValidationResult.notOk("Cartesian product is not supported in this backend") } val substraitContext = new SubstraitContext val expressionNode = condition.map { diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala index 63f76a25a2318..362debb531ee6 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala @@ -95,10 +95,10 @@ case class ExpandExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportExpandExec()) { - return ValidationResult.failed("Current backend does not support expand") + return ValidationResult.notOk("Current backend does not support expand") } if (projections.isEmpty) { - return ValidationResult.failed("Current backend does not support empty projections in expand") + return ValidationResult.notOk("Current backend does not support empty projections in expand") } val substraitContext = new SubstraitContext diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala index 3b8ed1167afcf..4f120488c2fb5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala @@ -130,23 +130,23 @@ abstract class FileSourceScanExecTransformerBase( if ( !metadataColumns.isEmpty && !BackendsApiManager.getSettings.supportNativeMetadataColumns() ) { - return ValidationResult.failed(s"Unsupported metadata columns scan in native.") + return ValidationResult.notOk(s"Unsupported metadata columns scan in native.") } if ( SparkShimLoader.getSparkShims.findRowIndexColumnIndexInSchema(schema) > 0 && !BackendsApiManager.getSettings.supportNativeRowIndexColumn() ) { - return ValidationResult.failed("Unsupported row index column scan in native.") + return ValidationResult.notOk("Unsupported row index column scan in native.") } if (hasUnsupportedColumns) { - return ValidationResult.failed(s"Unsupported columns scan in native.") + return ValidationResult.notOk(s"Unsupported columns scan in native.") } if (hasFieldIds) { // Spark read schema expects field Ids , the case didn't support yet by native. - return ValidationResult.failed( + return ValidationResult.notOk( s"Unsupported matching schema column names " + s"by field ids in native scan.") } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala index af4a92f194c1b..b5c9b85aeb0d5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala @@ -67,7 +67,7 @@ abstract class GenerateExecTransformerBase( override protected def doValidateInternal(): ValidationResult = { val validationResult = doGeneratorValidate(generator, outer) - if (!validationResult.ok()) { + if (!validationResult.isValid) { return validationResult } val context = new SubstraitContext diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala index 9a28af801d831..b200426d91ce9 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala @@ -117,7 +117,7 @@ abstract class HashAggregateExecBaseTransformer( val unsupportedAggExprs = aggregateAttributes.filterNot(attr => checkType(attr.dataType)) if (unsupportedAggExprs.nonEmpty) { - return ValidationResult.failed( + return ValidationResult.notOk( "Found unsupported data type in aggregation expression: " + unsupportedAggExprs .map(attr => s"${attr.name}#${attr.exprId.id}:${attr.dataType}") @@ -125,7 +125,7 @@ abstract class HashAggregateExecBaseTransformer( } val unsupportedGroupExprs = groupingExpressions.filterNot(attr => checkType(attr.dataType)) if (unsupportedGroupExprs.nonEmpty) { - return ValidationResult.failed( + return ValidationResult.notOk( "Found unsupported data type in grouping expression: " + unsupportedGroupExprs .map(attr => s"${attr.name}#${attr.exprId.id}:${attr.dataType}") diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala index 86e6c1f412656..cd22c578594c6 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala @@ -210,7 +210,7 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport { // Firstly, need to check if the Substrait plan for this operator can be successfully generated. if (substraitJoinType == JoinRel.JoinType.UNRECOGNIZED) { return ValidationResult - .failed(s"Unsupported join type of $hashJoinType for substrait: $substraitJoinType") + .notOk(s"Unsupported join type of $hashJoinType for substrait: $substraitJoinType") } val relNode = JoinUtils.createJoinRel( streamedKeyExprs, diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala index bed59b913a1e9..6f9ef34282bf0 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala @@ -99,7 +99,7 @@ case class SampleExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (withReplacement) { - return ValidationResult.failed( + return ValidationResult.notOk( "Unsupported sample exec in native with " + s"withReplacement parameter is $withReplacement") } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala index a05a5e72bfe1d..44a823834f926 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/ScanTransformerFactory.scala @@ -95,11 +95,11 @@ object ScanTransformerFactory { transformer.setPushDownFilters(allPushDownFilters.get) // Validate again if allPushDownFilters is defined. val validationResult = transformer.doValidate() - if (validationResult.ok()) { + if (validationResult.isValid) { transformer } else { val newSource = batchScan.copy(runtimeFilters = transformer.runtimeFilters) - FallbackTags.add(newSource, validationResult.reason()) + FallbackTags.add(newSource, validationResult.reason.get) newSource } } else { diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala index b69925d60fd21..f79dc69e680b5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala @@ -91,7 +91,7 @@ case class SortExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportSortExec()) { - return ValidationResult.failed("Current backend does not support sort") + return ValidationResult.notOk("Current backend does not support sort") } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala index c96789569f9ac..f032c4ca00879 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala @@ -164,7 +164,7 @@ abstract class SortMergeJoinExecTransformerBase( // Firstly, need to check if the Substrait plan for this operator can be successfully generated. if (substraitJoinType == JoinRel.JoinType.UNRECOGNIZED) { return ValidationResult - .failed(s"Found unsupported join type of $joinType for substrait: $substraitJoinType") + .notOk(s"Found unsupported join type of $joinType for substrait: $substraitJoinType") } val relNode = JoinUtils.createJoinRel( streamedKeys, @@ -253,7 +253,7 @@ case class SortMergeJoinExecTransformer( // Firstly, need to check if the Substrait plan for this operator can be successfully generated. if (substraitJoinType == JoinRel.JoinType.JOIN_TYPE_OUTER) { return ValidationResult - .failed(s"Found unsupported join type of $joinType for velox smj: $substraitJoinType") + .notOk(s"Found unsupported join type of $joinType for velox smj: $substraitJoinType") } super.doValidateInternal() } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala index b31471e21397a..74158d6332dc5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala @@ -67,7 +67,7 @@ case class TakeOrderedAndProjectExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (offset != 0) { - return ValidationResult.failed(s"Native TopK does not support offset: $offset") + return ValidationResult.notOk(s"Native TopK does not support offset: $offset") } var tagged: ValidationResult = null @@ -83,14 +83,14 @@ case class TakeOrderedAndProjectExecTransformer( ColumnarCollapseTransformStages.wrapInputIteratorTransformer(child) val sortPlan = SortExecTransformer(sortOrder, false, inputTransformer) val sortValidation = sortPlan.doValidate() - if (!sortValidation.ok()) { + if (!sortValidation.isValid) { return sortValidation } val limitPlan = LimitTransformer(sortPlan, offset, limit) tagged = limitPlan.doValidate() } - if (tagged.ok()) { + if (tagged.isValid) { val projectPlan = ProjectExecTransformer(projectList, child) tagged = projectPlan.doValidate() } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala index 4902b6c6cf1b7..628c08f290eb5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala @@ -165,7 +165,7 @@ case class WindowExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportWindowExec(windowExpression)) { return ValidationResult - .failed(s"Found unsupported window expression: ${windowExpression.mkString(", ")}") + .notOk(s"Found unsupported window expression: ${windowExpression.mkString(", ")}") } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala index 6068412fbad31..c93d01e7a12eb 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala @@ -145,7 +145,7 @@ case class WindowGroupLimitExecTransformer( override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction)) { return ValidationResult - .failed(s"Found unsupported rank like function: $rankLikeFunction") + .notOk(s"Found unsupported rank like function: $rankLikeFunction") } val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala index d78f21beaabfe..14d58bfa83771 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala @@ -150,8 +150,8 @@ case class WriteFilesExecTransformer( finalChildOutput.toStructType.fields, bucketSpec, caseInsensitiveOptions) - if (!validationResult.ok()) { - return ValidationResult.failed("Unsupported native write: " + validationResult.reason()) + if (!validationResult.isValid) { + return ValidationResult.notOk("Unsupported native write: " + validationResult.reason.get) } val substraitContext = new SubstraitContext diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala index 0c70e1ea7a7b1..71a76ff63dd1f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala @@ -26,29 +26,28 @@ import org.apache.gluten.substrait.plan.PlanBuilder import org.apache.gluten.substrait.rel.RelNode import org.apache.gluten.test.TestStats import org.apache.gluten.utils.LogLevelUtil +import org.apache.gluten.validate.NativePlanValidationInfo import org.apache.spark.sql.execution.SparkPlan import com.google.common.collect.Lists -sealed trait ValidationResult { - def ok(): Boolean - def reason(): String -} +import scala.collection.JavaConverters._ -object ValidationResult { - private case object Succeeded extends ValidationResult { - override def ok(): Boolean = true - override def reason(): String = throw new UnsupportedOperationException( - "Succeeded validation doesn't have failure details") - } +case class ValidationResult(isValid: Boolean, reason: Option[String]) - private case class Failed(override val reason: String) extends ValidationResult { - override def ok(): Boolean = false +object ValidationResult { + def ok: ValidationResult = ValidationResult(isValid = true, None) + def notOk(reason: String): ValidationResult = ValidationResult(isValid = false, Option(reason)) + def convertFromValidationInfo(info: NativePlanValidationInfo): ValidationResult = { + if (info.isSupported) { + ok + } else { + val fallbackInfo = info.getFallbackInfo.asScala + .mkString("Native validation failed:\n ", "\n ", "") + notOk(fallbackInfo) + } } - - def succeeded: ValidationResult = Succeeded - def failed(reason: String): ValidationResult = Failed(reason) } /** Every Gluten Operator should extend this trait. */ @@ -67,18 +66,17 @@ trait GlutenPlan extends SparkPlan with Convention.KnownBatchType with LogLevelU val schemaVaidationResult = BackendsApiManager.getValidatorApiInstance .doSchemaValidate(schema) .map { - reason => - ValidationResult.failed(s"Found schema check failure for $schema, due to: $reason") + reason => ValidationResult.notOk(s"Found schema check failure for $schema, due to: $reason") } - .getOrElse(ValidationResult.succeeded) - if (!schemaVaidationResult.ok()) { + .getOrElse(ValidationResult.ok) + if (!schemaVaidationResult.isValid) { TestStats.addFallBackClassName(this.getClass.toString) return schemaVaidationResult } try { TransformerState.enterValidation val res = doValidateInternal() - if (!res.ok()) { + if (!res.isValid) { TestStats.addFallBackClassName(this.getClass.toString) } res @@ -92,7 +90,7 @@ trait GlutenPlan extends SparkPlan with Convention.KnownBatchType with LogLevelU logValidationMessage( s"Validation failed with exception for plan: $nodeName, due to: ${e.getMessage}", e) - ValidationResult.failed(e.getMessage) + ValidationResult.notOk(e.getMessage) } finally { TransformerState.finishValidation } @@ -111,16 +109,16 @@ trait GlutenPlan extends SparkPlan with Convention.KnownBatchType with LogLevelU BackendsApiManager.getSparkPlanExecApiInstance.batchType } - protected def doValidateInternal(): ValidationResult = ValidationResult.succeeded + protected def doValidateInternal(): ValidationResult = ValidationResult.ok protected def doNativeValidation(context: SubstraitContext, node: RelNode): ValidationResult = { if (node != null && enableNativeValidation) { val planNode = PlanBuilder.makePlan(context, Lists.newArrayList(node)) val info = BackendsApiManager.getValidatorApiInstance .doNativeValidateWithFailureReason(planNode) - info.asResult() + ValidationResult.convertFromValidationInfo(info) } else { - ValidationResult.succeeded + ValidationResult.ok } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala index bfb926706cf7b..25674fd17147f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala @@ -40,11 +40,11 @@ object CollapseProjectExecTransformer extends Rule[SparkPlan] { val collapsedProject = p2.copy(projectList = CollapseProjectShim.buildCleanedProjectList(p1.projectList, p2.projectList)) val validationResult = collapsedProject.doValidate() - if (validationResult.ok()) { + if (validationResult.isValid) { logDebug(s"Collapse project $p1 and $p2.") collapsedProject } else { - logDebug(s"Failed to collapse project, due to ${validationResult.reason()}") + logDebug(s"Failed to collapse project, due to ${validationResult.reason.getOrElse("")}") p1 } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala index 491b54443d678..e334fcfbce889 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala @@ -243,7 +243,7 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP originalPlan .find(_.logicalLink.exists(_.fastEquals(p.logicalLink.get))) .filterNot(FallbackTags.nonEmpty) - .foreach(origin => FallbackTags.add(origin, FallbackTags.get(p))) + .foreach(origin => FallbackTags.tag(origin, FallbackTags.getTag(p))) case _ => } @@ -280,7 +280,7 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP } else { FallbackTags.addRecursively( vanillaSparkPlan, - FallbackTag.Exclusive(fallbackInfo.reason.getOrElse("Unknown reason"))) + TRANSFORM_UNSUPPORTED(fallbackInfo.reason, appendReasonIfExists = false)) FallbackNode(vanillaSparkPlan) } } else { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackTagRule.scala similarity index 89% rename from gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala rename to gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackTagRule.scala index f9eaa4179c67b..ddc6870e6c71f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackTagRule.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, QueryStageExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -50,42 +50,10 @@ sealed trait FallbackTag { if (FallbackTags.DEBUG) { Some(ExceptionUtils.getStackTrace(new Throwable())) } else None - - def reason(): String } -object FallbackTag { - - /** A tag that stores one reason text of fall back. */ - case class Appendable(override val reason: String) extends FallbackTag - - /** - * A tag that stores reason text of fall back. Other reasons will be discarded when this tag is - * added to plan. - */ - case class Exclusive(override val reason: String) extends FallbackTag - - trait Converter[T] { - def from(obj: T): Option[FallbackTag] - } - - object Converter { - implicit def asIs[T <: FallbackTag]: Converter[T] = (tag: T) => Some(tag) - - implicit object FromString extends Converter[String] { - override def from(reason: String): Option[FallbackTag] = Some(Appendable(reason)) - } - - implicit object FromValidationResult extends Converter[ValidationResult] { - override def from(result: ValidationResult): Option[FallbackTag] = { - if (result.ok()) { - return None - } - Some(Appendable(result.reason())) - } - } - } -} +case class TRANSFORM_UNSUPPORTED(reason: Option[String], appendReasonIfExists: Boolean = true) + extends FallbackTag object FallbackTags { val TAG: TreeNodeTag[FallbackTag] = @@ -102,7 +70,10 @@ object FallbackTags { * rules are passed. */ def nonEmpty(plan: SparkPlan): Boolean = { - getOption(plan).nonEmpty + getTagOption(plan) match { + case Some(TRANSFORM_UNSUPPORTED(_, _)) => true + case _ => false + } } /** @@ -113,51 +84,72 @@ object FallbackTags { */ def maybeOffloadable(plan: SparkPlan): Boolean = !nonEmpty(plan) - def add[T](plan: TreeNode[_], t: T)(implicit converter: FallbackTag.Converter[T]): Unit = { - val tagOption = getOption(plan) - val newTagOption = converter.from(t) - - val mergedTagOption: Option[FallbackTag] = - (tagOption ++ newTagOption).reduceOption[FallbackTag] { - // New tag comes while the plan was already tagged, merge. - case (_, exclusive: FallbackTag.Exclusive) => - exclusive - case (exclusive: FallbackTag.Exclusive, _) => - exclusive - case (l: FallbackTag.Appendable, r: FallbackTag.Appendable) => - FallbackTag.Appendable(s"${l.reason}; ${r.reason}") + def tag(plan: SparkPlan, hint: FallbackTag): Unit = { + val mergedHint = getTagOption(plan) + .map { + case originalHint @ TRANSFORM_UNSUPPORTED(Some(originalReason), originAppend) => + hint match { + case TRANSFORM_UNSUPPORTED(Some(newReason), append) => + if (originAppend && append) { + TRANSFORM_UNSUPPORTED(Some(originalReason + "; " + newReason)) + } else if (originAppend) { + TRANSFORM_UNSUPPORTED(Some(originalReason)) + } else if (append) { + TRANSFORM_UNSUPPORTED(Some(newReason)) + } else { + TRANSFORM_UNSUPPORTED(Some(originalReason), false) + } + case TRANSFORM_UNSUPPORTED(None, _) => + originalHint + case _ => + throw new GlutenNotSupportException( + "Plan was already tagged as non-transformable, " + + s"cannot mark it as transformable after that:\n${plan.toString()}") + } + case _ => + hint } - mergedTagOption - .foreach(mergedTag => plan.setTagValue(TAG, mergedTag)) + .getOrElse(hint) + plan.setTagValue(TAG, mergedHint) } - def addRecursively[T](plan: TreeNode[_], t: T)(implicit - converter: FallbackTag.Converter[T]): Unit = { - plan.foreach { - case _: GlutenPlan => // ignore - case other: TreeNode[_] => add(other, t) + def untag(plan: SparkPlan): Unit = { + plan.unsetTagValue(TAG) + } + + def add(plan: SparkPlan, validationResult: ValidationResult): Unit = { + if (!validationResult.isValid) { + tag(plan, TRANSFORM_UNSUPPORTED(validationResult.reason)) } } - def untag(plan: TreeNode[_]): Unit = { - plan.unsetTagValue(TAG) + def add(plan: SparkPlan, reason: String): Unit = { + tag(plan, TRANSFORM_UNSUPPORTED(Some(reason))) + } + + def addRecursively(plan: SparkPlan, hint: TRANSFORM_UNSUPPORTED): Unit = { + plan.foreach { + case _: GlutenPlan => // ignore + case other => tag(other, hint) + } } - def get(plan: TreeNode[_]): FallbackTag = { - getOption(plan).getOrElse( + def getTag(plan: SparkPlan): FallbackTag = { + getTagOption(plan).getOrElse( throw new IllegalStateException("Transform hint tag not set in plan: " + plan.toString())) } - def getOption(plan: TreeNode[_]): Option[FallbackTag] = { + def getTagOption(plan: SparkPlan): Option[FallbackTag] = { plan.getTagValue(TAG) } - implicit class EncodeFallbackTagImplicits(result: ValidationResult) { - def tagOnFallback(plan: TreeNode[_]): Unit = { - if (result.ok()) { + implicit class EncodeFallbackTagImplicits(validationResult: ValidationResult) { + def tagOnFallback(plan: SparkPlan): Unit = { + if (validationResult.isValid) { return } - add(plan, result) + val newTag = TRANSFORM_UNSUPPORTED(validationResult.reason) + tag(plan, newTag) } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 26c70293c7b3a..742c353410d6c 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -379,7 +379,7 @@ case class OffloadFilter() extends OffloadSingleNode with LogLevelUtil { val newScan = FilterHandler.pushFilterToScan(filter.condition, scan) newScan match { - case ts: TransformSupport if ts.doValidate().ok() => ts + case ts: TransformSupport if ts.doValidate().isValid => ts case _ => scan } } else scan @@ -550,12 +550,12 @@ object OffloadOthers { case plan: FileSourceScanExec => val transformer = ScanTransformerFactory.createFileSourceScanTransformer(plan) val validationResult = transformer.doValidate() - if (validationResult.ok()) { + if (validationResult.isValid) { logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") transformer } else { logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.") - FallbackTags.add(plan, validationResult.reason()) + FallbackTags.add(plan, validationResult.reason.get) plan } case plan: BatchScanExec => @@ -565,12 +565,12 @@ object OffloadOthers { val hiveTableScanExecTransformer = BackendsApiManager.getSparkPlanExecApiInstance.genHiveTableScanExecTransformer(plan) val validateResult = hiveTableScanExecTransformer.doValidate() - if (validateResult.ok()) { + if (validateResult.isValid) { logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") return hiveTableScanExecTransformer } logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.") - FallbackTags.add(plan, validateResult.reason()) + FallbackTags.add(plan, validateResult.reason.get) plan case other => throw new GlutenNotSupportException(s"${other.getClass.toString} is not supported.") diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala index ac35ac83bb21d..d32de32ebb322 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RemoveNativeWriteFilesSortAndProject.scala @@ -71,7 +71,7 @@ object NativeWriteFilesWithSkippingSortAndProject extends Logging { } val transformer = ProjectExecTransformer(newProjectList, p.child) val validationResult = transformer.doValidate() - if (validationResult.ok()) { + if (validationResult.isValid) { Some(transformer) } else { // If we can not transform the project, then we fallback to origin plan which means diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala index 4070a0a586120..611d6db0bd483 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala @@ -37,7 +37,7 @@ class PushFilterToScan(validator: Validator) extends RasRule[SparkPlan] { val newScan = FilterHandler.pushFilterToScan(filter.condition, scan) newScan match { - case ts: TransformSupport if ts.doValidate().ok() => + case ts: TransformSupport if ts.doValidate().isValid => List(filter.withNewChildren(List(ts))) case _ => List.empty diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala index 43f52a9e4758e..8091127da0bfa 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala @@ -106,7 +106,7 @@ object RasOffload { case Validator.Passed => val offloaded = base.offload(from) val offloadedNodes = offloaded.collect[GlutenPlan] { case t: GlutenPlan => t } - if (offloadedNodes.exists(!_.doValidate().ok())) { + if (offloadedNodes.exists(!_.doValidate().isValid)) { // 4. If native validation fails on the offloaded node, return the // original one. from diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala index e005a3dc81639..2abd4d7d48074 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala @@ -74,7 +74,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules: Seq[RewriteSingleNode] case p if !p.isInstanceOf[ProjectExec] && !p.isInstanceOf[RewrittenNodeWall] => p } assert(target.size == 1) - FallbackTags.getOption(target.head) + FallbackTags.getTagOption(target.head) } private def applyRewriteRules(origin: SparkPlan): (SparkPlan, Option[String]) = { @@ -112,10 +112,10 @@ class RewriteSparkPlanRulesManager private (rewriteRules: Seq[RewriteSingleNode] origin } else { addHint.apply(rewrittenPlan) - val tag = getFallbackTagBack(rewrittenPlan) - if (tag.isDefined) { + val hint = getFallbackTagBack(rewrittenPlan) + if (hint.isDefined) { // If the rewritten plan is still not transformable, return the original plan. - FallbackTags.add(origin, tag.get) + FallbackTags.tag(origin, hint.get) origin } else { rewrittenPlan.transformUp { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala index 903723ccb56b5..b6236ae9a536e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala @@ -19,7 +19,7 @@ package org.apache.gluten.extension.columnar.validator import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.{BackendsApiManager, BackendSettingsApi} import org.apache.gluten.expression.ExpressionUtils -import org.apache.gluten.extension.columnar.FallbackTags +import org.apache.gluten.extension.columnar.{FallbackTags, TRANSFORM_UNSUPPORTED} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.execution._ @@ -109,8 +109,8 @@ object Validators { private object FallbackByHint extends Validator { override def validate(plan: SparkPlan): Validator.OutCome = { if (FallbackTags.nonEmpty(plan)) { - val tag = FallbackTags.get(plan) - return fail(tag.reason()) + val hint = FallbackTags.getTag(plan).asInstanceOf[TRANSFORM_UNSUPPORTED] + return fail(hint.reason.getOrElse("Reason not recorded")) } pass() } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index 31175a43fbaec..85a4dd3878a30 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.extension.{GlutenPlan, ValidationResult} +import org.apache.gluten.extension.GlutenPlan +import org.apache.gluten.extension.ValidationResult import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark._ @@ -118,10 +119,10 @@ case class ColumnarShuffleExchangeExec( .doColumnarShuffleExchangeExecValidate(outputPartitioning, child) .map { reason => - ValidationResult.failed( + ValidationResult.notOk( s"Found schema check failure for schema ${child.schema} due to: $reason") } - .getOrElse(ValidationResult.succeeded) + .getOrElse(ValidationResult.ok) } override def nodeName: String = "ColumnarExchange" diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala index 338136b6d8704..781dc6b6f717a 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.execution import org.apache.gluten.execution.WholeStageTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.utils.PlanUtil import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.execution.GlutenFallbackReporter.FALLBACK_REASON_TAG import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} @@ -59,8 +59,8 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { p: SparkPlan, fallbackNodeToReason: mutable.HashMap[String, String] ): Unit = { - p.logicalLink.flatMap(FallbackTags.getOption) match { - case Some(tag) => addFallbackNodeWithReason(p, tag.reason(), fallbackNodeToReason) + p.logicalLink.flatMap(_.getTagValue(FALLBACK_REASON_TAG)) match { + case Some(reason) => addFallbackNodeWithReason(p, reason, fallbackNodeToReason) case _ => // If the SparkPlan does not have fallback reason, then there are two options: // 1. Gluten ignore that plan and it's a kind of fallback diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala index 67ecf81b9ffda..d41dce882602b 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenFallbackReporter.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.execution import org.apache.gluten.GlutenConfig import org.apache.gluten.events.GlutenPlanFallbackEvent import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.FallbackTags +import org.apache.gluten.extension.columnar.{FallbackTags, TRANSFORM_UNSUPPORTED} import org.apache.gluten.utils.LogLevelUtil import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat +import org.apache.spark.sql.execution.GlutenFallbackReporter.FALLBACK_REASON_TAG import org.apache.spark.sql.execution.ui.GlutenEventUtils /** @@ -56,14 +58,41 @@ case class GlutenFallbackReporter(glutenConfig: GlutenConfig, spark: SparkSessio plan.foreachUp { case _: GlutenPlan => // ignore case p: SparkPlan if FallbackTags.nonEmpty(p) => - val tag = FallbackTags.get(p) - logFallbackReason(validationLogLevel, p.nodeName, tag.reason()) - // With in next round stage in AQE, the physical plan would be a new instance that - // can not preserve the tag, so we need to set the fallback reason to logical plan. - // Then we can be aware of the fallback reason for the whole plan. - // If a logical plan mapping to several physical plan, we add all reason into - // that logical plan to make sure we do not lose any fallback reason. - p.logicalLink.foreach(logicalPlan => FallbackTags.add(logicalPlan, tag)) + FallbackTags.getTag(p) match { + case TRANSFORM_UNSUPPORTED(Some(reason), append) => + logFallbackReason(validationLogLevel, p.nodeName, reason) + // With in next round stage in AQE, the physical plan would be a new instance that + // can not preserve the tag, so we need to set the fallback reason to logical plan. + // Then we can be aware of the fallback reason for the whole plan. + // If a logical plan mapping to several physical plan, we add all reason into + // that logical plan to make sure we do not lose any fallback reason. + p.logicalLink.foreach { + logicalPlan => + val newReason = logicalPlan + .getTagValue(FALLBACK_REASON_TAG) + .map { + lastReason => + if (!append) { + lastReason + } else if (lastReason.contains(reason)) { + // use the last reason, as the reason is redundant + lastReason + } else if (reason.contains(lastReason)) { + // overwrite the reason + reason + } else { + // add the new reason + lastReason + "; " + reason + } + } + .getOrElse(reason) + logicalPlan.setTagValue(FALLBACK_REASON_TAG, newReason) + } + case TRANSFORM_UNSUPPORTED(_, _) => + logFallbackReason(validationLogLevel, p.nodeName, "unknown reason") + case _ => + throw new IllegalStateException("Unreachable code") + } case _ => } } @@ -90,4 +119,7 @@ case class GlutenFallbackReporter(glutenConfig: GlutenConfig, spark: SparkSessio } } -object GlutenFallbackReporter {} +object GlutenFallbackReporter { + // A tag used to inject to logical plan to preserve the fallback reason + val FALLBACK_REASON_TAG = new TreeNodeTag[String]("GLUTEN_FALLBACK_REASON") +} diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala index 6a9da0a9cf92e..ecedc1bae01c8 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala @@ -62,7 +62,7 @@ case class EvalPythonExecTransformer( // All udfs should be scalar python udf for (udf <- udfs) { if (!PythonUDF.isScalarPythonUDF(udf)) { - return ValidationResult.failed(s"$udf is not scalar python udf") + return ValidationResult.notOk(s"$udf is not scalar python udf") } } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala index 2a3ba79ebc2a5..95793e5dc9354 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala @@ -202,7 +202,7 @@ object HiveTableScanExecTransformer { hiveTableScan.relation, hiveTableScan.partitionPruningPred)(hiveTableScan.session) hiveTableScanTransformer.doValidate() - case _ => ValidationResult.failed("Is not a Hive scan") + case _ => ValidationResult.notOk("Is not a Hive scan") } } diff --git a/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala b/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala index 31e6c6940cd97..1cd735cf7ee7d 100644 --- a/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala +++ b/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala @@ -57,7 +57,7 @@ case class DeltaScanTransformer( _.name == "__delta_internal_is_row_deleted") || requiredSchema.fields.exists( _.name == "__delta_internal_row_index") ) { - return ValidationResult.failed(s"Deletion vector is not supported in native.") + return ValidationResult.notOk(s"Deletion vector is not supported in native.") } super.doValidateInternal() diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala index 2ee1573ea07ab..fe37da206a561 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala @@ -35,7 +35,7 @@ case class CustomerColumnarPreRules(session: SparkSession) extends Rule[SparkPla fileSourceScan.tableIdentifier, fileSourceScan.disableBucketedScan ) - if (transformer.doValidate().ok()) { + if (transformer.doValidate().isValid) { transformer } else { plan diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 54d7596b602c5..b9c9d8a270bf2 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} +import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags, TRANSFORM_UNSUPPORTED} import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -124,16 +124,17 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Tag not transformable more than once") { val originalPlan = UnaryOp1(LeafOp(supportsColumnar = true)) - FallbackTags.add(originalPlan, "fake reason") + FallbackTags.tag(originalPlan, TRANSFORM_UNSUPPORTED(Some("fake reason"))) val rule = FallbackEmptySchemaRelation() val newPlan = rule.apply(originalPlan) - val reason = FallbackTags.get(newPlan).reason() + val reason = FallbackTags.getTag(newPlan).asInstanceOf[TRANSFORM_UNSUPPORTED].reason + assert(reason.isDefined) if (BackendsApiManager.getSettings.fallbackOnEmptySchema(newPlan)) { assert( - reason.contains("fake reason") && - reason.contains("at least one of its children has empty output")) + reason.get.contains("fake reason") && + reason.get.contains("at least one of its children has empty output")) } else { - assert(reason.contains("fake reason")) + assert(reason.get.contains("fake reason")) } } diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala index 2ee1573ea07ab..fe37da206a561 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala @@ -35,7 +35,7 @@ case class CustomerColumnarPreRules(session: SparkSession) extends Rule[SparkPla fileSourceScan.tableIdentifier, fileSourceScan.disableBucketedScan ) - if (transformer.doValidate().ok()) { + if (transformer.doValidate().isValid) { transformer } else { plan diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 5150a47688519..8ce0af8df051e 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} +import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags, TRANSFORM_UNSUPPORTED} import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -125,16 +125,17 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Tag not transformable more than once") { val originalPlan = UnaryOp1(LeafOp(supportsColumnar = true)) - FallbackTags.add(originalPlan, "fake reason") + FallbackTags.tag(originalPlan, TRANSFORM_UNSUPPORTED(Some("fake reason"))) val rule = FallbackEmptySchemaRelation() val newPlan = rule.apply(originalPlan) - val reason = FallbackTags.get(newPlan).reason() + val reason = FallbackTags.getTag(newPlan).asInstanceOf[TRANSFORM_UNSUPPORTED].reason + assert(reason.isDefined) if (BackendsApiManager.getSettings.fallbackOnEmptySchema(newPlan)) { assert( - reason.contains("fake reason") && - reason.contains("at least one of its children has empty output")) + reason.get.contains("fake reason") && + reason.get.contains("at least one of its children has empty output")) } else { - assert(reason.contains("fake reason")) + assert(reason.get.contains("fake reason")) } } diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala index 2ee1573ea07ab..fe37da206a561 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala @@ -35,7 +35,7 @@ case class CustomerColumnarPreRules(session: SparkSession) extends Rule[SparkPla fileSourceScan.tableIdentifier, fileSourceScan.disableBucketedScan ) - if (transformer.doValidate().ok()) { + if (transformer.doValidate().isValid) { transformer } else { plan diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 5150a47688519..8ce0af8df051e 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} +import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags, TRANSFORM_UNSUPPORTED} import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -125,16 +125,17 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Tag not transformable more than once") { val originalPlan = UnaryOp1(LeafOp(supportsColumnar = true)) - FallbackTags.add(originalPlan, "fake reason") + FallbackTags.tag(originalPlan, TRANSFORM_UNSUPPORTED(Some("fake reason"))) val rule = FallbackEmptySchemaRelation() val newPlan = rule.apply(originalPlan) - val reason = FallbackTags.get(newPlan).reason() + val reason = FallbackTags.getTag(newPlan).asInstanceOf[TRANSFORM_UNSUPPORTED].reason + assert(reason.isDefined) if (BackendsApiManager.getSettings.fallbackOnEmptySchema(newPlan)) { assert( - reason.contains("fake reason") && - reason.contains("at least one of its children has empty output")) + reason.get.contains("fake reason") && + reason.get.contains("at least one of its children has empty output")) } else { - assert(reason.contains("fake reason")) + assert(reason.get.contains("fake reason")) } } diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala index 2ee1573ea07ab..fe37da206a561 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/CustomerColumnarPreRules.scala @@ -35,7 +35,7 @@ case class CustomerColumnarPreRules(session: SparkSession) extends Rule[SparkPla fileSourceScan.tableIdentifier, fileSourceScan.disableBucketedScan ) - if (transformer.doValidate().ok()) { + if (transformer.doValidate().isValid) { transformer } else { plan