diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala index 55be8846a71b..f9e53d789e96 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala @@ -27,7 +27,7 @@ import org.apache.spark.ShuffleDependency import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{ColumnarShuffleDependency, GlutenShuffleUtils} +import org.apache.spark.shuffle.{ColumnarShuffleDependency, GlutenShuffleUtils, HashPartitioningWrapper} import org.apache.spark.shuffle.utils.RangePartitionerBoundsGenerator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, BoundReference, UnsafeProjection, UnsafeRow} @@ -175,11 +175,17 @@ object CHExecUtil extends Logging { } private def buildHashPartitioning( - partitoining: HashPartitioning, + partitioning: HashPartitioning, childOutput: Seq[Attribute], output: Seq[Attribute]): NativePartitioning = { + val hashExpressions = partitioning match { + case partitioning: HashPartitioningWrapper => + partitioning.getNewExpr + case _ => + partitioning.expressions + } val hashFields = - partitoining.expressions.map( + hashExpressions.map( a => BindReferences .bindReference(ConverterUtils.getAttrFromExpr(a).toAttribute, childOutput) @@ -199,7 +205,7 @@ object CHExecUtil extends Logging { new NativePartitioning( GlutenShuffleUtils.HashPartitioningShortName, - partitoining.numPartitions, + partitioning.numPartitions, Array.empty[Byte], hashFields.mkString(",").getBytes(), outputFields.mkString(",").getBytes() diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala index 84a0dc58bc35..4ee22c0bc24c 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala @@ -26,6 +26,7 @@ import io.glutenproject.utils.{LogLevelUtil, PhysicalPlanSelector} import org.apache.spark.api.python.EvalPythonExecTransformer import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.HashPartitioningWrapper import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BindReferences, BoundReference, Expression, Murmur3Hash, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} @@ -204,7 +205,10 @@ case class TransformPreOverrides(isAdaptiveContext: Boolean) val pos = newExpressionsPosition(i) newExprs = newExprs :+ project.output(pos) } - (projectExpressions.size, HashPartitioning(newExprs, numPartitions), project) + ( + projectExpressions.size, + new HashPartitioningWrapper(exprs, newExprs, numPartitions), + project) case RangePartitioning(orderings, numPartitions) => val exprs = orderings.map(ordering => ordering.child) val (projectExpressions, newExpressionsPosition) = { diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/HashPartitioningWrapper.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/HashPartitioningWrapper.scala new file mode 100644 index 000000000000..06ce8fe0f508 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/HashPartitioningWrapper.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning + +// A wrapper for HashPartitioning to remain original hash expressions. +// Only used by CH backend when shuffle hash expressions contains non-field expression. +class HashPartitioningWrapper( + original: Seq[Expression], + newExpr: Seq[Expression], + override val numPartitions: Int) + extends HashPartitioning(original, numPartitions) { + + def getNewExpr: Seq[Expression] = newExpr +}