From e04899112e32c5ad8b5fd6df0f011304432bfc17 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Fri, 28 Jun 2024 10:06:37 +0800 Subject: [PATCH] fixup --- .../velox/VeloxSparkPlanExecApi.scala | 59 +++++++------------ .../execution/ColumnarBuildSideRelation.scala | 36 +++++------ 2 files changed, 39 insertions(+), 56 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index b48da15683e85..988451a699a58 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -16,28 +16,26 @@ */ package org.apache.gluten.backendsapi.velox +import org.apache.commons.lang3.ClassUtils import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.SparkPlanExecApi import org.apache.gluten.datasource.ArrowConvertorRule import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ -import org.apache.gluten.expression._ import org.apache.gluten.expression.ExpressionNames.{TRANSFORM_KEYS, TRANSFORM_VALUES} +import org.apache.gluten.expression._ import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet} import org.apache.gluten.extension._ import org.apache.gluten.extension.columnar.TransformHints import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.extension.columnar.transition.ConventionFunc.BatchOverride import org.apache.gluten.sql.shims.SparkShimLoader -import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSerializeResult} - -import org.apache.spark.{ShuffleDependency, SparkException} +import org.apache.gluten.vectorized.{ColumnarBatchSerializeResult, ColumnarBatchSerializer} import org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, PullOutArrowEvalPythonPreProjectHelper} import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper} import org.apache.spark.shuffle.utils.ShuffleUtil -import org.apache.spark.sql.{SparkSession, Strategy} +import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.BucketSpec @@ -61,11 +59,10 @@ import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, UserDefinedA import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch - -import org.apache.commons.lang3.ClassUtils +import org.apache.spark.sql.{SparkSession, Strategy} +import org.apache.spark.{ShuffleDependency, SparkException} import javax.ws.rs.core.UriBuilder - import scala.collection.mutable.ListBuffer class VeloxSparkPlanExecApi extends SparkPlanExecApi { @@ -76,9 +73,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { } /** - * Overrides [[org.apache.gluten.extension.columnar.transition.ConventionFunc]] Gluten is using to - * determine the convention (its row-based processing / columnar-batch processing support) of a - * plan with a user-defined function that accepts a plan then returns batch type it outputs. + * Overrides [[org.apache.gluten.extension.columnar.transition.ConventionFunc]] Gluten is using + * to determine the convention (its row-based processing / columnar-batch processing support) of + * a plan with a user-defined function that accepts a plan then returns batch type it outputs. */ override def batchTypeFunc(): BatchOverride = { case i: InMemoryTableScanExec @@ -110,8 +107,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { GenericExpressionTransformer(condFuncName, Seq(left), condExpr), right, left, - newExpr - ) + newExpr) } /** Transform Uuid to Substrait. */ @@ -159,7 +155,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { original.dataType match { case LongType | IntegerType | ShortType | ByteType => case _ => - throw new GlutenNotSupportException(s"$substraitExprName with try mode is not supported") + throw new GlutenNotSupportException( + s"$substraitExprName with try mode is not supported") } // Offload to velox for only IntegralTypes. GenericExpressionTransformer( @@ -488,8 +485,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { left, right, isSkewJoin, - projectList - ) + projectList) } override def genCartesianProductExecTransformer( left: SparkPlan, @@ -498,8 +494,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { CartesianProductExecTransformer( ColumnarCartesianProductBridge(left), ColumnarCartesianProductBridge(right), - condition - ) + condition) } override def genBroadcastNestedLoopJoinExecTransformer( @@ -508,13 +503,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { buildSide: BuildSide, joinType: JoinType, condition: Option[Expression]): BroadcastNestedLoopJoinExecTransformer = - VeloxBroadcastNestedLoopJoinExecTransformer( - left, - right, - buildSide, - joinType, - condition - ) + VeloxBroadcastNestedLoopJoinExecTransformer(left, right, buildSide, joinType, condition) override def genHashExpressionTransformer( substraitExprName: String, @@ -694,10 +683,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { substraitExprName: String, children: Seq[ExpressionTransformer], expr: Expression): ExpressionTransformer = { - if ( - SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY) - != SQLConf.MapKeyDedupPolicy.EXCEPTION.toString - ) { + if (SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY) + != SQLConf.MapKeyDedupPolicy.EXCEPTION.toString) { throw new GlutenNotSupportException("Only EXCEPTION policy is supported!") } GenericExpressionTransformer(substraitExprName, children, expr) @@ -795,10 +782,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { * * @return */ - override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = List( - CollectRewriteRule.apply, - HLLRewriteRule.apply - ) + override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = + List(CollectRewriteRule.apply, HLLRewriteRule.apply) /** * Generate extended columnar pre-rules, in the validation phase. @@ -854,8 +839,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { Sig[TransformKeys](TRANSFORM_KEYS), Sig[TransformValues](TRANSFORM_VALUES), // For test purpose. - Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION) - ) + Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION)) } override def genInjectedFunctions() @@ -884,8 +868,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { requiredChildOutput: Seq[Attribute], outer: Boolean, generatorOutput: Seq[Attribute], - child: SparkPlan - ): GenerateExecTransformerBase = { + child: SparkPlan): GenerateExecTransformerBase = { GenerateExecTransformer(generator, requiredChildOutput, outer, generatorOutput, child) } diff --git a/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala index 8403aaebf92c3..74c010cfe26db 100644 --- a/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala +++ b/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala @@ -37,10 +37,10 @@ import org.apache.arrow.c.ArrowSchema import scala.collection.JavaConverters.asScalaIteratorConverter case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Array[Byte]]) - extends BuildSideRelation { - private val runtime = Runtimes.contextInstance("BuildSideRelation#BatchSerializer") + extends BuildSideRelation { override def deserialized: Iterator[ColumnarBatch] = { + val runtime = Runtimes.contextInstance("BuildSideRelation#deserialized") val jniWrapper = ColumnarBatchSerializerJniWrapper.create(runtime) val serializeHandle: Long = { val allocator = ArrowBufferAllocators.contextInstance() @@ -82,10 +82,11 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra override def asReadOnlyCopy(): ColumnarBuildSideRelation = this /** - * Transform columnar broadcast value to Array[InternalRow] by key and distinct. NOTE: This method - * was called in Spark Driver, should manage resources carefully. + * Transform columnar broadcast value to Array[InternalRow] by key and distinct. NOTE: This + * method was called in Spark Driver, should manage resources carefully. */ override def transform(key: Expression): Array[InternalRow] = TaskResources.runUnsafe { + val runtime = Runtimes.contextInstance("BuildSideRelation#transform") // This transformation happens in Spark driver, thus resources can not be managed automatically. val serializerJniWrapper = ColumnarBatchSerializerJniWrapper.create(runtime) val serializeHandle = { @@ -147,20 +148,20 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra throw new IllegalArgumentException(s"Key column not found in expression: $key") } if (columnNames.size != 1) { - throw new IllegalArgumentException(s"Multiple key columns found in expression: $key") + throw new IllegalArgumentException( + s"Multiple key columns found in expression: $key") } val columnExpr = columnNames.head val oneColumnWithSameName = output.count(_.name == columnExpr.name) == 1 - val columnInOutput = output.zipWithIndex.filter { - p: (Attribute, Int) => - if (oneColumnWithSameName) { - // The comparison of exprId can be ignored when - // only one attribute name match is found. - p._1.name == columnExpr.name - } else { - // A case where output has multiple columns with same name - p._1.name == columnExpr.name && p._1.exprId == columnExpr.exprId - } + val columnInOutput = output.zipWithIndex.filter { p: (Attribute, Int) => + if (oneColumnWithSameName) { + // The comparison of exprId can be ignored when + // only one attribute name match is found. + p._1.name == columnExpr.name + } else { + // A case where output has multiple columns with same name + p._1.name == columnExpr.name && p._1.exprId == columnExpr.exprId + } } if (columnInOutput.isEmpty) { throw new IllegalStateException( @@ -173,9 +174,8 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra val replacement = BoundReference(columnInOutput.head._2, columnExpr.dataType, columnExpr.nullable) - val projExpr = key.transformDown { - case _: AttributeReference => - replacement + val projExpr = key.transformDown { case _: AttributeReference => + replacement } val proj = UnsafeProjection.create(projExpr)