Skip to content

Commit

Permalink
[GLUTEN-7178][VL] Fix field not found error when struct field name co…
Browse files Browse the repository at this point in the history
…ntains upper case (#7304)
  • Loading branch information
zml1206 authored Sep 24, 2024
1 parent f00e453 commit f49fec7
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ object CHBackendSettings extends BackendSettingsApi with Logging {

override def supportStructType(): Boolean = true

override def structFieldToLowerCase(): Boolean = false

override def supportExpandExec(): Boolean = true

override def excludeScanExecFromCollapsedStage(): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2094,4 +2094,22 @@ class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa
runQueryAndCompare("select col0 / (col1 + 1E-8) from t") { _ => }
}
}

test("Fix struct field case error") {
val excludedRules = "org.apache.spark.sql.catalyst.optimizer.PushDownPredicates," +
"org.apache.spark.sql.catalyst.optimizer.PushPredicateThroughNonJoin"
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) {
withTempPath {
path =>
sql("select named_struct('A', a) as c1 from values (1), (2) as data(a)").write.parquet(
path.getAbsolutePath)
val df = spark.read
.parquet(path.getAbsolutePath)
.union(spark.read.parquet(path.getAbsolutePath))
.filter("c1.A > 1")
.select("c1.A")
checkAnswer(df, Seq(Row(2), Row(2)))
}
}
}
}
6 changes: 6 additions & 0 deletions gluten-arrow/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.apache.gluten</groupId>
<artifactId>gluten-substrait</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>io.glutenproject</groupId>
<artifactId>protobuf-java</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.utils

import org.apache.gluten.expression.ConverterUtils

import org.apache.spark.sql.types._

import org.apache.arrow.vector.complex.MapVector
Expand Down Expand Up @@ -92,9 +94,16 @@ object SparkArrowUtil {
name,
fieldType,
fields
.map(field => toArrowField(field.name, field.dataType, field.nullable, timeZoneId))
.map(
field =>
toArrowField(
ConverterUtils.normalizeStructFieldName(field.name),
field.dataType,
field.nullable,
timeZoneId))
.toSeq
.asJava)
.asJava
)
case MapType(keyType, valueType, valueContainsNull) =>
val mapType = new FieldType(nullable, new ArrowType.Map(false), null)
// Note: Map Type struct can not be null, Struct Type key field can not be null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ trait BackendSettingsApi {
}
def supportStructType(): Boolean = false

def structFieldToLowerCase(): Boolean = true

// Whether to fallback aggregate at the same time if its empty-output child is fallen back.
def fallbackAggregateWithEmptyOutputChild(): Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ object ConverterUtils extends Logging {
if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
}

def normalizeStructFieldName(name: String): String = {
if (BackendsApiManager.getSettings.structFieldToLowerCase()) {
normalizeColName(name)
} else {
name
}
}

def getShortAttributeName(attr: Attribute): String = {
val name = normalizeColName(attr.name)
val subIndex = name.indexOf("(")
Expand Down Expand Up @@ -259,7 +267,7 @@ object ConverterUtils extends Logging {
val fieldNames = new JArrayList[String]
for (structField <- s.fields) {
fieldNodes.add(getTypeNode(structField.dataType, structField.nullable))
fieldNames.add(structField.name)
fieldNames.add(normalizeStructFieldName(structField.name))
}
TypeBuilder.makeStruct(nullable, fieldNodes, fieldNames)
case _: NullType =>
Expand Down

0 comments on commit f49fec7

Please sign in to comment.