From 0a40b2433f8314115fd3c713cbb044e684b57b6b Mon Sep 17 00:00:00 2001 From: Zouxxyy Date: Tue, 7 May 2024 10:26:56 +0800 Subject: [PATCH] [CORE] Only return columns of partitions that require read for iceberg (#5624) --- .../execution/IcebergScanTransformer.scala | 4 +-- .../source/GlutenIcebergSourceUtil.scala | 30 ++++++++++++++----- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala index 9bb33678a9df..6e079bf7e10a 100644 --- a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala +++ b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala @@ -48,7 +48,7 @@ case class IcebergScanTransformer( override def filterExprs(): Seq[Expression] = pushdownFilters.getOrElse(Seq.empty) - override def getPartitionSchema: StructType = GlutenIcebergSourceUtil.getPartitionSchema(scan) + override def getPartitionSchema: StructType = GlutenIcebergSourceUtil.getReadPartitionSchema(scan) override def getDataSchema: StructType = new StructType() @@ -63,7 +63,7 @@ case class IcebergScanTransformer( filteredPartitions, outputPartitioning) groupedPartitions.zipWithIndex.map { - case (p, index) => GlutenIcebergSourceUtil.genSplitInfo(p, index) + case (p, index) => GlutenIcebergSourceUtil.genSplitInfo(p, index, getPartitionSchema) } } diff --git a/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala b/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala index 2b4f54aef141..6b67e763648b 100644 --- a/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala +++ b/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala @@ -34,7 +34,10 @@ import scala.collection.JavaConverters._ object GlutenIcebergSourceUtil { - def genSplitInfo(inputPartition: InputPartition, index: Int): SplitInfo = inputPartition match { + def genSplitInfo( + inputPartition: InputPartition, + index: Int, + readPartitionSchema: StructType): SplitInfo = inputPartition match { case partition: SparkInputPartition => val paths = new JArrayList[String]() val starts = new JArrayList[JLong]() @@ -50,8 +53,8 @@ object GlutenIcebergSourceUtil { paths.add(filePath) starts.add(task.start()) lengths.add(task.length()) - partitionColumns.add(getPartitionColumns(task)) - deleteFilesList.add(task.deletes()); + partitionColumns.add(getPartitionColumns(task, readPartitionSchema)) + deleteFilesList.add(task.deletes()) val currentFileFormat = convertFileFormat(task.file().format()) if (fileFormat == ReadFileFormat.UnknownFormat) { fileFormat = currentFileFormat @@ -94,7 +97,7 @@ object GlutenIcebergSourceUtil { throw new GlutenNotSupportException("Only support iceberg SparkBatchQueryScan.") } - def getPartitionSchema(sparkScan: Scan): StructType = sparkScan match { + def getReadPartitionSchema(sparkScan: Scan): StructType = sparkScan match { case scan: SparkBatchQueryScan => val tasks = scan.tasks().asScala asFileScanTask(tasks.toList).foreach { @@ -102,7 +105,16 @@ object GlutenIcebergSourceUtil { val spec = task.spec() if (spec.isPartitioned) { var partitionSchema = new StructType() - val partitionFields = spec.partitionType().fields().asScala + val readFields = scan.readSchema().fields.map(_.name).toSet + // Iceberg will generate some non-table fields as partition fields, such as x_bucket, + // which will not appear in readFields, they also cannot be filtered. + val tableFields = spec.schema().columns().asScala.map(_.name()).toSet + val partitionFields = + spec + .partitionType() + .fields() + .asScala + .filter(f => !tableFields.contains(f.name) || readFields.contains(f.name())) partitionFields.foreach { field => TypeUtil.validatePartitionColumnType(field.`type`().typeId()) @@ -130,12 +142,16 @@ object GlutenIcebergSourceUtil { } } - private def getPartitionColumns(task: FileScanTask): JHashMap[String, String] = { + private def getPartitionColumns( + task: FileScanTask, + readPartitionSchema: StructType): JHashMap[String, String] = { val partitionColumns = new JHashMap[String, String]() + val readPartitionFields = readPartitionSchema.fields.map(_.name).toSet val spec = task.spec() val partition = task.partition() if (spec.isPartitioned) { - val partitionFields = spec.partitionType().fields().asScala + val partitionFields = + spec.partitionType().fields().asScala.filter(f => readPartitionFields.contains(f.name())) partitionFields.zipWithIndex.foreach { case (field, index) => val partitionValue = partition.get(index, field.`type`().typeId().javaClass())