From 119062f4be0c4f423399829fea9fbd062f4e80a7 Mon Sep 17 00:00:00 2001 From: zhli Date: Mon, 4 Dec 2023 23:25:43 +0800 Subject: [PATCH] [VL] fallback bloom_filter_agg when might_contain is fallbacked --- .../columnar/TransformHintRule.scala | 7 +++++ .../utils/velox/VeloxTestSettings.scala | 2 -- .../utils/velox/VeloxTestSettings.scala | 2 -- .../glutenproject/sql/shims/SparkShims.scala | 4 ++- .../sql/shims/spark32/Spark32Shims.scala | 4 ++- .../sql/shims/spark33/Spark33Shims.scala | 27 +++++++++++++++++++ .../sql/shims/spark34/Spark34Shims.scala | 27 +++++++++++++++++++ 7 files changed, 67 insertions(+), 6 deletions(-) diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala index 1431af2dda34..ca86e40490c3 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala @@ -714,6 +714,13 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { s"${e.getMessage}, original sparkplan is " + s"${plan.getClass}(${plan.children.toList.map(_.getClass)})") } + + if (TransformHints.isAlreadyTagged(plan) && TransformHints.isNotTransformable(plan)) { + // Velox BloomFilter's implementation is different from Spark's. + // So if might_contain falls back, we need fall back related bloom filter agg. + SparkShimLoader.getSparkShims.handleBloomFilterFallback(plan)( + p => TransformHints.tagNotTransformable(p, "related BloomFilterMightContain is fallbacked")) + } } implicit class EncodeTransformableTagImplicits(validationResult: ValidationResult) { diff --git a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala index f28a9c31a40a..756707679822 100644 --- a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala @@ -51,8 +51,6 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("string split function with positive limit") .exclude("string split function with negative limit") enableSuite[GlutenBloomFilterAggregateQuerySuite] - // fallback might_contain, the input argument binary is not same with vanilla spark - .exclude("Test NULL inputs for might_contain") enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2DataFrameSuite] enableSuite[GlutenDataSourceV2FunctionSuite] diff --git a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala index dc0762a79d20..1010414ceef4 100644 --- a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala @@ -51,8 +51,6 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("string split function with positive limit") .exclude("string split function with negative limit") enableSuite[GlutenBloomFilterAggregateQuerySuite] - // fallback might_contain, the input argument binary is not same with vanilla spark - .exclude("Test NULL inputs for might_contain") enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2DataFrameSuite] enableSuite[GlutenDataSourceV2FunctionSuite] diff --git a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala index de99e7efb44c..1a1ec92a8052 100644 --- a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.text.TextScan @@ -81,4 +81,6 @@ trait SparkShims { start: Long, length: Long, @transient locations: Array[String] = Array.empty): PartitionedFile + + def handleBloomFilterFallback(plan: SparkPlan)(fun: SparkPlan => Unit): Unit } diff --git a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala index 580cc93bdf75..d63898a51487 100644 --- a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala +++ b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} +import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan} import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -101,4 +101,6 @@ class Spark32Shims extends SparkShims { length: Long, @transient locations: Array[String] = Array.empty): PartitionedFile = PartitionedFile(partitionValues, filePath, start, length, locations) + + override def handleBloomFilterFallback(plan: SparkPlan)(fun: SparkPlan => Unit): Unit = {} } diff --git a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala index 50e536610e7b..4fa3415a3715 100644 --- a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala @@ -127,6 +127,33 @@ class Spark33Shims extends SparkShims { @transient locations: Array[String] = Array.empty): PartitionedFile = PartitionedFile(partitionValues, filePath, start, length, locations) + override def handleBloomFilterFallback(plan: SparkPlan)(fun: SparkPlan => Unit): Unit = { + def tagNotTransformableRecursive(p: SparkPlan): Unit = { + p match { + case agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec + if agg.aggregateExpressions.exists( + expr => expr.aggregateFunction.isInstanceOf[BloomFilterAggregate]) => + fun(agg) + tagNotTransformableRecursive(agg.child) + case a: org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec => + tagNotTransformableRecursive(a.executedPlan) + case _ => + p.children.map(tagNotTransformableRecursive) + } + } + + plan.transformAllExpressions { + case mc @ BloomFilterMightContain(sub: org.apache.spark.sql.execution.ScalarSubquery, _) => + tagNotTransformableRecursive(sub.plan) + mc + case mc @ BloomFilterMightContain( + g @ GetStructField(sub: org.apache.spark.sql.execution.ScalarSubquery, _, _), + _) => + tagNotTransformableRecursive(sub.plan) + mc + } + } + private def invalidBucketFile(path: String): Throwable = { new SparkException( errorClass = "INVALID_BUCKET_FILE", diff --git a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala index cdc42f3b43fd..5b4c0b70c5a5 100644 --- a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala @@ -130,6 +130,33 @@ class Spark34Shims extends SparkShims { @transient locations: Array[String] = Array.empty): PartitionedFile = PartitionedFile(partitionValues, SparkPath.fromPathString(filePath), start, length, locations) + override def handleBloomFilterFallback(plan: SparkPlan)(fun: SparkPlan => Unit): Unit = { + def tagNotTransformableRecursive(p: SparkPlan): Unit = { + p match { + case agg: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec + if agg.aggregateExpressions.exists( + expr => expr.aggregateFunction.isInstanceOf[BloomFilterAggregate]) => + fun(agg) + tagNotTransformableRecursive(agg.child) + case a: org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec => + tagNotTransformableRecursive(a.executedPlan) + case _ => + p.children.map(tagNotTransformableRecursive) + } + } + + plan.transformAllExpressions { + case mc @ BloomFilterMightContain(sub: org.apache.spark.sql.execution.ScalarSubquery, _) => + tagNotTransformableRecursive(sub.plan) + mc + case mc @ BloomFilterMightContain( + g @ GetStructField(sub: org.apache.spark.sql.execution.ScalarSubquery, _, _), + _) => + tagNotTransformableRecursive(sub.plan) + mc + } + } + private def invalidBucketFile(path: String): Throwable = { new SparkException( errorClass = "INVALID_BUCKET_FILE",