Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer committed Jun 28, 2024
1 parent e048991 commit b264284
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,28 @@
*/
package org.apache.gluten.backendsapi.velox

import org.apache.commons.lang3.ClassUtils
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.SparkPlanExecApi
import org.apache.gluten.datasource.ArrowConvertorRule
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression.ExpressionNames.{TRANSFORM_KEYS, TRANSFORM_VALUES}
import org.apache.gluten.expression._
import org.apache.gluten.expression.ExpressionNames.{TRANSFORM_KEYS, TRANSFORM_VALUES}
import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet}
import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar.TransformHints
import org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.extension.columnar.transition.ConventionFunc.BatchOverride
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.vectorized.{ColumnarBatchSerializeResult, ColumnarBatchSerializer}
import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSerializeResult}

import org.apache.spark.{ShuffleDependency, SparkException}
import org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, PullOutArrowEvalPythonPreProjectHelper}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.utils.ShuffleUtil
import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper}
import org.apache.spark.shuffle.utils.ShuffleUtil
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
Expand All @@ -59,10 +61,11 @@ import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, UserDefinedA
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.{ShuffleDependency, SparkException}

import org.apache.commons.lang3.ClassUtils

import javax.ws.rs.core.UriBuilder

import scala.collection.mutable.ListBuffer

class VeloxSparkPlanExecApi extends SparkPlanExecApi {
Expand All @@ -73,9 +76,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
}

/**
* Overrides [[org.apache.gluten.extension.columnar.transition.ConventionFunc]] Gluten is using
* to determine the convention (its row-based processing / columnar-batch processing support) of
* a plan with a user-defined function that accepts a plan then returns batch type it outputs.
* Overrides [[org.apache.gluten.extension.columnar.transition.ConventionFunc]] Gluten is using to
* determine the convention (its row-based processing / columnar-batch processing support) of a
* plan with a user-defined function that accepts a plan then returns batch type it outputs.
*/
override def batchTypeFunc(): BatchOverride = {
case i: InMemoryTableScanExec
Expand Down Expand Up @@ -155,8 +158,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
original.dataType match {
case LongType | IntegerType | ShortType | ByteType =>
case _ =>
throw new GlutenNotSupportException(
s"$substraitExprName with try mode is not supported")
throw new GlutenNotSupportException(s"$substraitExprName with try mode is not supported")
}
// Offload to velox for only IntegralTypes.
GenericExpressionTransformer(
Expand Down Expand Up @@ -683,8 +685,10 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
substraitExprName: String,
children: Seq[ExpressionTransformer],
expr: Expression): ExpressionTransformer = {
if (SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY)
!= SQLConf.MapKeyDedupPolicy.EXCEPTION.toString) {
if (
SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY)
!= SQLConf.MapKeyDedupPolicy.EXCEPTION.toString
) {
throw new GlutenNotSupportException("Only EXCEPTION policy is supported!")
}
GenericExpressionTransformer(substraitExprName, children, expr)
Expand Down Expand Up @@ -839,7 +843,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
Sig[TransformKeys](TRANSFORM_KEYS),
Sig[TransformValues](TRANSFORM_VALUES),
// For test purpose.
Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION))
Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION)
)
}

override def genInjectedFunctions()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.arrow.c.ArrowSchema
import scala.collection.JavaConverters.asScalaIteratorConverter

case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Array[Byte]])
extends BuildSideRelation {
extends BuildSideRelation {

override def deserialized: Iterator[ColumnarBatch] = {
val runtime = Runtimes.contextInstance("BuildSideRelation#deserialized")
Expand Down Expand Up @@ -82,8 +82,8 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra
override def asReadOnlyCopy(): ColumnarBuildSideRelation = this

/**
* Transform columnar broadcast value to Array[InternalRow] by key and distinct. NOTE: This
* method was called in Spark Driver, should manage resources carefully.
* Transform columnar broadcast value to Array[InternalRow] by key and distinct. NOTE: This method
* was called in Spark Driver, should manage resources carefully.
*/
override def transform(key: Expression): Array[InternalRow] = TaskResources.runUnsafe {
val runtime = Runtimes.contextInstance("BuildSideRelation#transform")
Expand Down Expand Up @@ -148,20 +148,20 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra
throw new IllegalArgumentException(s"Key column not found in expression: $key")
}
if (columnNames.size != 1) {
throw new IllegalArgumentException(
s"Multiple key columns found in expression: $key")
throw new IllegalArgumentException(s"Multiple key columns found in expression: $key")
}
val columnExpr = columnNames.head
val oneColumnWithSameName = output.count(_.name == columnExpr.name) == 1
val columnInOutput = output.zipWithIndex.filter { p: (Attribute, Int) =>
if (oneColumnWithSameName) {
// The comparison of exprId can be ignored when
// only one attribute name match is found.
p._1.name == columnExpr.name
} else {
// A case where output has multiple columns with same name
p._1.name == columnExpr.name && p._1.exprId == columnExpr.exprId
}
val columnInOutput = output.zipWithIndex.filter {
p: (Attribute, Int) =>
if (oneColumnWithSameName) {
// The comparison of exprId can be ignored when
// only one attribute name match is found.
p._1.name == columnExpr.name
} else {
// A case where output has multiple columns with same name
p._1.name == columnExpr.name && p._1.exprId == columnExpr.exprId
}
}
if (columnInOutput.isEmpty) {
throw new IllegalStateException(
Expand All @@ -174,8 +174,9 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra
val replacement =
BoundReference(columnInOutput.head._2, columnExpr.dataType, columnExpr.nullable)

val projExpr = key.transformDown { case _: AttributeReference =>
replacement
val projExpr = key.transformDown {
case _: AttributeReference =>
replacement
}

val proj = UnsafeProjection.create(projExpr)
Expand Down

0 comments on commit b264284

Please sign in to comment.