Skip to content

Commit

Permalink
allow hash on map for round-robin partitioning
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma committed May 24, 2024
1 parent 891ab83 commit 9b35e1a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,16 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
override def genColumnarShuffleExchange(
shuffle: ShuffleExchangeExec,
newChild: SparkPlan): SparkPlan = {
def allowHashOnMap[T](f: => T): T = {
val originalAllowHash = SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE)
try {
SQLConf.get.setConf(SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE, true)
f
} finally {
SQLConf.get.setConf(SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE, originalAllowHash)
}
}

shuffle.outputPartitioning match {
case HashPartitioning(exprs, _) =>
val hashExpr = new Murmur3Hash(exprs)
Expand All @@ -331,21 +341,30 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
shuffle.withNewChildren(newChild :: Nil)
}
case RoundRobinPartitioning(num) if SQLConf.get.sortBeforeRepartition && num > 1 =>
val hashExpr = new Murmur3Hash(newChild.output)
val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ newChild.output
val projectTransformer = ProjectExecTransformer(projectList, newChild)
val sortOrder = SortOrder(projectTransformer.output.head, Ascending)
val sortByHashCode = SortExecTransformer(Seq(sortOrder), global = false, projectTransformer)
val dropSortColumnTransformer = ProjectExecTransformer(projectList.drop(1), sortByHashCode)
val validationResult = dropSortColumnTransformer.doValidate()
if (validationResult.isValid) {
ColumnarShuffleExchangeExec(
shuffle,
dropSortColumnTransformer,
dropSortColumnTransformer.output)
} else {
TransformHints.tagNotTransformable(shuffle, validationResult)
shuffle.withNewChildren(newChild :: Nil)
// scalastyle:off line.size.limit
// Temporarily allow hash on map if it's disabled, otherwise HashExpression will fail to get
// resolved if its child contains map type.
// See https://github.com/apache/spark/blob/609bd4839e5d504917de74ed1cb9c23645fba51f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala#L279-L283
// scalastyle:on line.size.limit
allowHashOnMap {
val hashExpr = new Murmur3Hash(newChild.output)
val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ newChild.output
val projectTransformer = ProjectExecTransformer(projectList, newChild)
val sortOrder = SortOrder(projectTransformer.output.head, Ascending)
val sortByHashCode =
SortExecTransformer(Seq(sortOrder), global = false, projectTransformer)
val dropSortColumnTransformer =
ProjectExecTransformer(projectList.drop(1), sortByHashCode)
val validationResult = dropSortColumnTransformer.doValidate()
if (validationResult.isValid) {
ColumnarShuffleExchangeExec(
shuffle,
dropSortColumnTransformer,
dropSortColumnTransformer.output)
} else {
TransformHints.tagNotTransformable(shuffle, validationResult)
shuffle.withNewChildren(newChild :: Nil)
}
}
case _ =>
ColumnarShuffleExchangeExec(shuffle, newChild, null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkConf
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -1354,7 +1354,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
}
}

test("test roundrobine with sort") {
test("test RoundRobin repartition with sort") {
def checkRoundRobinOperators(df: DataFrame): Unit = {
checkGlutenOperatorMatch[SortExecTransformer](df)
checkGlutenOperatorMatch[ColumnarShuffleExchangeExec](df)
}

// scalastyle:off
runQueryAndCompare("SELECT /*+ REPARTITION(3) */ l_orderkey, l_partkey FROM lineitem") {
/*
Expand All @@ -1364,7 +1369,7 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
+- ^(2) ProjectExecTransformer [hash(l_orderkey#16L, l_partkey#17L) AS hash_partition_key#302, l_orderkey#16L, l_partkey#17L]
+- ^(2) BatchScanExecTransformer[l_orderkey#16L, l_partkey#17L] ParquetScan DataFilters: [], Format: parquet, Location: InMemoryFileIndex(1 paths)[..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<l_orderkey:bigint,l_partkey:bigint>, PushedFilters: [] RuntimeFilters: []
*/
checkGlutenOperatorMatch[SortExecTransformer]
checkRoundRobinOperators
}
// scalastyle:on

Expand All @@ -1377,6 +1382,11 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
}
}
}

// Gluten-5206: test repartition on map type
runQueryAndCompare(
"SELECT /*+ REPARTITION(3) */ l_orderkey, map(l_orderkey, l_partkey) FROM lineitem")(
checkRoundRobinOperators)
}

test("Support Map type signature") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ abstract class FuzzerTestBase extends VeloxWholeStageTransformerSuite {
.set("spark.plugins", "org.apache.gluten.GlutenPlugin")
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.memory.offHeap.enabled", "true")
.set("spark.memory.offHeap.size", "512MB")
.set("spark.memory.offHeap.size", "4g")
.set("spark.driver.memory", "4g")
.set("spark.driver.maxResultSize", "4g")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ShuffleWriterFuzzerTest extends FuzzerTestBase {
logWarning(
s"==============================> " +
s"Started reproduction (seed: ${dataGenerator.getSeed})")
val result = defaultRunner(testShuffle(sql))
val result = defaultRunner(testShuffle(sql))()
assert(result.isInstanceOf[Successful], s"Failed to run 'reproduce' with seed: $seed")
}
}
Expand Down

0 comments on commit 9b35e1a

Please sign in to comment.