Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Aug 14, 2024
1 parent 6fa72ee commit 0bc774f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging {
BackendsApiManager.initialize()
BackendsApiManager.getListenerApiInstance.onDriverStart(sc, pluginContext)
GlutenListenerFactory.addToSparkListenerBus(sc)
ExpressionMappings.expressionExtensionTransformer =
ExpressionMappings.setExpressionExtensionTransformer(
ExpressionUtil.extendedExpressionTransformer(
conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "")
)
))
Collections.emptyMap()
}

Expand Down Expand Up @@ -273,10 +273,10 @@ private[gluten] class GlutenExecutorPlugin extends ExecutorPlugin {
BackendsApiManager.initialize()
BackendsApiManager.getListenerApiInstance.onExecutorStart(ctx)

ExpressionMappings.expressionExtensionTransformer =
ExpressionMappings.setExpressionExtensionTransformer(
ExpressionUtil.extendedExpressionTransformer(
conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "")
)
))
}

/** Clean up and terminate this plugin. For example: close the native engine. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ object AggregateFunctionsBuilder {
// First handle the custom aggregate functions
val (substraitAggFuncName, inputTypes) =
if (
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
ExpressionMappings.getExpressionExtensionTransformer.extensionExpressionsMapping.contains(
aggregateFunc.getClass)
) {
val (substraitAggFuncName, inputTypes) =
ExpressionMappings.expressionExtensionTransformer.buildCustomAggregateFunction(
ExpressionMappings.getExpressionExtensionTransformer.buildCustomAggregateFunction(
aggregateFunc)
assert(substraitAggFuncName.isDefined)
(substraitAggFuncName.get, inputTypes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ object ExpressionConverter extends SQLConfHelper with Logging {

expr match {
case extendedExpr
if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
extendedExpr.getClass) =>
if ExpressionMappings.getExpressionExtensionTransformer.extensionExpressionsMapping
.contains(extendedExpr.getClass) =>
// Use extended expression transformer to replace custom expression first
ExpressionMappings.expressionExtensionTransformer
ExpressionMappings.getExpressionExtensionTransformer
.replaceWithExtensionExpressionTransformer(substraitExprName, extendedExpr, attributeSeq)
case c: CreateArray =>
val children =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.gluten.expression
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.ExpressionNames._
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer, ExpressionExtensionTrait}
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -339,7 +339,7 @@ object ExpressionMappings {
def expressionsMap: Map[Class[_], String] = {
val blacklist = GlutenConfig.getConf.expressionBlacklist
val supportedExprs = defaultExpressionsMap ++
expressionExtensionTransformer.extensionExpressionsMapping
expressionExtensionTransformer.get.extensionExpressionsMapping
if (blacklist.isEmpty) {
supportedExprs
} else {
Expand All @@ -354,5 +354,16 @@ object ExpressionMappings {
.toMap[Class[_], String]
}

var expressionExtensionTransformer: ExpressionExtensionTrait = _
private var expressionExtensionTransformer: Option[ExpressionExtensionTrait] = None

def getExpressionExtensionTransformer: ExpressionExtensionTrait = {
expressionExtensionTransformer.getOrElse(new DefaultExpressionExtensionTransformer)
}

def setExpressionExtensionTransformer(value: ExpressionExtensionTrait): Unit = {
if (!expressionExtensionTransformer.isDefined) {
expressionExtensionTransformer = Some(value)
}
}

}

0 comments on commit 0bc774f

Please sign in to comment.