Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Minor code cleanups against fallback tagging #6320

Merged
merged 9 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,20 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
format match {
case ParquetReadFormat =>
if (validateFilePath) {
ValidationResult.ok
ValidationResult.succeeded
} else {
ValidationResult.notOk("Validate file path failed.")
ValidationResult.failed("Validate file path failed.")
}
case OrcReadFormat => ValidationResult.ok
case MergeTreeReadFormat => ValidationResult.ok
case OrcReadFormat => ValidationResult.succeeded
case MergeTreeReadFormat => ValidationResult.succeeded
case TextReadFormat =>
if (!hasComplexType) {
ValidationResult.ok
ValidationResult.succeeded
} else {
ValidationResult.notOk("Has complex type.")
ValidationResult.failed("Has complex type.")
}
case JsonReadFormat => ValidationResult.ok
case _ => ValidationResult.notOk(s"Unsupported file format $format")
case JsonReadFormat => ValidationResult.succeeded
case _ => ValidationResult.failed(s"Unsupported file format $format")
}
}

Expand Down Expand Up @@ -291,7 +291,7 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
fields: Array[StructField],
bucketSpec: Option[BucketSpec],
options: Map[String, String]): ValidationResult =
ValidationResult.notOk("CH backend is unsupported.")
ValidationResult.failed("CH backend is unsupported.")

override def enableNativeWriteFiles(): Boolean = {
GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
case _: InnerLike =>
case _ =>
if (joinType == LeftSemi || condition.isDefined) {
return ValidationResult.notOk(
return ValidationResult.failed(
s"Broadcast Nested Loop join is not supported join type $joinType with conditions")
}
}

ValidationResult.ok
ValidationResult.succeeded
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ case class CHGenerateExecTransformer(
override protected def doGeneratorValidate(
generator: Generator,
outer: Boolean): ValidationResult =
ValidationResult.ok
ValidationResult.succeeded

override protected def getRelNode(
context: SubstraitContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ case class CHShuffledHashJoinExecTransformer(
right.outputSet,
condition)
if (shouldFallback) {
return ValidationResult.notOk("ch join validate fail")
return ValidationResult.failed("ch join validate fail")
}
super.doValidateInternal()
}
Expand Down Expand Up @@ -118,10 +118,10 @@ case class CHBroadcastHashJoinExecTransformer(
condition)

if (shouldFallback) {
return ValidationResult.notOk("ch join validate fail")
return ValidationResult.failed("ch join validate fail")
}
if (isNullAwareAntiJoin) {
return ValidationResult.notOk("ch does not support NAAJ")
return ValidationResult.failed("ch does not support NAAJ")
}
super.doValidateInternal()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ case class CHSortMergeJoinExecTransformer(
right.outputSet,
condition)
if (shouldFallback) {
return ValidationResult.notOk("ch join validate fail")
return ValidationResult.failed("ch join validate fail")
}
super.doValidateInternal()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend
!columnarConf.enableColumnarBroadcastExchange ||
!columnarConf.enableColumnarBroadcastJoin
) {
ValidationResult.notOk(
ValidationResult.failed(
"columnar broadcast exchange is disabled or " +
"columnar broadcast join is disabled")
} else {
if (FallbackTags.nonEmpty(bhj)) {
ValidationResult.notOk("broadcast join is already tagged as not transformable")
ValidationResult.failed("broadcast join is already tagged as not transformable")
} else {
val bhjTransformer = BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastHashJoinExecTransformer(
Expand All @@ -75,7 +75,7 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend
bhj.right,
bhj.isNullAwareAntiJoin)
val isBhjTransformable = bhjTransformer.doValidate()
if (isBhjTransformable.isValid) {
if (isBhjTransformable.ok()) {
val exchangeTransformer = ColumnarBroadcastExchangeExec(mode, child)
exchangeTransformer.doValidate()
} else {
Expand Down Expand Up @@ -111,12 +111,12 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend
!GlutenConfig.getConf.enableColumnarBroadcastExchange ||
!GlutenConfig.getConf.enableColumnarBroadcastJoin
) {
ValidationResult.notOk(
ValidationResult.failed(
"columnar broadcast exchange is disabled or " +
"columnar broadcast join is disabled")
} else {
if (FallbackTags.nonEmpty(bnlj)) {
ValidationResult.notOk("broadcast join is already tagged as not transformable")
ValidationResult.failed("broadcast join is already tagged as not transformable")
} else {
val transformer = BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastNestedLoopJoinExecTransformer(
Expand All @@ -126,7 +126,7 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend
bnlj.joinType,
bnlj.condition)
val isTransformable = transformer.doValidate()
if (isTransformable.isValid) {
if (isTransformable.ok()) {
val exchangeTransformer = ColumnarBroadcastExchangeExec(mode, child)
exchangeTransformer.doValidate()
} else {
Expand Down Expand Up @@ -242,7 +242,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
maybeExchange match {
case Some(exchange @ BroadcastExchangeExec(_, _)) =>
isTransformable.tagOnFallback(plan)
if (!isTransformable.isValid) {
if (!isTransformable.ok) {
FallbackTags.add(exchange, isTransformable)
}
case None =>
Expand Down Expand Up @@ -280,7 +280,7 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
plan,
"it's a materialized broadcast exchange or reused broadcast exchange")
case ColumnarBroadcastExchangeExec(mode, child) =>
if (!isTransformable.isValid) {
if (!isTransformable.ok) {
throw new IllegalStateException(
s"BroadcastExchange has already been" +
s" transformed to columnar version but BHJ is determined as" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ case class CHColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBase(c
s"${field.dataType} is not supported in ColumnarToRowExecBase.")
}
}
ValidationResult.ok
ValidationResult.succeeded
}

override def doExecuteInternal(): RDD[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ object VeloxBackendSettings extends BackendSettingsApi {
// Collect unsupported types.
val unsupportedDataTypeReason = fields.collect(validatorFunc)
if (unsupportedDataTypeReason.isEmpty) {
ValidationResult.ok
ValidationResult.succeeded
} else {
ValidationResult.notOk(
ValidationResult.failed(
s"Found unsupported data type in $format: ${unsupportedDataTypeReason.mkString(", ")}.")
}
}
Expand Down Expand Up @@ -135,10 +135,10 @@ object VeloxBackendSettings extends BackendSettingsApi {
} else {
validateTypes(parquetTypeValidatorWithComplexTypeFallback)
}
case DwrfReadFormat => ValidationResult.ok
case DwrfReadFormat => ValidationResult.succeeded
case OrcReadFormat =>
if (!GlutenConfig.getConf.veloxOrcScanEnabled) {
ValidationResult.notOk(s"Velox ORC scan is turned off.")
ValidationResult.failed(s"Velox ORC scan is turned off.")
} else {
val typeValidator: PartialFunction[StructField, String] = {
case StructField(_, arrayType: ArrayType, _, _)
Expand All @@ -164,7 +164,7 @@ object VeloxBackendSettings extends BackendSettingsApi {
validateTypes(orcTypeValidatorWithComplexTypeFallback)
}
}
case _ => ValidationResult.notOk(s"Unsupported file format for $format.")
case _ => ValidationResult.failed(s"Unsupported file format for $format.")
}
}

Expand Down Expand Up @@ -284,8 +284,8 @@ object VeloxBackendSettings extends BackendSettingsApi {
.orElse(validateDataTypes())
.orElse(validateWriteFilesOptions())
.orElse(validateBucketSpec()) match {
case Some(reason) => ValidationResult.notOk(reason)
case _ => ValidationResult.ok
case Some(reason) => ValidationResult.failed(reason)
case _ => ValidationResult.succeeded
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ child.output
val projectTransformer = ProjectExecTransformer(projectList, child)
val validationResult = projectTransformer.doValidate()
if (validationResult.isValid) {
if (validationResult.ok()) {
val newChild = maybeAddAppendBatchesExec(projectTransformer)
ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output.drop(1))
} else {
Expand All @@ -392,7 +392,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
val projectTransformer = ProjectExecTransformer(projectList, child)
val projectBeforeSortValidationResult = projectTransformer.doValidate()
// Make sure we support offload hash expression
val projectBeforeSort = if (projectBeforeSortValidationResult.isValid) {
val projectBeforeSort = if (projectBeforeSortValidationResult.ok()) {
projectTransformer
} else {
val project = ProjectExec(projectList, child)
Expand All @@ -405,7 +405,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
val dropSortColumnTransformer =
ProjectExecTransformer(projectList.drop(1), sortByHashCode)
val validationResult = dropSortColumnTransformer.doValidate()
if (validationResult.isValid) {
if (validationResult.ok()) {
val newChild = maybeAddAppendBatchesExec(dropSortColumnTransformer)
ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output)
} else {
Expand Down Expand Up @@ -888,7 +888,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
case p @ LimitTransformer(SortExecTransformer(sortOrder, _, child, _), 0, count) =>
val global = child.outputPartitioning.satisfies(AllTuples)
val topN = TopNTransformer(count, sortOrder, global, child)
if (topN.doValidate().isValid) {
if (topN.doValidate().ok()) {
topN
} else {
p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ case class GenerateExecTransformer(
generator: Generator,
outer: Boolean): ValidationResult = {
if (!supportsGenerate(generator, outer)) {
ValidationResult.notOk(
ValidationResult.failed(
s"Velox backend does not support this generator: ${generator.getClass.getSimpleName}" +
s", outer: $outer")
} else {
ValidationResult.ok
ValidationResult.succeeded
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas
s"VeloxColumnarToRowExec.")
}
}
ValidationResult.ok
ValidationResult.succeeded
}

override def doExecuteInternal(): RDD[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.gluten.validate;

import org.apache.gluten.extension.ValidationResult;

import java.util.Vector;

public class NativePlanValidationInfo {
Expand All @@ -30,11 +32,13 @@ public NativePlanValidationInfo(int isSupported, String fallbackInfo) {
}
}

public boolean isSupported() {
return isSupported == 1;
}

public Vector<String> getFallbackInfo() {
return fallbackInfo;
public ValidationResult asResult() {
if (isSupported == 1) {
return ValidationResult.succeeded();
}
return ValidationResult.failed(
String.format(
"Native validation failed: %n%s",
fallbackInfo.stream().reduce((l, r) -> l + "\n" + r)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ trait BackendSettingsApi {
format: ReadFileFormat,
fields: Array[StructField],
partTable: Boolean,
paths: Seq[String]): ValidationResult = ValidationResult.ok
paths: Seq[String]): ValidationResult = ValidationResult.succeeded
def supportWriteFilesExec(
format: FileFormat,
fields: Array[StructField],
bucketSpec: Option[BucketSpec],
options: Map[String, String]): ValidationResult = ValidationResult.ok
options: Map[String, String]): ValidationResult = ValidationResult.succeeded
def supportNativeWrite(fields: Array[StructField]): Boolean = true
def supportNativeMetadataColumns(): Boolean = false
def supportNativeRowIndexColumn(): Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP
if (remainingCondition == null) {
// All the filters can be pushed down and the computing of this Filter
// is not needed.
return ValidationResult.ok
return ValidationResult.succeeded
}
val substraitContext = new SubstraitContext
val operatorId = substraitContext.nextOperatorId(this.nodeName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource

val validationResult = BackendsApiManager.getSettings
.supportFileFormatRead(fileFormat, fields, getPartitionSchema.nonEmpty, getInputFilePaths)
if (!validationResult.isValid) {
if (!validationResult.ok()) {
return validationResult
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,18 @@ abstract class BatchScanExecTransformerBase(

override def doValidateInternal(): ValidationResult = {
if (pushedAggregate.nonEmpty) {
return ValidationResult.notOk(s"Unsupported aggregation push down for $scan.")
return ValidationResult.failed(s"Unsupported aggregation push down for $scan.")
}

if (
SparkShimLoader.getSparkShims.findRowIndexColumnIndexInSchema(schema) > 0 &&
!BackendsApiManager.getSettings.supportNativeRowIndexColumn()
) {
return ValidationResult.notOk("Unsupported row index column scan in native.")
return ValidationResult.failed("Unsupported row index column scan in native.")
}

if (hasUnsupportedColumns) {
return ValidationResult.notOk(s"Unsupported columns scan in native.")
return ValidationResult.failed(s"Unsupported columns scan in native.")
}

super.doValidateInternal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,34 +163,35 @@ abstract class BroadcastNestedLoopJoinExecTransformer(

def validateJoinTypeAndBuildSide(): ValidationResult = {
val result = joinType match {
case _: InnerLike | LeftOuter | RightOuter => ValidationResult.ok
case _: InnerLike | LeftOuter | RightOuter => ValidationResult.succeeded
case _ =>
ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin")
ValidationResult.failed(s"$joinType join is not supported with BroadcastNestedLoopJoin")
}

if (!result.isValid) {
if (!result.ok()) {
return result
}

(joinType, buildSide) match {
case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) =>
ValidationResult.notOk(s"$joinType join is not supported with $buildSide")
case _ => ValidationResult.ok // continue
ValidationResult.failed(s"$joinType join is not supported with $buildSide")
case _ => ValidationResult.succeeded // continue
}
}

override protected def doValidateInternal(): ValidationResult = {
if (!GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled) {
return ValidationResult.notOk(
return ValidationResult.failed(
s"Config ${GlutenConfig.BROADCAST_NESTED_LOOP_JOIN_TRANSFORMER_ENABLED.key} not enabled")
}

if (substraitJoinType == CrossRel.JoinType.UNRECOGNIZED) {
return ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin")
return ValidationResult.failed(
s"$joinType join is not supported with BroadcastNestedLoopJoin")
}

val validateResult = validateJoinTypeAndBuildSide()
if (!validateResult.isValid) {
if (!validateResult.ok()) {
return validateResult
}

Expand Down
Loading
Loading