Skip to content

Commit

Permalink
fix: Fallback to Spark for unsupported partitioning
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Aug 1, 2024
1 parent e52cfb4 commit e32decb
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class CometSparkSessionExtensions
case s: ShuffleExchangeExec
if (!s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode(
conf) &&
QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 &&
QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 &&
!isShuffleOperator(s.child) =>
logInfo("Comet extension enabled for JVM Columnar Shuffle")
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
Expand Down Expand Up @@ -769,7 +769,7 @@ class CometSparkSessionExtensions
// convert it to CometColumnarShuffle,
case s: ShuffleExchangeExec
if isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) &&
QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 &&
QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 &&
!isShuffleOperator(s.child) =>
logInfo("Comet extension enabled for JVM Columnar Shuffle")

Expand All @@ -789,6 +789,7 @@ class CometSparkSessionExtensions

case s: ShuffleExchangeExec =>
val isShuffleEnabled = isCometShuffleEnabled(conf)
val outputPartitioning = s.outputPartitioning
val reason = getCometShuffleNotEnabledReason(conf).getOrElse("no reason available")
val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason")
val columnarShuffleEnabled = isCometJVMShuffleMode(conf)
Expand All @@ -797,12 +798,13 @@ class CometSparkSessionExtensions
.supportPartitioning(s.child.output, s.outputPartitioning)
._1,
"Native shuffle: " +
s"${QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._2}")
s"${QueryPlanSerde.supportPartitioning(s.child.output, outputPartitioning)._2}")
val msg3 = createMessage(
isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde
.supportPartitioningTypes(s.child.output)
.supportPartitioningTypes(s.child.output, s.outputPartitioning)
._1,
s"JVM shuffle: ${QueryPlanSerde.supportPartitioningTypes(s.child.output)._2}")
"JVM shuffle: " +
s"${QueryPlanSerde.supportPartitioningTypes(s.child.output, outputPartitioning)._2}")
withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(","))
s

Expand Down
65 changes: 47 additions & 18 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometRowToColumnarExec, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
Expand Down Expand Up @@ -2880,7 +2880,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
* Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
* which supports struct/array.
*/
def supportPartitioningTypes(inputs: Seq[Attribute]): (Boolean, String) = {
def supportPartitioningTypes(
inputs: Seq[Attribute],
partitioning: Partitioning): (Boolean, String) = {
def supportedDataType(dt: DataType): Boolean = dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType |
Expand All @@ -2904,14 +2906,37 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
false
}

// Check if the datatypes of shuffle input are supported.
var msg = ""
val supported = inputs.forall(attr => supportedDataType(attr.dataType))
val supported = partitioning match {
case HashPartitioning(expressions, _) =>
val supported =
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
expressions.forall(e => supportedDataType(e.dataType))
if (!supported) {
msg = s"unsupported Spark partitioning expressions: $expressions"
}
supported
case SinglePartition => true
case RoundRobinPartitioning(_) => true
case RangePartitioning(orderings, _) =>
val supported =
orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
orderings.forall(e => supportedDataType(e.dataType))
if (!supported) {
msg = s"unsupported Spark partitioning expressions: $orderings"
}
supported
case _ =>
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
false
}

if (!supported) {
msg = s"unsupported Spark partitioning: ${inputs.map(_.dataType)}"
emitWarning(msg)
(false, msg)
} else {
(true, null)
}
(supported, msg)
}

/**
Expand All @@ -2930,23 +2955,27 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
false
}

// Check if the datatypes of shuffle input are supported.
val supported = inputs.forall(attr => supportedDataType(attr.dataType))
var msg = ""
val supported = partitioning match {
case HashPartitioning(expressions, _) =>
val supported =
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
expressions.forall(e => supportedDataType(e.dataType))
if (!supported) {
msg = s"unsupported Spark partitioning expressions: $expressions"
}
supported
case SinglePartition => true
case _ =>
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
false
}

if (!supported) {
val msg = s"unsupported Spark partitioning: ${inputs.map(_.dataType)}"
emitWarning(msg)
(false, msg)
} else {
partitioning match {
case HashPartitioning(expressions, _) =>
(expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined), null)
case SinglePartition => (true, null)
case other =>
val msg = s"unsupported Spark partitioning: ${other.getClass.getName}"
emitWarning(msg)
(false, msg)
}
(true, null)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@

package org.apache.comet.exec

import java.util.Collections

import org.scalactic.source.Position
import org.scalatest.Tag

import org.apache.hadoop.fs.Path
import org.apache.spark.{Partitioner, SparkConf}
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager}
import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryCatalog, InMemoryTableCatalog}
import org.apache.spark.sql.connector.distributions.Distributions
import org.apache.spark.sql.connector.expressions.Expressions._
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
Expand All @@ -34,6 +40,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus

abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper {
protected val adaptiveExecutionEnabled: Boolean
Expand All @@ -47,6 +54,16 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar

protected val asyncShuffleEnable: Boolean

override def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName)
}

override def afterAll(): Unit = {
spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat")
super.afterAll()
}

override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
pos: Position): Unit = {
super.test(testName, testTags: _*) {
Expand Down Expand Up @@ -85,6 +102,94 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
}
}

private val emptyProps: java.util.Map[String, String] = {
Collections.emptyMap[String, String]
}
private val items: String = "items"
private val itemsColumns: Array[Column] = Array(
Column.create("id", LongType),
Column.create("name", StringType),
Column.create("price", FloatType),
Column.create("arrive_time", TimestampType))

private val purchases: String = "purchases"
private val purchasesColumns: Array[Column] = Array(
Column.create("item_id", LongType),
Column.create("price", FloatType),
Column.create("time", TimestampType))

protected def catalog: InMemoryCatalog = {
val catalog = spark.sessionState.catalogManager.catalog("testcat")
catalog.asInstanceOf[InMemoryCatalog]
}

private def createTable(
table: String,
columns: Array[Column],
partitions: Array[Transform],
catalog: InMemoryTableCatalog = catalog): Unit = {
catalog.createTable(
Identifier.of(Array("ns"), table),
columns,
partitions,
emptyProps,
Distributions.unspecified(),
Array.empty,
None,
None,
numRowsPerSplit = 1)
}

private def selectWithMergeJoinHint(t1: String, t2: String): String = {
s"SELECT /*+ MERGE($t1, $t2) */ "
}

private def createJoinTestDF(
keys: Seq[(String, String)],
extraColumns: Seq[String] = Nil,
joinType: String = ""): DataFrame = {
val extraColList = if (extraColumns.isEmpty) "" else extraColumns.mkString(", ", ", ", "")
sql(s"""
|${selectWithMergeJoinHint("i", "p")}
|id, name, i.price as purchase_price, p.price as sale_price $extraColList
|FROM testcat.ns.$items i $joinType JOIN testcat.ns.$purchases p
|ON ${keys.map(k => s"i.${k._1} = p.${k._2}").mkString(" AND ")}
|ORDER BY id, purchase_price, sale_price $extraColList
|""".stripMargin)
}

test("Fallback to Spark for unsupported partitioning") {
assume(isSpark40Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")

val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)

sql(
s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchasesColumns, Array.empty)
sql(
s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp)), " +
"(5, 26.0, cast('2023-01-01' as timestamp)), " +
"(6, 50.0, cast('2023-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(
SQLConf.V2_BUCKETING_ENABLED.key -> "true",
"spark.sql.sources.v2.bucketing.shuffle.enabled" -> shuffle.toString,
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"))
checkSparkAnswer(df)
}
}
}

test("columnar shuffle on nested struct including nulls") {
Seq(10, 201).foreach { numPartitions =>
Seq("1.0", "10.0").foreach { ratio =>
Expand Down

0 comments on commit e32decb

Please sign in to comment.