Skip to content

Commit

Permalink
[GLUTEN-6950][CORE] Move specific rules into backend modules (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored and shamirchen committed Oct 14, 2024
1 parent 88ec5fd commit d7488e5
Show file tree
Hide file tree
Showing 17 changed files with 215 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
}

override def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = true

override def supportCartesianProductExec(): Boolean = true

override def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,8 @@ private object CHRuleApi {
injector.injectTransform(_ => RemoveTransitions)
injector.injectTransform(c => FallbackOnANSIMode.apply(c.session))
injector.injectTransform(c => FallbackMultiCodegens.apply(c.session))
injector.injectTransform(c => PlanOneRowRelation.apply(c.session))
injector.injectTransform(_ => RewriteSubqueryBroadcast())
injector.injectTransform(c => FallbackBroadcastHashJoin.apply(c.session))
injector.injectTransform(_ => FallbackEmptySchemaRelation())
injector.injectTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.session))
injector.injectTransform(_ => RewriteSparkPlanRulesManager())
injector.injectTransform(_ => AddFallbackTagRule())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ import org.apache.gluten.GlutenConfig
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar
package org.apache.gluten.extension

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial}
Expand All @@ -39,7 +38,7 @@ case class MergeTwoPhasesHashBaseAggregate(session: SparkSession) extends Rule[S
val columnarConf: GlutenConfig = GlutenConfig.getConf
val scanOnly: Boolean = columnarConf.enableScanOnly
val enableColumnarHashAgg: Boolean = !scanOnly && columnarConf.enableColumnarHashAgg
val replaceSortAggWithHashAgg = BackendsApiManager.getSettings.replaceSortAggWithHashAgg
val replaceSortAggWithHashAgg: Boolean = GlutenConfig.getConf.forceToUseHashAgg

private def isPartialAgg(partialAgg: BaseAggregateExec, finalAgg: BaseAggregateExec): Boolean = {
// TODO: now it can not support to merge agg which there are the filters in the aggregate exprs.
Expand All @@ -57,10 +56,7 @@ case class MergeTwoPhasesHashBaseAggregate(session: SparkSession) extends Rule[S
}

override def apply(plan: SparkPlan): SparkPlan = {
if (
!enableColumnarHashAgg || !BackendsApiManager.getSettings
.mergeTwoPhasesHashBaseAggregateIfNeed()
) {
if (!enableColumnarHashAgg) {
plan
} else {
plan.transformDown {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import java.lang.IllegalArgumentException

// For readable, people usually convert a unix timestamp into date, and compare it with another
// date. For example
// select * from table where '2023-11-02' >= from_unixtime(unix_timestamp, 'yyyy-MM-dd')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFo
import org.apache.gluten.utils._

import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, EulerNumber, Expression, Lag, Lead, Literal, MakeYMInterval, NamedExpression, NthValue, NTile, PercentRank, Pi, Rand, RangeFrame, Rank, RowNumber, SortOrder, SparkPartitionID, SparkVersion, SpecialFrameBoundary, SpecifiedWindowFrame, Uuid}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ApproximatePercentile, Count, Sum}
import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Lag, Lead, NamedExpression, NthValue, NTile, PercentRank, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ApproximatePercentile}
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
Expand Down Expand Up @@ -443,49 +441,6 @@ object VeloxBackendSettings extends BackendSettingsApi {
}
}

/**
* Check whether a plan needs to be offloaded even though they have empty input schema, e.g,
* Sum(1), Count(1), rand(), etc.
* @param plan:
* The Spark plan to check.
*/
private def mayNeedOffload(plan: SparkPlan): Boolean = {
def checkExpr(expr: Expression): Boolean = {
expr match {
// Block directly falling back the below functions by FallbackEmptySchemaRelation.
case alias: Alias => checkExpr(alias.child)
case _: Rand | _: Uuid | _: MakeYMInterval | _: SparkPartitionID | _: EulerNumber | _: Pi |
_: SparkVersion =>
true
case _ => false
}
}

plan match {
case exec: HashAggregateExec if exec.aggregateExpressions.nonEmpty =>
// Check Sum(Literal) or Count(Literal).
exec.aggregateExpressions.forall(
expression => {
val aggFunction = expression.aggregateFunction
aggFunction match {
case Sum(Literal(_, _), _) => true
case Count(Seq(Literal(_, _))) => true
case _ => false
}
})
case p: ProjectExec if p.projectList.nonEmpty =>
p.projectList.forall(checkExpr(_))
case _ =>
false
}
}

override def fallbackOnEmptySchema(plan: SparkPlan): Boolean = {
// Count(1) and Sum(1) are special cases that Velox backend can handle.
// Do not fallback it and its children in the first place.
!mayNeedOffload(plan)
}

override def fallbackAggregateWithEmptyOutputChild(): Boolean = true

override def recreateJoinExecOnFallback(): Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ package org.apache.gluten.backendsapi.velox

import org.apache.gluten.backendsapi.RuleApi
import org.apache.gluten.datasource.ArrowConvertorRule
import org.apache.gluten.extension._
import org.apache.gluten.extension.{ArrowScanReplaceRule, BloomFilterMightContainJointRewriteRule, CollectRewriteRule, FlushableHashAggregateRule, HLLRewriteRule}
import org.apache.gluten.extension.EmptySchemaWorkaround.{FallbackEmptySchemaRelation, PlanOneRowRelation}
import org.apache.gluten.extension.columnar._
import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides}
import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform
Expand Down Expand Up @@ -61,7 +62,6 @@ private object VeloxRuleApi {
injector.injectTransform(c => BloomFilterMightContainJointRewriteRule.apply(c.session))
injector.injectTransform(c => ArrowScanReplaceRule.apply(c.session))
injector.injectTransform(_ => FallbackEmptySchemaRelation())
injector.injectTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.session))
injector.injectTransform(_ => RewriteSparkPlanRulesManager())
injector.injectTransform(_ => AddFallbackTagRule())
injector.injectTransform(_ => TransformPreOverrides())
Expand Down Expand Up @@ -103,7 +103,6 @@ private object VeloxRuleApi {
injector.inject(_ => RewriteSubqueryBroadcast())
injector.inject(c => BloomFilterMightContainJointRewriteRule.apply(c.session))
injector.inject(c => ArrowScanReplaceRule.apply(c.session))
injector.inject(c => MergeTwoPhasesHashBaseAggregate.apply(c.session))

// Gluten RAS: The RAS rule.
injector.inject(c => EnumeratedTransform(c.session, c.outputsColumnar))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension

import org.apache.gluten.GlutenConfig
import org.apache.gluten.extension.columnar.FallbackTags

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, EulerNumber, Expression, Literal, MakeYMInterval, Pi, Rand, SparkPartitionID, SparkVersion, Uuid}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Sum}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ProjectExec, RDDScanExec, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources.WriteFilesExec
import org.apache.spark.sql.types.StringType

/** Rules to make Velox backend work correctly with query plans that have empty output schemas. */
object EmptySchemaWorkaround {

/**
* This rule plans [[RDDScanExec]] with a fake schema to make gluten work, because gluten does not
* support empty output relation, see [[FallbackEmptySchemaRelation]].
*/
case class PlanOneRowRelation(spark: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
if (!GlutenConfig.getConf.enableOneRowRelationColumnar) {
return plan
}

plan.transform {
// We should make sure the output does not change, e.g.
// Window
// OneRowRelation
case u: UnaryExecNode
if u.child.isInstanceOf[RDDScanExec] &&
u.child.asInstanceOf[RDDScanExec].name == "OneRowRelation" &&
u.outputSet != u.child.outputSet =>
val rdd = spark.sparkContext.parallelize(InternalRow(null) :: Nil, 1)
val attr = AttributeReference("fake_column", StringType)()
u.withNewChildren(RDDScanExec(attr :: Nil, rdd, "OneRowRelation") :: Nil)
}
}
}

/**
* FIXME To be removed: Since Velox backend is the only one to use the strategy, and we already
* support offloading zero-column batch in ColumnarBatchInIterator via PR #3309.
*
* We'd make sure all Velox operators be able to handle zero-column input correctly then remove
* the rule together with [[PlanOneRowRelation]].
*/
case class FallbackEmptySchemaRelation() extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = plan.transformDown {
case p =>
if (fallbackOnEmptySchema(p)) {
if (p.children.exists(_.output.isEmpty)) {
// Some backends are not eligible to offload plan with zero-column input.
// If any child have empty output, mark the plan and that child as UNSUPPORTED.
FallbackTags.add(p, "at least one of its children has empty output")
p.children.foreach {
child =>
if (child.output.isEmpty && !child.isInstanceOf[WriteFilesExec]) {
FallbackTags.add(child, "at least one of its children has empty output")
}
}
}
}
p
}

private def fallbackOnEmptySchema(plan: SparkPlan): Boolean = {
// Count(1) and Sum(1) are special cases that Velox backend can handle.
// Do not fallback it and its children in the first place.
!mayNeedOffload(plan)
}

/**
* Check whether a plan needs to be offloaded even though they have empty input schema, e.g,
* Sum(1), Count(1), rand(), etc.
* @param plan:
* The Spark plan to check.
*
* Since https://github.com/apache/incubator-gluten/pull/2749.
*/
private def mayNeedOffload(plan: SparkPlan): Boolean = {
def checkExpr(expr: Expression): Boolean = {
expr match {
// Block directly falling back the below functions by FallbackEmptySchemaRelation.
case alias: Alias => checkExpr(alias.child)
case _: Rand | _: Uuid | _: MakeYMInterval | _: SparkPartitionID | _: EulerNumber |
_: Pi | _: SparkVersion =>
true
case _ => false
}
}

plan match {
case exec: HashAggregateExec if exec.aggregateExpressions.nonEmpty =>
// Check Sum(Literal) or Count(Literal).
exec.aggregateExpressions.forall(
expression => {
val aggFunction = expression.aggregateFunction
aggFunction match {
case Sum(Literal(_, _), _) => true
case Count(Seq(Literal(_, _))) => true
case _ => false
}
})
case p: ProjectExec if p.projectList.nonEmpty =>
p.projectList.forall(checkExpr(_))
case _ =>
false
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ trait BackendSettingsApi {
case _ => false
}
def supportStructType(): Boolean = false
def fallbackOnEmptySchema(plan: SparkPlan): Boolean = false

// Whether to fallback aggregate at the same time if its empty-output child is fallen back.
def fallbackAggregateWithEmptyOutputChild(): Boolean = false
Expand All @@ -90,12 +89,6 @@ trait BackendSettingsApi {
def excludeScanExecFromCollapsedStage(): Boolean = false
def rescaleDecimalArithmetic: Boolean = false

/**
* Whether to replace sort agg with hash agg., e.g., sort agg will be used in spark's planning for
* string type input.
*/
def replaceSortAggWithHashAgg: Boolean = GlutenConfig.getConf.forceToUseHashAgg

/** Get the config prefix for each backend */
def getBackendConfigPrefix: String

Expand Down Expand Up @@ -147,9 +140,6 @@ trait BackendSettingsApi {

def supportSampleExec(): Boolean = false

/** Merge two phases hash based aggregate if need */
def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false

def supportColumnarArrowUdf(): Boolean = false

def generateHdfsConfForLibhdfs(): Boolean = false
Expand Down
Loading

0 comments on commit d7488e5

Please sign in to comment.