Skip to content

Commit

Permalink
[VL] Rework the implementation of spark.gluten.enabled (#7672)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Oct 25, 2024
1 parent c1ab7b3 commit 619624a
Show file tree
Hide file tree
Showing 33 changed files with 423 additions and 305 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.delta.commands

import org.apache.gluten.utils.QueryPlanSelector
import org.apache.gluten.extension.GlutenSessionExtensions

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.delta._
Expand Down Expand Up @@ -145,9 +145,11 @@ object VacuumCommand extends VacuumCommandImpl with Serializable {

// --- modified start
val originalEnabledGluten =
spark.sparkContext.getLocalProperty(QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY)
spark.sparkContext.getLocalProperty(GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY)
// gluten can not support vacuum command
spark.sparkContext.setLocalProperty(QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY, "false")
spark.sparkContext.setLocalProperty(
GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY,
"false")
// --- modified end

val validFiles = snapshot.stateDS
Expand Down Expand Up @@ -284,31 +286,37 @@ object VacuumCommand extends VacuumCommandImpl with Serializable {
} else {
allFilesAndDirs
.where('modificationTime < deleteBeforeTimestamp || 'isDir)
.mapPartitions { fileStatusIterator =>
val reservoirBase = new Path(basePath)
val fs = reservoirBase.getFileSystem(hadoopConf.value.value)
fileStatusIterator.flatMap { fileStatus =>
if (fileStatus.isDir) {
Iterator.single(relativize(fileStatus.getPath, fs, reservoirBase, isDir = true))
} else {
val dirs = getAllSubdirs(basePath, fileStatus.path, fs)
val dirsWithSlash = dirs.map { p =>
relativize(new Path(p), fs, reservoirBase, isDir = true)
}
dirsWithSlash ++ Iterator(
relativize(new Path(fileStatus.path), fs, reservoirBase, isDir = false))
.mapPartitions {
fileStatusIterator =>
val reservoirBase = new Path(basePath)
val fs = reservoirBase.getFileSystem(hadoopConf.value.value)
fileStatusIterator.flatMap {
fileStatus =>
if (fileStatus.isDir) {
Iterator.single(
relativize(fileStatus.getPath, fs, reservoirBase, isDir = true))
} else {
val dirs = getAllSubdirs(basePath, fileStatus.path, fs)
val dirsWithSlash = dirs.map {
p => relativize(new Path(p), fs, reservoirBase, isDir = true)
}
dirsWithSlash ++ Iterator(
relativize(new Path(fileStatus.path), fs, reservoirBase, isDir = false))
}
}
}
}.groupBy($"value" as 'path)
}
.groupBy($"value".as('path))
.count()
.join(validFiles, Seq("path"), "leftanti")
.where('count === 1)
.select('path)
.as[String]
.map { relativePath =>
assert(!stringToPath(relativePath).isAbsolute,
"Shouldn't have any absolute paths for deletion here.")
pathToString(DeltaFileOperations.absolutePath(basePath, relativePath))
.map {
relativePath =>
assert(
!stringToPath(relativePath).isAbsolute,
"Shouldn't have any absolute paths for deletion here.")
pathToString(DeltaFileOperations.absolutePath(basePath, relativePath))
}
}
// --- modified end
Expand Down Expand Up @@ -371,10 +379,12 @@ object VacuumCommand extends VacuumCommandImpl with Serializable {
// --- modified start
if (originalEnabledGluten != null) {
spark.sparkContext.setLocalProperty(
QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY, originalEnabledGluten)
GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY,
originalEnabledGluten)
} else {
spark.sparkContext.setLocalProperty(
QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY, "true")
GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY,
"true")
}
// --- modified end
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,16 @@ package org.apache.spark.sql.delta.commands
import java.net.URI
import java.util.Date
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._

import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.actions.{FileAction, RemoveFile}
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.util.DeltaFileOperations
import org.apache.spark.sql.delta.util.DeltaFileOperations.tryDeleteNonRecursive
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import org.apache.gluten.extension.GlutenSessionExtensions
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.gluten.utils.QueryPlanSelector
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
Expand Down Expand Up @@ -161,9 +158,9 @@ object VacuumCommand extends VacuumCommandImpl with Serializable {

// --- modified start
val originalEnabledGluten =
spark.sparkContext.getLocalProperty(QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY)
spark.sparkContext.getLocalProperty(GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY)
// gluten can not support vacuum command
spark.sparkContext.setLocalProperty(QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY, "false")
spark.sparkContext.setLocalProperty(GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY, "false")
// --- modified end

val validFiles = snapshot.stateDS
Expand Down Expand Up @@ -362,10 +359,10 @@ object VacuumCommand extends VacuumCommandImpl with Serializable {
// --- modified start
if (originalEnabledGluten != null) {
spark.sparkContext.setLocalProperty(
QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY, originalEnabledGluten)
GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY, originalEnabledGluten)
} else {
spark.sparkContext.setLocalProperty(
QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY, "true")
GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY, "true")
}
// --- modified end
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.util.DeltaFileOperations
import org.apache.spark.sql.delta.util.DeltaFileOperations.tryDeleteNonRecursive
import com.fasterxml.jackson.databind.annotation.JsonDeserialize

import org.apache.gluten.utils.QueryPlanSelector
import org.apache.gluten.extension.GlutenSessionExtensions
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.broadcast.Broadcast
Expand Down Expand Up @@ -254,9 +253,9 @@ object VacuumCommand extends VacuumCommandImpl with Serializable {

// --- modified start
val originalEnabledGluten =
spark.sparkContext.getLocalProperty(QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY)
spark.sparkContext.getLocalProperty(GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY)
// gluten can not support vacuum command
spark.sparkContext.setLocalProperty(QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY, "false")
spark.sparkContext.setLocalProperty(GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY, "false")
// --- modified end

val validFiles =
Expand Down Expand Up @@ -461,10 +460,10 @@ object VacuumCommand extends VacuumCommandImpl with Serializable {
// --- modified start
if (originalEnabledGluten != null) {
spark.sparkContext.setLocalProperty(
QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY, originalEnabledGluten)
GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY, originalEnabledGluten)
} else {
spark.sparkContext.setLocalProperty(
QueryPlanSelector.GLUTEN_ENABLE_FOR_THREAD_KEY, "true")
GlutenSessionExtensions.GLUTEN_ENABLE_FOR_THREAD_KEY, "true")
}
// --- modified end
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector}
import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasInjector}
import org.apache.gluten.parser.{GlutenCacheFilesSqlParser, GlutenClickhouseSqlParser}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.PhysicalPlanSelector

import org.apache.spark.sql.catalyst.{CHAggregateFunctionRewriteRule, EqualToRewrite}
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -39,8 +38,6 @@ import org.apache.spark.util.SparkPlanRules
class CHRuleApi extends RuleApi {
import CHRuleApi._
override def injectRules(injector: RuleInjector): Unit = {
injector.gluten.skipOn(PhysicalPlanSelector.skipCond)

injectSpark(injector.spark)
injectLegacy(injector.gluten.legacy)
injectRas(injector.gluten.ras)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.gluten.extension

import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
import org.apache.gluten.utils.PhysicalPlanSelector

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.LeftAnti
Expand All @@ -28,8 +27,7 @@ import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ClickHouseBuildSideRelation}

case class CHAQEPropagateEmptyRelation(session: SparkSession) extends Rule[SparkPlan] {

def apply(plan: SparkPlan): SparkPlan = PhysicalPlanSelector.maybe(session, plan) {
def apply(plan: SparkPlan): SparkPlan = {
if (!(session.conf.get(CHBackendSettings.GLUTEN_AQE_PROPAGATEEMPTY, "true").toBoolean)) {
plan
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.columnar._
import org.apache.gluten.extension.columnar.FallbackTags.EncodeFallbackTagImplicits
import org.apache.gluten.utils.PhysicalPlanSelector

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
Expand All @@ -38,7 +37,7 @@ import scala.util.control.Breaks.{break, breakable}
// to columnar while BHJ fallbacks, BroadcastExec need to be tagged not transformable when applying
// queryStagePrepRules.
case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = PhysicalPlanSelector.maybe(session, plan) {
override def apply(plan: SparkPlan): SparkPlan = {
val columnarConf: GlutenConfig = GlutenConfig.getConf
plan.foreach {
case bhj: BroadcastHashJoinExec =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,13 @@ import org.apache.gluten.extension.columnar.transition.{InsertTransitions, Remov
import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector}
import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasInjector}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.PhysicalPlanSelector

import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter}

class VeloxRuleApi extends RuleApi {
import VeloxRuleApi._

override def injectRules(injector: RuleInjector): Unit = {
injector.gluten.skipOn(PhysicalPlanSelector.skipCond)

injectSpark(injector.spark)
injectLegacy(injector.gluten.legacy)
injectRas(injector.gluten.ras)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.gluten.datasource
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.datasource.v2.ArrowCSVTable
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.LogicalPlanSelector

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.SparkSession
Expand All @@ -40,7 +39,7 @@ import scala.collection.convert.ImplicitConversions.`map AsScala`

@Experimental
case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = LogicalPlanSelector.maybe(session, plan) {
override def apply(plan: LogicalPlan): LogicalPlan = {
if (!BackendsApiManager.getSettings.enableNativeArrowReadFiles()) {
return plan
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ package org.apache.gluten.extension
import org.apache.gluten.datasource.ArrowCSVFileFormat
import org.apache.gluten.datasource.v2.ArrowCSVScan
import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec
import org.apache.gluten.utils.PhysicalPlanSelector

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ArrowFileSourceScanExec, FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec

case class ArrowScanReplaceRule(spark: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = PhysicalPlanSelector.maybe(spark, plan) {
override def apply(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case plan: FileSourceScanExec if plan.relation.fileFormat.isInstanceOf[ArrowCSVFileFormat] =>
ArrowFileSourceScanExec(plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.expression.VeloxBloomFilterMightContain
import org.apache.gluten.expression.aggregate.VeloxBloomFilterAggregate
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.PhysicalPlanSelector

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

case class BloomFilterMightContainJointRewriteRule(spark: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = PhysicalPlanSelector.maybe(spark, plan) {
override def apply(plan: SparkPlan): SparkPlan = {
if (!GlutenConfig.getConf.enableNativeBloomFilter) {
return plan
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.gluten.extension

import org.apache.gluten.expression.ExpressionMappings
import org.apache.gluten.expression.aggregate.{VeloxCollectList, VeloxCollectSet}
import org.apache.gluten.utils.LogicalPlanSelector

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Expression, WindowExpression}
Expand All @@ -36,7 +35,7 @@ import scala.reflect.{classTag, ClassTag}
*/
case class CollectRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {
import CollectRewriteRule._
override def apply(plan: LogicalPlan): LogicalPlan = LogicalPlanSelector.maybe(spark, plan) {
override def apply(plan: LogicalPlan): LogicalPlan = {
if (!has[VeloxCollectSet] && !has[VeloxCollectList]) {
return plan
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.gluten.extension

import org.apache.gluten.GlutenConfig
import org.apache.gluten.expression.aggregate.HLLAdapter
import org.apache.gluten.utils.LogicalPlanSelector

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Literal
Expand All @@ -29,7 +28,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, AGGREGATE_EXP
import org.apache.spark.sql.types._

case class HLLRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = LogicalPlanSelector.maybe(spark, plan) {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUpWithPruning(_.containsPattern(AGGREGATE)) {
case a: Aggregate =>
a.transformExpressionsWithPruning(_.containsPattern(AGGREGATE_EXPRESSION)) {
Expand Down
Loading

0 comments on commit 619624a

Please sign in to comment.