Skip to content

Commit

Permalink
[VL] fix delta column mapping for struct type columns (#4530)
Browse files Browse the repository at this point in the history
[VL] fix delta column mapping for struct type columns.
  • Loading branch information
zhli1142015 authored Jan 26, 2024
1 parent fb66d64 commit 15f51e8
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.delta.{DeltaParquetFileFormat, NoMapping}
import org.apache.spark.sql.delta.{DeltaColumnMapping, DeltaParquetFileFormat, NoMapping}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.types.{StructField, StructType}

import scala.collection._

Expand Down Expand Up @@ -73,73 +72,62 @@ object DeltaRewriteTransformerRules {
private def transformColumnMappingPlan(plan: SparkPlan): SparkPlan = plan match {
case plan: DeltaScanTransformer =>
val fmt = plan.relation.fileFormat.asInstanceOf[DeltaParquetFileFormat]
// a mapping between the table schemas name to parquet schemas.
val columnNameMapping = mutable.Map.empty[String, String]
fmt.referenceSchema.foreach {
f =>
val pName = f.metadata.getString("delta.columnMapping.physicalName")
val lName = f.name
columnNameMapping += (lName -> pName)
}

// transform HadoopFsRelation
val relation = plan.relation
val newDataFields = relation.dataSchema.map(e => e.copy(columnNameMapping(e.name)))
val newPartitionFields = relation.partitionSchema.map {
e => e.copy(columnNameMapping(e.name))
}
val newFsRelation = relation.copy(
partitionSchema = StructType(newPartitionFields),
dataSchema = StructType(newDataFields)
partitionSchema = DeltaColumnMapping.createPhysicalSchema(
relation.partitionSchema,
fmt.referenceSchema,
fmt.columnMappingMode),
dataSchema = DeltaColumnMapping.createPhysicalSchema(
relation.dataSchema,
fmt.referenceSchema,
fmt.columnMappingMode)
)(SparkSession.active)

// transform output's name into physical name so Reader can read data correctly
// should keep the columns order the same as the origin output
val originColumnNames = mutable.ListBuffer.empty[String]
val transformedAttrs = mutable.ListBuffer.empty[Attribute]
val newOutput = plan.output.map {
o =>
val newAttr = o.withName(columnNameMapping(o.name))
if (!originColumnNames.contains(o.name)) {
transformedAttrs += newAttr
originColumnNames += o.name
}
newAttr
def mapAttribute(attr: Attribute) = {
val newAttr = if (!plan.isMetadataColumn(attr)) {
DeltaColumnMapping
.createPhysicalAttributes(Seq(attr), fmt.referenceSchema, fmt.columnMappingMode)
.head
} else {
attr
}
if (!originColumnNames.contains(attr.name)) {
transformedAttrs += newAttr
originColumnNames += attr.name
}
newAttr
}
val newOutput = plan.output.map(o => mapAttribute(o))
// transform dataFilters
val newDataFilters = plan.dataFilters.map {
e =>
e.transformDown {
case attr: AttributeReference =>
val newAttr = attr.withName(columnNameMapping(attr.name)).toAttribute
if (!originColumnNames.contains(attr.name)) {
transformedAttrs += newAttr
originColumnNames += attr.name
}
newAttr
mapAttribute(attr)
}
}
// transform partitionFilters
val newPartitionFilters = plan.partitionFilters.map {
e =>
e.transformDown {
case attr: AttributeReference =>
val newAttr = attr.withName(columnNameMapping(attr.name)).toAttribute
if (!originColumnNames.contains(attr.name)) {
transformedAttrs += newAttr
originColumnNames += attr.name
}
newAttr
mapAttribute(attr)
}
}
// replace tableName in schema with physicalName
val newRequiredFields = plan.requiredSchema.map {
e => StructField(columnNameMapping(e.name), e.dataType, e.nullable, e.metadata)
}
val scanExecTransformer = new DeltaScanTransformer(
newFsRelation,
newOutput,
StructType(newRequiredFields),
DeltaColumnMapping.createPhysicalSchema(
plan.requiredSchema,
fmt.referenceSchema,
fmt.columnMappingMode),
newPartitionFilters,
plan.optionalBucketSet,
plan.optionalNumCoalescedBuckets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ package io.glutenproject.execution

import org.apache.spark.SparkConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, StringType, StructType}

import scala.collection.JavaConverters._

class VeloxDeltaSuite extends WholeStageTransformerSuite {

Expand Down Expand Up @@ -130,4 +133,54 @@ class VeloxDeltaSuite extends WholeStageTransformerSuite {
checkAnswer(df2, Row("v2") :: Nil)
}
}

testWithSpecifiedSparkVersion("column mapping with complex type") {
withTable("t1") {
val simpleNestedSchema = new StructType()
.add("a", StringType, true)
.add("b", new StructType().add("c", StringType, true).add("d", IntegerType, true))
.add("map", MapType(StringType, StringType), true)
.add("arr", ArrayType(IntegerType), true)

val simpleNestedData = spark.createDataFrame(
Seq(
Row("str1", Row("str1.1", 1), Map("k1" -> "v1"), Array(1, 11)),
Row("str2", Row("str1.2", 2), Map("k2" -> "v2"), Array(2, 22))).asJava,
simpleNestedSchema)

spark.sql(
"""CREATE TABLE t1
| (a STRING,b STRUCT<c: STRING NOT NULL, d: INT>,map MAP<STRING, STRING>,arr ARRAY<INT>)
| USING DELTA
| PARTITIONED BY (`a`)
| TBLPROPERTIES ('delta.columnMapping.mode' = 'name')""".stripMargin)

simpleNestedData.write.format("delta").mode("append").saveAsTable("t1")

val df1 = runQueryAndCompare("select * from t1") { _ => }
checkAnswer(
df1,
Seq(
Row("str1", Row("str1.1", 1), Map("k1" -> "v1"), Array(1, 11)),
Row("str2", Row("str1.2", 2), Map("k2" -> "v2"), Array(2, 22))))
spark.sql(s"Alter table t1 RENAME COLUMN b to b1")
spark.sql(
"insert into t1 " +
"values ('str3', struct('str1.3', 3), map('k3', 'v3'), array(3, 33))")

val df2 = runQueryAndCompare("select b1 from t1") { _ => }
checkAnswer(df2, Seq(Row(Row("str1.1", 1)), Row(Row("str1.2", 2)), Row(Row("str1.3", 3))))

spark.sql(s"Alter table t1 RENAME COLUMN b1.c to c1")
val df3 = runQueryAndCompare("select * from t1") { _ => }
checkAnswer(
df3,
Seq(
Row("str1", Row("str1.1", 1), Map("k1" -> "v1"), Array(1, 11)),
Row("str2", Row("str1.2", 2), Map("k2" -> "v2"), Array(2, 22)),
Row("str3", Row("str1.3", 3), Map("k3" -> "v3"), Array(3, 33))
)
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class FileSourceScanExecShim(

def hasMetadataColumns: Boolean = false

def isMetadataColumn(attr: Attribute): Boolean = false

def hasFieldIds: Boolean = false

// The codes below are copied from FileSourceScanExec in Spark,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class FileSourceScanExecShim(

def hasMetadataColumns: Boolean = metadataColumns.nonEmpty

def isMetadataColumn(attr: Attribute): Boolean = metadataColumns.contains(attr)

def hasFieldIds: Boolean = ParquetUtils.hasFieldIds(requiredSchema)

// The codes below are copied from FileSourceScanExec in Spark,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class FileSourceScanExecShim(

def hasMetadataColumns: Boolean = metadataColumns.nonEmpty

def isMetadataColumn(attr: Attribute): Boolean = metadataColumns.contains(attr)

def hasFieldIds: Boolean = ParquetUtils.hasFieldIds(requiredSchema)

private def isDynamicPruningFilter(e: Expression): Boolean =
Expand Down

0 comments on commit 15f51e8

Please sign in to comment.