Skip to content

Commit

Permalink
disable complex type fallback for parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
yma11 committed Aug 14, 2024
1 parent fc7f9cd commit 398fafa
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,56 +99,8 @@ object VeloxBackendSettings extends BackendSettingsApi {
}
}

val parquetTypeValidatorWithComplexTypeFallback: PartialFunction[StructField, String] = {
case StructField(_, arrayType: ArrayType, _, _) =>
arrayType.simpleString + " is forced to fallback."
case StructField(_, mapType: MapType, _, _) =>
mapType.simpleString + " is forced to fallback."
case StructField(_, structType: StructType, _, _) =>
structType.simpleString + " is forced to fallback."
case StructField(_, timestampType: TimestampType, _, _)
if GlutenConfig.getConf.forceParquetTimestampTypeScanFallbackEnabled =>
timestampType.simpleString + " is forced to fallback."
}
val orcTypeValidatorWithComplexTypeFallback: PartialFunction[StructField, String] = {
case StructField(_, arrayType: ArrayType, _, _) =>
arrayType.simpleString + " is forced to fallback."
case StructField(_, mapType: MapType, _, _) =>
mapType.simpleString + " is forced to fallback."
case StructField(_, structType: StructType, _, _) =>
structType.simpleString + " is forced to fallback."
case StructField(_, stringType: StringType, _, metadata)
if isCharType(stringType, metadata) =>
CharVarcharUtils.getRawTypeString(metadata) + " not support"
case StructField(_, TimestampType, _, _) => "TimestampType not support"
}
format match {
case ParquetReadFormat =>
val typeValidator: PartialFunction[StructField, String] = {
// Parquet scan of nested array with struct/array as element type is unsupported in Velox.
case StructField(_, arrayType: ArrayType, _, _)
if arrayType.elementType.isInstanceOf[StructType] =>
"StructType as element in ArrayType"
case StructField(_, arrayType: ArrayType, _, _)
if arrayType.elementType.isInstanceOf[ArrayType] =>
"ArrayType as element in ArrayType"
// Parquet scan of nested map with struct as key type,
// or array type as value type is not supported in Velox.
case StructField(_, mapType: MapType, _, _) if mapType.keyType.isInstanceOf[StructType] =>
"StructType as Key in MapType"
case StructField(_, mapType: MapType, _, _)
if mapType.valueType.isInstanceOf[ArrayType] =>
"ArrayType as Value in MapType"
case StructField(_, TimestampType, _, _)
if GlutenConfig.getConf.forceParquetTimestampTypeScanFallbackEnabled =>
"TimestampType"
}
if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) {
validateTypes(typeValidator)
} else {
validateTypes(parquetTypeValidatorWithComplexTypeFallback)
}
case DwrfReadFormat => ValidationResult.succeeded
case ParquetReadFormat | DwrfReadFormat => ValidationResult.succeeded
case OrcReadFormat =>
if (!GlutenConfig.getConf.veloxOrcScanEnabled) {
ValidationResult.failed(s"Velox ORC scan is turned off.")
Expand All @@ -171,11 +123,7 @@ object VeloxBackendSettings extends BackendSettingsApi {
CharVarcharUtils.getRawTypeString(metadata) + " not support"
case StructField(_, TimestampType, _, _) => "TimestampType not support"
}
if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) {
validateTypes(typeValidator)
} else {
validateTypes(orcTypeValidatorWithComplexTypeFallback)
}
validateTypes(typeValidator)
}
case _ => ValidationResult.failed(s"Unsupported file format for $format.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
childTransformer: ExpressionTransformer,
ordinal: Int,
original: GetStructField): ExpressionTransformer = {
VeloxGetStructFieldTransformer(substraitExprName, childTransformer, original)
VeloxGetStructFieldTransformer(substraitExprName, childTransformer, ordinal, original)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,20 @@ case class VeloxNamedStructTransformer(
case class VeloxGetStructFieldTransformer(
substraitExprName: String,
child: ExpressionTransformer,
ordinal: Int,
original: GetStructField)
extends UnaryExpressionTransformer {
override def doTransform(args: Object): ExpressionNode = {
val childNode = child.doTransform(args)
childNode match {
case node: StructLiteralNode =>
node.getFieldLiteral(original.ordinal)
node.getFieldLiteral(ordinal)
case node: SelectionNode =>
// Append the nested index to selection node.
node.addNestedChildIdx(JInteger.valueOf(original.ordinal))
node.addNestedChildIdx(JInteger.valueOf(ordinal))
case node: NullLiteralNode =>
val nodeType =
node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(original.ordinal)
node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(ordinal)
ExpressionBuilder.makeNullLiteral(nodeType)
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,7 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla

sql("CREATE TABLE t2(id INT, l ARRAY<STRUCT<k: INT, v: INT>>) USING PARQUET")
sql("INSERT INTO t2 VALUES(1, ARRAY(STRUCT(1, 100))), (2, ARRAY(STRUCT(2, 200)))")
runQueryAndCompare("SELECT first(l) FROM t2")(df => checkFallbackOperators(df, 1))
runQueryAndCompare("SELECT first(l) FROM t2")(df => checkFallbackOperators(df, 0))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,22 +427,6 @@ class VeloxParquetDataTypeValidationSuite extends VeloxWholeStageTransformerSuit
}
}

test("Force complex type scan fallback") {
withSQLConf(("spark.gluten.sql.complexType.scan.fallback.enabled", "true")) {
val df = spark.sql("select struct from type1")
val executedPlan = getExecutedPlan(df)
assert(!executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}

test("Force timestamp type scan fallback") {
withSQLConf(("spark.gluten.sql.parquet.timestampType.scan.fallback.enabled", "true")) {
val df = spark.sql("select timestamp from type1")
val executedPlan = getExecutedPlan(df)
assert(!executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}

test("Decimal type") {
// Validation: BatchScan Project Aggregate Expand Sort Limit
runQueryAndCompare(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import org.apache.spark.sql.hive.HiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import scala.collection.mutable.ArrayBuffer

trait Transformable {
def getTransformer(childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer
}
Expand Down Expand Up @@ -349,15 +351,26 @@ object ExpressionConverter extends SQLConfHelper with Logging {
expr => replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap)),
m)
case getStructField: GetStructField =>
// Different backends may have different result.
BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(
getStructField.child,
attributeSeq,
expressionsMap),
getStructField.ordinal,
getStructField)
try {
val bindRef =
bindGetStructField(getStructField, attributeSeq)
// Different backends may have different result.
BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(
getStructField.child,
attributeSeq,
expressionsMap),
bindRef.ordinal,
getStructField)
} catch {
case e: IllegalStateException =>
// This situation may need developers to fix, although we just throw the below
// exception to let the corresponding operator fall back.
throw new UnsupportedOperationException(
s"Failed to bind reference for $getStructField: ${e.getMessage}")
}

case getArrayStructFields: GetArrayStructFields =>
GenericExpressionTransformer(
substraitExprName,
Expand Down Expand Up @@ -729,4 +742,49 @@ object ExpressionConverter extends SQLConfHelper with Logging {
}
substraitExprName
}

private def bindGetStructField(
structField: GetStructField,
input: AttributeSeq): BoundReference = {
// get the new ordinal base input
var newOrdinal: Int = -1
val names = new ArrayBuffer[String]
var root: Expression = structField
while (root.isInstanceOf[GetStructField]) {
val curField = root.asInstanceOf[GetStructField]
val name = curField.childSchema.fields(curField.ordinal).name
names += name
root = root.asInstanceOf[GetStructField].child
}
// For map/array type, the reference is correct no matter NESTED_SCHEMA_PRUNING_ENABLED or not
if (!root.isInstanceOf[AttributeReference]) {
return BoundReference(structField.ordinal, structField.dataType, structField.nullable)
}
names += root.asInstanceOf[AttributeReference].name
input.attrs.foreach(
attribute => {
var level = names.size - 1
if (names(level) == attribute.name) {
var candidateFields: Array[StructField] = null
var dtType = attribute.dataType
while (dtType.isInstanceOf[StructType] && level >= 1) {
candidateFields = dtType.asInstanceOf[StructType].fields
level -= 1
val curName = names(level)
for (i <- 0 until candidateFields.length) {
if (candidateFields(i).name == curName) {
dtType = candidateFields(i).dataType
newOrdinal = i
}
}
}
}
})
if (newOrdinal == -1) {
throw new IllegalStateException(
s"Couldn't find $structField in ${input.attrs.mkString("[", ",", "]")}")
} else {
BoundReference(newOrdinal, structField.dataType, structField.nullable)
}
}
}

0 comments on commit 398fafa

Please sign in to comment.