diff --git a/gluten-delta/src/main/scala/io/glutenproject/extension/DeltaRewriteTransformerRules.scala b/gluten-delta/src/main/scala/io/glutenproject/extension/DeltaRewriteTransformerRules.scala index 8e20e5343cba..3e0a07f59608 100644 --- a/gluten-delta/src/main/scala/io/glutenproject/extension/DeltaRewriteTransformerRules.scala +++ b/gluten-delta/src/main/scala/io/glutenproject/extension/DeltaRewriteTransformerRules.scala @@ -24,10 +24,9 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.delta.{DeltaParquetFileFormat, NoMapping} +import org.apache.spark.sql.delta.{DeltaColumnMapping, DeltaParquetFileFormat, NoMapping} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.types.{StructField, StructType} import scala.collection._ @@ -73,50 +72,44 @@ object DeltaRewriteTransformerRules { private def transformColumnMappingPlan(plan: SparkPlan): SparkPlan = plan match { case plan: DeltaScanTransformer => val fmt = plan.relation.fileFormat.asInstanceOf[DeltaParquetFileFormat] - // a mapping between the table schemas name to parquet schemas. - val columnNameMapping = mutable.Map.empty[String, String] - fmt.referenceSchema.foreach { - f => - val pName = f.metadata.getString("delta.columnMapping.physicalName") - val lName = f.name - columnNameMapping += (lName -> pName) - } // transform HadoopFsRelation val relation = plan.relation - val newDataFields = relation.dataSchema.map(e => e.copy(columnNameMapping(e.name))) - val newPartitionFields = relation.partitionSchema.map { - e => e.copy(columnNameMapping(e.name)) - } val newFsRelation = relation.copy( - partitionSchema = StructType(newPartitionFields), - dataSchema = StructType(newDataFields) + partitionSchema = DeltaColumnMapping.createPhysicalSchema( + relation.partitionSchema, + fmt.referenceSchema, + fmt.columnMappingMode), + dataSchema = DeltaColumnMapping.createPhysicalSchema( + relation.dataSchema, + fmt.referenceSchema, + fmt.columnMappingMode) )(SparkSession.active) - // transform output's name into physical name so Reader can read data correctly // should keep the columns order the same as the origin output val originColumnNames = mutable.ListBuffer.empty[String] val transformedAttrs = mutable.ListBuffer.empty[Attribute] - val newOutput = plan.output.map { - o => - val newAttr = o.withName(columnNameMapping(o.name)) - if (!originColumnNames.contains(o.name)) { - transformedAttrs += newAttr - originColumnNames += o.name - } - newAttr + def mapAttribute(attr: Attribute) = { + val newAttr = if (!plan.isMetadataColumn(attr)) { + DeltaColumnMapping + .createPhysicalAttributes(Seq(attr), fmt.referenceSchema, fmt.columnMappingMode) + .head + } else { + attr + } + if (!originColumnNames.contains(attr.name)) { + transformedAttrs += newAttr + originColumnNames += attr.name + } + newAttr } + val newOutput = plan.output.map(o => mapAttribute(o)) // transform dataFilters val newDataFilters = plan.dataFilters.map { e => e.transformDown { case attr: AttributeReference => - val newAttr = attr.withName(columnNameMapping(attr.name)).toAttribute - if (!originColumnNames.contains(attr.name)) { - transformedAttrs += newAttr - originColumnNames += attr.name - } - newAttr + mapAttribute(attr) } } // transform partitionFilters @@ -124,22 +117,17 @@ object DeltaRewriteTransformerRules { e => e.transformDown { case attr: AttributeReference => - val newAttr = attr.withName(columnNameMapping(attr.name)).toAttribute - if (!originColumnNames.contains(attr.name)) { - transformedAttrs += newAttr - originColumnNames += attr.name - } - newAttr + mapAttribute(attr) } } // replace tableName in schema with physicalName - val newRequiredFields = plan.requiredSchema.map { - e => StructField(columnNameMapping(e.name), e.dataType, e.nullable, e.metadata) - } val scanExecTransformer = new DeltaScanTransformer( newFsRelation, newOutput, - StructType(newRequiredFields), + DeltaColumnMapping.createPhysicalSchema( + plan.requiredSchema, + fmt.referenceSchema, + fmt.columnMappingMode), newPartitionFilters, plan.optionalBucketSet, plan.optionalNumCoalescedBuckets, diff --git a/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxDeltaSuite.scala b/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxDeltaSuite.scala index 199b05657eb5..780b14b813c0 100644 --- a/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxDeltaSuite.scala +++ b/gluten-delta/src/test/scala/io/glutenproject/execution/VeloxDeltaSuite.scala @@ -18,6 +18,9 @@ package io.glutenproject.execution import org.apache.spark.SparkConf import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, StringType, StructType} + +import scala.collection.JavaConverters._ class VeloxDeltaSuite extends WholeStageTransformerSuite { @@ -130,4 +133,54 @@ class VeloxDeltaSuite extends WholeStageTransformerSuite { checkAnswer(df2, Row("v2") :: Nil) } } + + testWithSpecifiedSparkVersion("column mapping with complex type") { + withTable("t1") { + val simpleNestedSchema = new StructType() + .add("a", StringType, true) + .add("b", new StructType().add("c", StringType, true).add("d", IntegerType, true)) + .add("map", MapType(StringType, StringType), true) + .add("arr", ArrayType(IntegerType), true) + + val simpleNestedData = spark.createDataFrame( + Seq( + Row("str1", Row("str1.1", 1), Map("k1" -> "v1"), Array(1, 11)), + Row("str2", Row("str1.2", 2), Map("k2" -> "v2"), Array(2, 22))).asJava, + simpleNestedSchema) + + spark.sql( + """CREATE TABLE t1 + | (a STRING,b STRUCT,map MAP,arr ARRAY) + | USING DELTA + | PARTITIONED BY (`a`) + | TBLPROPERTIES ('delta.columnMapping.mode' = 'name')""".stripMargin) + + simpleNestedData.write.format("delta").mode("append").saveAsTable("t1") + + val df1 = runQueryAndCompare("select * from t1") { _ => } + checkAnswer( + df1, + Seq( + Row("str1", Row("str1.1", 1), Map("k1" -> "v1"), Array(1, 11)), + Row("str2", Row("str1.2", 2), Map("k2" -> "v2"), Array(2, 22)))) + spark.sql(s"Alter table t1 RENAME COLUMN b to b1") + spark.sql( + "insert into t1 " + + "values ('str3', struct('str1.3', 3), map('k3', 'v3'), array(3, 33))") + + val df2 = runQueryAndCompare("select b1 from t1") { _ => } + checkAnswer(df2, Seq(Row(Row("str1.1", 1)), Row(Row("str1.2", 2)), Row(Row("str1.3", 3)))) + + spark.sql(s"Alter table t1 RENAME COLUMN b1.c to c1") + val df3 = runQueryAndCompare("select * from t1") { _ => } + checkAnswer( + df3, + Seq( + Row("str1", Row("str1.1", 1), Map("k1" -> "v1"), Array(1, 11)), + Row("str2", Row("str1.2", 2), Map("k2" -> "v2"), Array(2, 22)), + Row("str3", Row("str1.3", 3), Map("k3" -> "v3"), Array(3, 33)) + ) + ) + } + } } diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala index d380b0bd3e5c..e8e0e9a1c947 100644 --- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala @@ -65,6 +65,8 @@ class FileSourceScanExecShim( def hasMetadataColumns: Boolean = false + def isMetadataColumn(attr: Attribute): Boolean = false + def hasFieldIds: Boolean = false // The codes below are copied from FileSourceScanExec in Spark, diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala index b8a14701bb90..8134a7b35e9f 100644 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala @@ -66,6 +66,8 @@ class FileSourceScanExecShim( def hasMetadataColumns: Boolean = metadataColumns.nonEmpty + def isMetadataColumn(attr: Attribute): Boolean = metadataColumns.contains(attr) + def hasFieldIds: Boolean = ParquetUtils.hasFieldIds(requiredSchema) // The codes below are copied from FileSourceScanExec in Spark, diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala index 03a1b447559b..2687eb01d5a9 100644 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/FileSourceScanExecShim.scala @@ -67,6 +67,8 @@ class FileSourceScanExecShim( def hasMetadataColumns: Boolean = metadataColumns.nonEmpty + def isMetadataColumn(attr: Attribute): Boolean = metadataColumns.contains(attr) + def hasFieldIds: Boolean = ParquetUtils.hasFieldIds(requiredSchema) private def isDynamicPruningFilter(e: Expression): Boolean =