Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer committed Aug 1, 2024
1 parent 8eb8a90 commit 811919f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -833,10 +833,11 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla

test("combine small batches before shuffle") {
val minBatchSize = 15
val maxBatchSize = 100
withSQLConf(
"spark.gluten.sql.columnar.backend.velox.resizeBatches.shuffleInput" -> "true",
"spark.gluten.sql.columnar.backend.velox.resizeBatches.shuffleInput.range" -> "2~20000",
"spark.gluten.sql.columnar.backend.velox.minBatchSizeForShuffle" -> s"$minBatchSize"
"spark.gluten.sql.columnar.maxBatchSize" -> "2",
"spark.gluten.sql.columnar.backend.velox.resizeBatches.shuffleInput.range" -> s"$minBatchSize~$maxBatchSize"
) {
val df = runQueryAndCompare(
"select l_orderkey, sum(l_partkey) as sum from lineitem " +
Expand All @@ -854,6 +855,30 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
}
}

test("split small batches before shuffle") {
val minBatchSize = 1
val maxBatchSize = 4
withSQLConf(
"spark.gluten.sql.columnar.backend.velox.resizeBatches.shuffleInput" -> "true",
"spark.gluten.sql.columnar.maxBatchSize" -> "100",
"spark.gluten.sql.columnar.backend.velox.resizeBatches.shuffleInput.range" -> s"$minBatchSize~$maxBatchSize"
) {
val df = runQueryAndCompare(
"select l_orderkey, sum(l_partkey) as sum from lineitem " +
"where l_orderkey < 100 group by l_orderkey") { _ => }
checkLengthAndPlan(df, 27)
val ops = collect(df.queryExecution.executedPlan) { case p: VeloxResizeBatchesExec => p }
assert(ops.size == 1)
val op = ops.head
assert(op.minOutputBatchSize == minBatchSize)
val metrics = op.metrics
assert(metrics("numInputRows").value == 27)
assert(metrics("numInputBatches").value == 1)
assert(metrics("numOutputRows").value == 27)
assert(metrics("numOutputBatches").value == 7)
}
}

test("test OneRowRelation") {
val df = sql("SELECT 1")
checkAnswer(df, Row(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,11 @@ class GlutenConfig(conf: SQLConf) extends Logging {

def veloxBloomFilterMaxNumBits: Long = conf.getConf(COLUMNAR_VELOX_BLOOM_FILTER_MAX_NUM_BITS)

case class ResizeRange(min: Int, max: Int)
case class ResizeRange(min: Int, max: Int) {
assert(max >= min)
assert(min > 0, "Min batch size should be larger than 0")
assert(max > 0, "Max batch size should be larger than 0")
}

private object ResizeRange {
def parse(pattern: String): ResizeRange = {
Expand Down

0 comments on commit 811919f

Please sign in to comment.