From fa86e7682636a4a03a0ea7b60275c41b8567d4f0 Mon Sep 17 00:00:00 2001 From: Terry Wang Date: Mon, 18 Mar 2024 09:07:27 +0800 Subject: [PATCH] [GLUTEN-4875][VL]Support spark sql conf sortBeforeRepartition to avoid stage partial retry casuing result mismatch (#4872) --- .../velox/SparkPlanExecApiImpl.scala | 22 +++++++++++++--- .../execution/TestOperator.scala | 25 +++++++++++++++++++ .../spark/sql/GlutenImplicitsTest.scala | 4 +-- .../GlutenReplaceHashWithSortAggSuite.scala | 3 +++ .../GlutenReplaceHashWithSortAggSuite.scala | 3 +++ 5 files changed, 52 insertions(+), 5 deletions(-) diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala index 60aff4b293fa..b8da55ea95e4 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala @@ -36,12 +36,12 @@ import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, StringSplit, StringTrim} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit, StringTrim} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HLLAdapter} import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{BroadcastUtils, ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SparkPlan, VeloxColumnarWriteFilesExec} import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec} @@ -232,7 +232,23 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi { TransformHints.tagNotTransformable(shuffle, validationResult) shuffle.withNewChildren(newChild :: Nil) } - + case RoundRobinPartitioning(num) if SQLConf.get.sortBeforeRepartition && num > 1 => + val hashExpr = new Murmur3Hash(newChild.output) + val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ newChild.output + val projectTransformer = ProjectExecTransformer(projectList, newChild) + val sortOrder = SortOrder(projectTransformer.output.head, Ascending) + val sortByHashCode = SortExecTransformer(Seq(sortOrder), global = false, projectTransformer) + val dropSortColumnTransformer = ProjectExecTransformer(projectList.drop(1), sortByHashCode) + val validationResult = dropSortColumnTransformer.doValidate() + if (validationResult.isValid) { + ColumnarShuffleExchangeExec( + shuffle, + dropSortColumnTransformer, + dropSortColumnTransformer.output) + } else { + TransformHints.tagNotTransformable(shuffle, validationResult) + shuffle.withNewChildren(newChild :: Nil) + } case _ => ColumnarShuffleExchangeExec(shuffle, newChild, null) } diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala index 5dbad565ca73..239bec57a7d7 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala @@ -1233,4 +1233,29 @@ class TestOperator extends VeloxWholeStageTransformerSuite { checkOperatorMatch[HashAggregateExecTransformer] } } + + test("test roundrobine with sort") { + // scalastyle:off + runQueryAndCompare("SELECT /*+ REPARTITION(3) */ l_orderkey, l_partkey FROM lineitem") { + /* + ColumnarExchange RoundRobinPartitioning(3), REPARTITION_BY_NUM, [l_orderkey#16L, l_partkey#17L) + +- ^(2) ProjectExecTransformer [l_orderkey#16L, l_partkey#17L] + +- ^(2) SortExecTransformer [hash_partition_key#302 ASC NULLS FIRST], false, 0 + +- ^(2) ProjectExecTransformer [hash(l_orderkey#16L, l_partkey#17L) AS hash_partition_key#302, l_orderkey#16L, l_partkey#17L] + +- ^(2) BatchScanExecTransformer[l_orderkey#16L, l_partkey#17L] ParquetScan DataFilters: [], Format: parquet, Location: InMemoryFileIndex(1 paths)[..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct, PushedFilters: [] RuntimeFilters: [] + */ + checkOperatorMatch[SortExecTransformer] + } + // scalastyle:on + + withSQLConf("spark.sql.execution.sortBeforeRepartition" -> "false") { + runQueryAndCompare("""SELECT /*+ REPARTITION(3) */ + | l_orderkey, l_partkey FROM lineitem""".stripMargin) { + df => + { + assert(getExecutedPlan(df).count(_.isInstanceOf[SortExecTransformer]) == 0) + } + } + } + } } diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenImplicitsTest.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenImplicitsTest.scala index 52a4db9e682a..e4356cec8ff1 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenImplicitsTest.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenImplicitsTest.scala @@ -119,10 +119,10 @@ class GlutenImplicitsTest extends GlutenSQLTestsBaseTrait { testGluten("fallbackSummary with cached data and shuffle") { withAQEEnabledAndDisabled { val df = spark.sql("select * from t1").filter(_.getLong(0) > 0).cache.repartition() - assert(df.fallbackSummary().numGlutenNodes == 3, df.fallbackSummary()) + assert(df.fallbackSummary().numGlutenNodes == 6, df.fallbackSummary()) assert(df.fallbackSummary().numFallbackNodes == 1, df.fallbackSummary()) df.collect() - assert(df.fallbackSummary().numGlutenNodes == 3, df.fallbackSummary()) + assert(df.fallbackSummary().numGlutenNodes == 6, df.fallbackSummary()) assert(df.fallbackSummary().numFallbackNodes == 1, df.fallbackSummary()) } } diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala index bbc267ec2b9d..c83912d7feb5 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala @@ -61,6 +61,8 @@ class GlutenReplaceHashWithSortAggSuite Seq("FIRST", "COLLECT_LIST").foreach { aggExpr => + // Because repartition modification causing the result sort order not same and the + // result not same, so we add order by key before comparing the result. val query = s""" |SELECT key, $aggExpr(key) @@ -72,6 +74,7 @@ class GlutenReplaceHashWithSortAggSuite | SORT BY key |) |GROUP BY key + |ORDER BY key """.stripMargin checkAggs(query, 2, 0, 2, 0) } diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala index f86509d44636..f0898beb3ae3 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala @@ -60,6 +60,8 @@ class GlutenReplaceHashWithSortAggSuite Seq("FIRST", "COLLECT_LIST").foreach { aggExpr => + // Because repartition modification causing the result sort order not same and the + // result not same, so we add order by key before comparing the result. val query = s""" |SELECT key, $aggExpr(key) @@ -71,6 +73,7 @@ class GlutenReplaceHashWithSortAggSuite | SORT BY key |) |GROUP BY key + |ORDER BY key """.stripMargin checkAggs(query, 2, 0, 2, 0) }