From 398fafa2b4127d0a7a056c4754252027d7589fd5 Mon Sep 17 00:00:00 2001 From: yan ma Date: Mon, 5 Aug 2024 20:56:41 +0800 Subject: [PATCH] disable complex type fallback for parquet --- .../backendsapi/velox/VeloxBackend.scala | 56 +------------- .../velox/VeloxSparkPlanExecApi.scala | 2 +- .../expression/ExpressionTransformer.scala | 7 +- .../gluten/execution/TestOperator.scala | 2 +- .../VeloxParquetDataTypeValidationSuite.scala | 16 ---- .../expression/ExpressionConverter.scala | 76 ++++++++++++++++--- 6 files changed, 75 insertions(+), 84 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index d32911f4a4c76..ccc924a58670a 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -99,56 +99,8 @@ object VeloxBackendSettings extends BackendSettingsApi { } } - val parquetTypeValidatorWithComplexTypeFallback: PartialFunction[StructField, String] = { - case StructField(_, arrayType: ArrayType, _, _) => - arrayType.simpleString + " is forced to fallback." - case StructField(_, mapType: MapType, _, _) => - mapType.simpleString + " is forced to fallback." - case StructField(_, structType: StructType, _, _) => - structType.simpleString + " is forced to fallback." - case StructField(_, timestampType: TimestampType, _, _) - if GlutenConfig.getConf.forceParquetTimestampTypeScanFallbackEnabled => - timestampType.simpleString + " is forced to fallback." - } - val orcTypeValidatorWithComplexTypeFallback: PartialFunction[StructField, String] = { - case StructField(_, arrayType: ArrayType, _, _) => - arrayType.simpleString + " is forced to fallback." - case StructField(_, mapType: MapType, _, _) => - mapType.simpleString + " is forced to fallback." - case StructField(_, structType: StructType, _, _) => - structType.simpleString + " is forced to fallback." - case StructField(_, stringType: StringType, _, metadata) - if isCharType(stringType, metadata) => - CharVarcharUtils.getRawTypeString(metadata) + " not support" - case StructField(_, TimestampType, _, _) => "TimestampType not support" - } format match { - case ParquetReadFormat => - val typeValidator: PartialFunction[StructField, String] = { - // Parquet scan of nested array with struct/array as element type is unsupported in Velox. - case StructField(_, arrayType: ArrayType, _, _) - if arrayType.elementType.isInstanceOf[StructType] => - "StructType as element in ArrayType" - case StructField(_, arrayType: ArrayType, _, _) - if arrayType.elementType.isInstanceOf[ArrayType] => - "ArrayType as element in ArrayType" - // Parquet scan of nested map with struct as key type, - // or array type as value type is not supported in Velox. - case StructField(_, mapType: MapType, _, _) if mapType.keyType.isInstanceOf[StructType] => - "StructType as Key in MapType" - case StructField(_, mapType: MapType, _, _) - if mapType.valueType.isInstanceOf[ArrayType] => - "ArrayType as Value in MapType" - case StructField(_, TimestampType, _, _) - if GlutenConfig.getConf.forceParquetTimestampTypeScanFallbackEnabled => - "TimestampType" - } - if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) { - validateTypes(typeValidator) - } else { - validateTypes(parquetTypeValidatorWithComplexTypeFallback) - } - case DwrfReadFormat => ValidationResult.succeeded + case ParquetReadFormat | DwrfReadFormat => ValidationResult.succeeded case OrcReadFormat => if (!GlutenConfig.getConf.veloxOrcScanEnabled) { ValidationResult.failed(s"Velox ORC scan is turned off.") @@ -171,11 +123,7 @@ object VeloxBackendSettings extends BackendSettingsApi { CharVarcharUtils.getRawTypeString(metadata) + " not support" case StructField(_, TimestampType, _, _) => "TimestampType not support" } - if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) { - validateTypes(typeValidator) - } else { - validateTypes(orcTypeValidatorWithComplexTypeFallback) - } + validateTypes(typeValidator) } case _ => ValidationResult.failed(s"Unsupported file format for $format.") } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 89d29781c0799..da2e69ffdfe27 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -724,7 +724,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { childTransformer: ExpressionTransformer, ordinal: Int, original: GetStructField): ExpressionTransformer = { - VeloxGetStructFieldTransformer(substraitExprName, childTransformer, original) + VeloxGetStructFieldTransformer(substraitExprName, childTransformer, ordinal, original) } /** diff --git a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala index 51b19ab140d9a..b668c5879ac94 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala @@ -54,19 +54,20 @@ case class VeloxNamedStructTransformer( case class VeloxGetStructFieldTransformer( substraitExprName: String, child: ExpressionTransformer, + ordinal: Int, original: GetStructField) extends UnaryExpressionTransformer { override def doTransform(args: Object): ExpressionNode = { val childNode = child.doTransform(args) childNode match { case node: StructLiteralNode => - node.getFieldLiteral(original.ordinal) + node.getFieldLiteral(ordinal) case node: SelectionNode => // Append the nested index to selection node. - node.addNestedChildIdx(JInteger.valueOf(original.ordinal)) + node.addNestedChildIdx(JInteger.valueOf(ordinal)) case node: NullLiteralNode => val nodeType = - node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(original.ordinal) + node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(ordinal) ExpressionBuilder.makeNullLiteral(nodeType) case other => throw new GlutenNotSupportException(s"$other is not supported.") diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala index a0ea7d7267b46..b37a7aa98c883 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala @@ -1713,7 +1713,7 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla sql("CREATE TABLE t2(id INT, l ARRAY>) USING PARQUET") sql("INSERT INTO t2 VALUES(1, ARRAY(STRUCT(1, 100))), (2, ARRAY(STRUCT(2, 200)))") - runQueryAndCompare("SELECT first(l) FROM t2")(df => checkFallbackOperators(df, 1)) + runQueryAndCompare("SELECT first(l) FROM t2")(df => checkFallbackOperators(df, 0)) } } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala index 85b3f32a76842..8b6cc63c954d2 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala @@ -427,22 +427,6 @@ class VeloxParquetDataTypeValidationSuite extends VeloxWholeStageTransformerSuit } } - test("Force complex type scan fallback") { - withSQLConf(("spark.gluten.sql.complexType.scan.fallback.enabled", "true")) { - val df = spark.sql("select struct from type1") - val executedPlan = getExecutedPlan(df) - assert(!executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer])) - } - } - - test("Force timestamp type scan fallback") { - withSQLConf(("spark.gluten.sql.parquet.timestampType.scan.fallback.enabled", "true")) { - val df = spark.sql("select timestamp from type1") - val executedPlan = getExecutedPlan(df) - assert(!executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer])) - } - } - test("Decimal type") { // Validation: BatchScan Project Aggregate Expand Sort Limit runQueryAndCompare( diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 3ca66b51897b0..a986183c983cf 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -33,6 +33,8 @@ import org.apache.spark.sql.hive.HiveUDFTransformer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import scala.collection.mutable.ArrayBuffer + trait Transformable { def getTransformer(childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer } @@ -349,15 +351,26 @@ object ExpressionConverter extends SQLConfHelper with Logging { expr => replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap)), m) case getStructField: GetStructField => - // Different backends may have different result. - BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer( - substraitExprName, - replaceWithExpressionTransformerInternal( - getStructField.child, - attributeSeq, - expressionsMap), - getStructField.ordinal, - getStructField) + try { + val bindRef = + bindGetStructField(getStructField, attributeSeq) + // Different backends may have different result. + BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer( + substraitExprName, + replaceWithExpressionTransformerInternal( + getStructField.child, + attributeSeq, + expressionsMap), + bindRef.ordinal, + getStructField) + } catch { + case e: IllegalStateException => + // This situation may need developers to fix, although we just throw the below + // exception to let the corresponding operator fall back. + throw new UnsupportedOperationException( + s"Failed to bind reference for $getStructField: ${e.getMessage}") + } + case getArrayStructFields: GetArrayStructFields => GenericExpressionTransformer( substraitExprName, @@ -729,4 +742,49 @@ object ExpressionConverter extends SQLConfHelper with Logging { } substraitExprName } + + private def bindGetStructField( + structField: GetStructField, + input: AttributeSeq): BoundReference = { + // get the new ordinal base input + var newOrdinal: Int = -1 + val names = new ArrayBuffer[String] + var root: Expression = structField + while (root.isInstanceOf[GetStructField]) { + val curField = root.asInstanceOf[GetStructField] + val name = curField.childSchema.fields(curField.ordinal).name + names += name + root = root.asInstanceOf[GetStructField].child + } + // For map/array type, the reference is correct no matter NESTED_SCHEMA_PRUNING_ENABLED or not + if (!root.isInstanceOf[AttributeReference]) { + return BoundReference(structField.ordinal, structField.dataType, structField.nullable) + } + names += root.asInstanceOf[AttributeReference].name + input.attrs.foreach( + attribute => { + var level = names.size - 1 + if (names(level) == attribute.name) { + var candidateFields: Array[StructField] = null + var dtType = attribute.dataType + while (dtType.isInstanceOf[StructType] && level >= 1) { + candidateFields = dtType.asInstanceOf[StructType].fields + level -= 1 + val curName = names(level) + for (i <- 0 until candidateFields.length) { + if (candidateFields(i).name == curName) { + dtType = candidateFields(i).dataType + newOrdinal = i + } + } + } + } + }) + if (newOrdinal == -1) { + throw new IllegalStateException( + s"Couldn't find $structField in ${input.attrs.mkString("[", ",", "]")}") + } else { + BoundReference(newOrdinal, structField.dataType, structField.nullable) + } + } }