Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE][VL] Fix BatchScanExec filter pushdown logic #4132

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
with PredicateHelper {

override protected def doValidateInternal(): ValidationResult = {
val leftCondition = getLeftCondition
if (leftCondition == null) {
val remainingCondition = getRemainingCondition
if (remainingCondition == null) {
// All the filters can be pushed down and the computing of this Filter
// is not needed.
return ValidationResult.ok
Expand All @@ -37,16 +37,22 @@ case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
val operatorId = substraitContext.nextOperatorId(this.nodeName)
// Firstly, need to check if the Substrait plan for this operator can be successfully generated.
val relNode =
getRelNode(substraitContext, leftCondition, child.output, operatorId, null, validation = true)
getRelNode(
substraitContext,
remainingCondition,
child.output,
operatorId,
null,
validation = true)
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
val leftCondition = getLeftCondition
val remainingCondition = getRemainingCondition

val operatorId = context.nextOperatorId(this.nodeName)
if (leftCondition == null) {
if (remainingCondition == null) {
// The computing for this filter is not needed.
context.registerEmptyRelToOperator(operatorId)
// Since some columns' nullability will be removed after this filter, we need to update the
Expand All @@ -56,7 +62,7 @@ case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)

val currRel = getRelNode(
context,
leftCondition,
remainingCondition,
child.output,
operatorId,
childCtx.root,
Expand All @@ -65,7 +71,7 @@ case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
TransformContext(childCtx.outputAttributes, output, currRel)
}

private def getLeftCondition: Expression = {
private def getRemainingCondition: Expression = {
val scanFilters = child match {
// Get the filters including the manually pushed down ones.
case basicScanTransformer: BasicScanExecTransformer =>
Expand All @@ -77,9 +83,9 @@ case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
if (scanFilters.isEmpty) {
condition
} else {
val leftFilters =
FilterHandler.getLeftFilters(scanFilters, splitConjunctivePredicates(condition))
leftFilters.reduceLeftOption(And).orNull
val remainingFilters =
FilterHandler.getRemainingFilters(scanFilters, splitConjunctivePredicates(condition))
remainingFilters.reduceLeftOption(And).orNull
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ case class FilterExecTransformer(condition: Expression, child: SparkPlan)
with PredicateHelper {

override protected def doValidateInternal(): ValidationResult = {
val leftCondition = getLeftCondition
if (leftCondition == null) {
val remainingCondition = getRemainingCondition
if (remainingCondition == null) {
// All the filters can be pushed down and the computing of this Filter
// is not needed.
return ValidationResult.ok
Expand All @@ -37,25 +37,31 @@ case class FilterExecTransformer(condition: Expression, child: SparkPlan)
val operatorId = substraitContext.nextOperatorId(this.nodeName)
// Firstly, need to check if the Substrait plan for this operator can be successfully generated.
val relNode =
getRelNode(substraitContext, leftCondition, child.output, operatorId, null, validation = true)
getRelNode(
substraitContext,
remainingCondition,
child.output,
operatorId,
null,
validation = true)
// Then, validate the generated plan in native engine.
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
val leftCondition = getLeftCondition
val remainingCondition = getRemainingCondition

val operatorId = context.nextOperatorId(this.nodeName)
if (leftCondition == null) {
if (remainingCondition == null) {
// The computing for this filter is not needed.
context.registerEmptyRelToOperator(operatorId)
return childCtx
}

val currRel = getRelNode(
context,
leftCondition,
remainingCondition,
child.output,
operatorId,
childCtx.root,
Expand All @@ -64,7 +70,7 @@ case class FilterExecTransformer(condition: Expression, child: SparkPlan)
TransformContext(childCtx.outputAttributes, output, currRel)
}

private def getLeftCondition: Expression = {
private def getRemainingCondition: Expression = {
val scanFilters = child match {
// Get the filters including the manually pushed down ones.
case basicScanExecTransformer: BasicScanExecTransformer =>
Expand All @@ -76,9 +82,9 @@ case class FilterExecTransformer(condition: Expression, child: SparkPlan)
if (scanFilters.isEmpty) {
condition
} else {
val leftFilters =
FilterHandler.getLeftFilters(scanFilters, splitConjunctivePredicates(condition))
leftFilters.reduceLeftOption(And).orNull
val remainingFilters =
FilterHandler.getRemainingFilters(scanFilters, splitConjunctivePredicates(condition))
remainingFilters.reduceLeftOption(And).orNull
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.execution

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.GreaterThan
import org.apache.spark.sql.execution.ScalarSubquery

class VeloxScanSuite extends VeloxWholeStageTransformerSuite {
protected val rootPath: String = getClass.getResource("/").getPath
override protected val backend: String = "velox"
override protected val resourcePath: String = "/tpch-data-parquet-velox"
override protected val fileFormat: String = "parquet"

protected val veloxTPCHQueries: String = rootPath + "/tpch-queries-velox"
protected val queriesResults: String = rootPath + "queries-output"

override protected def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.adaptive.enabled", "false")

override def beforeAll(): Unit = {
super.beforeAll()
}

test("tpch q22 subquery filter pushdown - v1") {
createTPCHNotNullTables()
runTPCHQuery(22, veloxTPCHQueries, queriesResults, compareResult = false, noFallBack = false) {
df =>
val plan = df.queryExecution.executedPlan
val exist = plan.collect { case scan: FileSourceScanExecTransformer => scan }.exists {
scan =>
scan.filterExprs().exists {
case _ @GreaterThan(_, _: ScalarSubquery) => true
case _ => false
}
}
assert(exist)
}
}

test("tpch q22 subquery filter pushdown - v2") {
withSQLConf("spark.sql.sources.useV1SourceList" -> "") {
// Tables must be created here, otherwise v2 scan will not be used.
createTPCHNotNullTables()
runTPCHQuery(
22,
veloxTPCHQueries,
queriesResults,
compareResult = false,
noFallBack = false) {
df =>
val plan = df.queryExecution.executedPlan
val exist = plan.collect { case scan: BatchScanExecTransformer => scan }.exists {
scan =>
scan.filterExprs().exists {
case _ @GreaterThan(_, _: ScalarSubquery) => true
case _ => false
}
}
assert(exist)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -378,20 +378,34 @@ object FilterHandler extends PredicateHelper {
* @return
* the filter conditions not pushed down into Scan.
*/
def getLeftFilters(scanFilters: Seq[Expression], filters: Seq[Expression]): Seq[Expression] =
def getRemainingFilters(scanFilters: Seq[Expression], filters: Seq[Expression]): Seq[Expression] =
(ExpressionSet(filters) -- ExpressionSet(scanFilters)).toSeq

// Separate and compare the filter conditions in Scan and Filter.
// Push down the left conditions in Filter into Scan.
def applyFilterPushdownToScan(filter: FilterExec, reuseSubquery: Boolean): GlutenPlan =
// Push down the remaining conditions in Filter into Scan.
def applyFilterPushdownToScan(filter: FilterExec, reuseSubquery: Boolean): SparkPlan =
filter.child match {
case fileSourceScan: FileSourceScanExec =>
val leftFilters =
getLeftFilters(fileSourceScan.dataFilters, splitConjunctivePredicates(filter.condition))
val remainingFilters =
getRemainingFilters(
fileSourceScan.dataFilters,
splitConjunctivePredicates(filter.condition))
ScanTransformerFactory.createFileSourceScanTransformer(
fileSourceScan,
reuseSubquery,
extraFilters = leftFilters)
extraFilters = remainingFilters)
case batchScan: BatchScanExec =>
val remainingFilters = batchScan.scan match {
case fileScan: FileScan =>
getRemainingFilters(fileScan.dataFilters, splitConjunctivePredicates(filter.condition))
case _ =>
// TODO: For data lake format use pushedFilters in SupportsPushDownFilters
splitConjunctivePredicates(filter.condition)
}
ScanTransformerFactory.createBatchScanTransformer(
batchScan,
reuseSubquery,
pushdownFilters = remainingFilters)
case other =>
throw new UnsupportedOperationException(s"${other.getClass.toString} is not supported.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch

import java.util.Objects

import scala.collection.mutable.ListBuffer

/**
* Columnar Based BatchScanExec. Although keyGroupedPartitioning is not used, it cannot be deleted,
* it can make BatchScanExecTransformer contain a constructor with the same parameters as
Expand All @@ -57,9 +59,18 @@ class BatchScanExecTransformer(
@transient override lazy val metrics: Map[String, SQLMetric] =
BackendsApiManager.getMetricsApiInstance.genBatchScanTransformerMetrics(sparkContext)

// Similar to the problem encountered in https://github.com/oap-project/gluten/pull/3184,
// we cannot add member variables to BatchScanExecTransformer, which inherits from case
// class. Otherwise, we will encounter an issue where makeCopy cannot find a constructor
// with the corresponding number of parameters.
// The workaround is to add a mutable list to pass in pushdownFilters.
val pushdownFilters: ListBuffer[Expression] = ListBuffer.empty

def addPushdownFilters(filters: Seq[Expression]): Unit = pushdownFilters ++= filters

override def filterExprs(): Seq[Expression] = scan match {
case fileScan: FileScan =>
fileScan.dataFilters
fileScan.dataFilters ++ pushdownFilters
case _ =>
throw new UnsupportedOperationException(s"${scan.getClass.toString} is not supported")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
package io.glutenproject.execution

import io.glutenproject.expression.ExpressionConverter
import io.glutenproject.extension.columnar.TransformHints
import io.glutenproject.sql.shims.SparkShimLoader

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}

import java.util.ServiceLoader
Expand Down Expand Up @@ -67,15 +68,9 @@ object ScanTransformerFactory {
}
}

def createBatchScanTransformer(
private def lookupBatchScanTransformer(
batchScanExec: BatchScanExec,
reuseSubquery: Boolean,
validation: Boolean = false): BatchScanExecTransformer = {
val newPartitionFilters = if (validation) {
batchScanExec.runtimeFilters
} else {
ExpressionConverter.transformDynamicPruningExpr(batchScanExec.runtimeFilters, reuseSubquery)
}
newPartitionFilters: Seq[Expression]): BatchScanExecTransformer = {
val scan = batchScanExec.scan
lookupDataSourceScanTransformer(scan.getClass.getName) match {
case Some(clz) =>
Expand All @@ -85,11 +80,58 @@ object ScanTransformerFactory {
.asInstanceOf[DataSourceScanTransformerRegister]
.createDataSourceV2Transformer(batchScanExec, newPartitionFilters)
case _ =>
new BatchScanExecTransformer(
batchScanExec.output,
batchScanExec.scan,
newPartitionFilters,
table = SparkShimLoader.getSparkShims.getBatchScanExecTable(batchScanExec))
scan match {
case _: FileScan =>
new BatchScanExecTransformer(
batchScanExec.output,
batchScanExec.scan,
newPartitionFilters,
table = SparkShimLoader.getSparkShims.getBatchScanExecTable(batchScanExec)
)
case _ =>
throw new UnsupportedOperationException(s"Unsupported scan $scan")
}
}
}

def createBatchScanTransformer(
batchScan: BatchScanExec,
reuseSubquery: Boolean,
pushdownFilters: Seq[Expression] = Seq.empty,
validation: Boolean = false): SparkPlan = {
if (supportedBatchScan(batchScan.scan)) {
val newPartitionFilters = if (validation) {
// No transformation is needed for DynamicPruningExpressions
// during the validation process.
batchScan.runtimeFilters
} else {
ExpressionConverter.transformDynamicPruningExpr(batchScan.runtimeFilters, reuseSubquery)
}
val transformer = lookupBatchScanTransformer(batchScan, newPartitionFilters)
if (!validation && pushdownFilters.nonEmpty) {
transformer.addPushdownFilters(pushdownFilters)
// Validate again if pushdownFilters is not empty.
val validationResult = transformer.doValidate()
if (validationResult.isValid) {
transformer
} else {
val newSource = batchScan.copy(runtimeFilters = transformer.runtimeFilters)
TransformHints.tagNotTransformable(newSource, validationResult.reason.get)
newSource
}
} else {
transformer
}
} else {
if (validation) {
throw new UnsupportedOperationException(s"Unsupported scan ${batchScan.scan}")
}
// If filter expressions aren't empty, we need to transform the inner operators,
// and fallback the BatchScanExec itself.
val newSource = batchScan.copy(runtimeFilters = ExpressionConverter
.transformDynamicPruningExpr(batchScan.runtimeFilters, reuseSubquery))
TransformHints.tagNotTransformable(newSource, "The scan in BatchScanExec is not supported.")
newSource
}
}

Expand Down
Loading
Loading