diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index d44635afc3479..66ca8660a50c2 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -331,12 +331,16 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { } } - val child = if (GlutenConfig.getConf.veloxCoalesceBatchesBeforeShuffle) { - VeloxAppendBatchesExec(shuffle.child, GlutenConfig.getConf.veloxMinBatchSizeForShuffle) - } else { - shuffle.child + def maybeAddAppendBatchesExec(plan: SparkPlan): SparkPlan = { + if (GlutenConfig.getConf.veloxCoalesceBatchesBeforeShuffle) { + VeloxAppendBatchesExec(plan, GlutenConfig.getConf.veloxMinBatchSizeForShuffle) + } else { + plan + } } + val child = shuffle.child + shuffle.outputPartitioning match { case HashPartitioning(exprs, _) => val hashExpr = new Murmur3Hash(exprs) @@ -344,10 +348,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { val projectTransformer = ProjectExecTransformer(projectList, child) val validationResult = projectTransformer.doValidate() if (validationResult.isValid) { - ColumnarShuffleExchangeExec( - shuffle, - projectTransformer, - projectTransformer.output.drop(1)) + val newChild = maybeAddAppendBatchesExec(projectTransformer) + ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output.drop(1)) } else { TransformHints.tagNotTransformable(shuffle, validationResult) shuffle.withNewChildren(child :: Nil) @@ -363,7 +365,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { // null type since the value always be null. val columnsForHash = child.output.filterNot(_.dataType == NullType) if (columnsForHash.isEmpty) { - ColumnarShuffleExchangeExec(shuffle, child, child.output) + val newChild = maybeAddAppendBatchesExec(child) + ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output) } else { val hashExpr = new Murmur3Hash(columnsForHash) val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ child.output @@ -384,10 +387,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { ProjectExecTransformer(projectList.drop(1), sortByHashCode) val validationResult = dropSortColumnTransformer.doValidate() if (validationResult.isValid) { - ColumnarShuffleExchangeExec( - shuffle, - dropSortColumnTransformer, - dropSortColumnTransformer.output) + val newChild = maybeAddAppendBatchesExec(dropSortColumnTransformer) + ColumnarShuffleExchangeExec(shuffle, newChild, newChild.output) } else { TransformHints.tagNotTransformable(shuffle, validationResult) shuffle.withNewChildren(child :: Nil) @@ -395,7 +396,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { } } case _ => - ColumnarShuffleExchangeExec(shuffle, child, null) + val newChild = maybeAddAppendBatchesExec(child) + ColumnarShuffleExchangeExec(shuffle, newChild, null) } }