Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-7261][CORE] Support offloading partial filters to native scan #8082

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.gluten.utils.VeloxFileSystemValidationJniWrapper
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.GreaterThan
import org.apache.spark.sql.execution.ScalarSubquery
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

class VeloxScanSuite extends VeloxWholeStageTransformerSuite {
Expand Down Expand Up @@ -150,4 +151,40 @@ class VeloxScanSuite extends VeloxWholeStageTransformerSuite {
}
}
}

test("push partial filters to offload scan when filter need fallback - v1") {
withSQLConf(GlutenConfig.EXPRESSION_BLACK_LIST.key -> "add") {
createTPCHNotNullTables()
val query = "select l_partkey from lineitem where l_partkey + 1 > 5 and l_partkey - 1 < 8"
runQueryAndCompare(query) {
df =>
{
val executedPlan = getExecutedPlan(df)
val scans = executedPlan.collect { case p: FileSourceScanExecTransformer => p }
assert(scans.size == 1)
// isnotnull(l_partkey) and l_partkey - 1 < 8
assert(scans.head.filterExprs().size == 2)
}
}
}
}

test("push partial filters to offload scan when filter need fallback - v2") {
withSQLConf(
GlutenConfig.EXPRESSION_BLACK_LIST.key -> "add",
SQLConf.USE_V1_SOURCE_LIST.key -> "") {
createTPCHNotNullTables()
val query = "select l_partkey from lineitem where l_partkey + 1 > 5 and l_partkey - 1 < 8"
runQueryAndCompare(query) {
df =>
{
val executedPlan = getExecutedPlan(df)
val scans = executedPlan.collect { case p: BatchScanExecTransformer => p }
assert(scans.size == 1)
// isnotnull(l_partkey) and l_partkey - 1 < 8
assert(scans.head.filterExprs().size == 2)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ case class IcebergScanTransformer(
IcebergScanTransformer.supportsBatchScan(scan)
}

override def filterExprs(): Seq[Expression] = pushdownFilters.getOrElse(Seq.empty)

override lazy val getPartitionSchema: StructType =
GlutenIcebergSourceUtil.getReadPartitionSchema(scan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

import com.google.protobuf.StringValue
import io.substrait.proto.NamedStruct
Expand Down Expand Up @@ -131,11 +131,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource
}.asJava
// Will put all filter expressions into an AND expression
val transformer = filterExprs()
.map {
case ar: AttributeReference if ar.dataType == BooleanType =>
EqualNullSafe(ar, Literal.TrueLiteral)
case e => e
}
.map(ExpressionConverter.replaceAttributeReference)
.reduceLeftOption(And)
.map(ExpressionConverter.replaceWithExpressionTransformer(_, output))
val filterNodes = transformer.map(_.doTransform(context.registeredFunction))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.ExpressionConverter
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.sql.shims.SparkShimLoader
Expand Down Expand Up @@ -101,18 +101,24 @@ abstract class BatchScanExecTransformerBase(
// class. Otherwise, we will encounter an issue where makeCopy cannot find a constructor
// with the corresponding number of parameters.
// The workaround is to add a mutable list to pass in pushdownFilters.
protected var pushdownFilters: Option[Seq[Expression]] = None
protected var pushdownFilters: Seq[Expression] = scan match {
case fileScan: FileScan =>
fileScan.dataFilters.filter {
expr =>
ExpressionConverter.canReplaceWithExpressionTransformer(
ExpressionConverter.replaceAttributeReference(expr),
output)
}
case _ =>
logInfo(s"${scan.getClass.toString} does not support push down filters")
Seq.empty
}

def setPushDownFilters(filters: Seq[Expression]): Unit = {
pushdownFilters = Some(filters)
pushdownFilters = filters
}

override def filterExprs(): Seq[Expression] = scan match {
case fileScan: FileScan =>
pushdownFilters.getOrElse(fileScan.dataFilters)
case _ =>
throw new GlutenNotSupportException(s"${scan.getClass.toString} is not supported")
}
override def filterExprs(): Seq[Expression] = pushdownFilters

override def getMetadataColumns(): Seq[AttributeReference] = Seq.empty

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.ExpressionConverter
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.sql.shims.SparkShimLoader
Expand Down Expand Up @@ -102,7 +103,12 @@ abstract class FileSourceScanExecTransformerBase(
.genFileSourceScanTransformerMetrics(sparkContext)
.filter(m => !driverMetricsAlias.contains(m._1)) ++ driverMetricsAlias

override def filterExprs(): Seq[Expression] = dataFiltersInScan
override def filterExprs(): Seq[Expression] = dataFiltersInScan.filter {
expr =>
ExpressionConverter.canReplaceWithExpressionTransformer(
ExpressionConverter.replaceAttributeReference(expr),
output)
}

override def getMetadataColumns(): Seq[AttributeReference] = metadataColumns

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ object ExpressionConverter extends SQLConfHelper with Logging {
replaceWithExpressionTransformer0(expr, attributeSeq, expressionsMap)
}

def canReplaceWithExpressionTransformer(
expr: Expression,
attributeSeq: Seq[Attribute]): Boolean = {
try {
replaceWithExpressionTransformer(expr, attributeSeq)
true
} catch {
case e: Exception =>
logInfo(e.getMessage)
false
}
}

def replaceAttributeReference(expr: Expression): Expression = expr match {
case ar: AttributeReference if ar.dataType == BooleanType =>
EqualNullSafe(ar, Literal.TrueLiteral)
case e => e
}

private def replacePythonUDFWithExpressionTransformer(
udf: PythonUDF,
attributeSeq: Seq[Attribute],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ object PushDownFilterToScan extends Rule[SparkPlan] with PredicateHelper {
// If BatchScanExecTransformerBase's parent is filter, pushdownFilters can't be None.
batchScan.setPushDownFilters(Seq.empty)
val newScan = batchScan
if (pushDownFilters.size > 0) {
if (pushDownFilters.nonEmpty) {
newScan.setPushDownFilters(pushDownFilters)
if (newScan.doValidate().ok()) {
filter.withNewChildren(Seq(newScan))
Expand Down
Loading