Skip to content

Commit

Permalink
[GLUTEN-4000][CORE] Apply Basic Common Subexpression Elimination for …
Browse files Browse the repository at this point in the history
…Spark Logical Plan (#4016)

[CORE] Apply Basic Common Subexpression Elimination for Spark Logical Plan
Now only for CH backend.
  • Loading branch information
taiyang-li authored Dec 26, 2023
1 parent 57f4df1 commit 42e6990
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildS
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.utils.CHExecUtil
import org.apache.spark.sql.extension.ClickHouseAnalysis
import org.apache.spark.sql.extension.CommonSubexpressionEliminateRule
import org.apache.spark.sql.extension.RewriteDateTimestampComparisonRule
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -340,7 +341,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
*/
override def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]] = {
val analyzers = List(spark => new ClickHouseAnalysis(spark, spark.sessionState.conf))
if (GlutenConfig.getConf.enableDateTimestampComparison) {
if (GlutenConfig.getConf.enableRewriteDateTimestampComparison) {
analyzers :+ (spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf))
} else {
analyzers
Expand All @@ -353,7 +354,12 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
* @return
*/
override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = {
List.empty
var optimizers = List.empty[SparkSession => Rule[LogicalPlan]]
if (GlutenConfig.getConf.enableCommonSubexpressionEliminate) {
optimizers = optimizers :+ (
spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf))
}
optimizers
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class GlutenClickHouseTPCDSParquetColumnarShuffleSuite extends GlutenClickHouseT
// .set("spark.sql.files.maxPartitionBytes", "134217728")
// .set("spark.sql.files.openCostInBytes", "134217728")
.set("spark.memory.offHeap.size", "4g")
// .set("spark.sql.planChangeLog.level", "error")
}

executeTPCDSTest(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class GlutenClickHouseTPCHParquetSuite extends GlutenClickHouseTPCHAbstractSuite
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.gluten.sql.columnar.backend.ch.use.v2", "false")
.set("spark.gluten.supported.scala.udfs", "my_add")
// .set("spark.sql.planChangeLog.level", "error")
}

override protected val createNullableTables = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package io.glutenproject.execution
import io.glutenproject.GlutenConfig
import io.glutenproject.utils.UTSystemParameters

import org.apache.spark.SPARK_VERSION_SHORT
import org.apache.spark.SparkConf
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation}
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -29,6 +31,7 @@ import java.nio.file.Files
import java.sql.Date

import scala.collection.immutable.Seq
import scala.reflect.ClassTag

class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerSuite {
override protected val resourcePath: String = {
Expand All @@ -39,6 +42,11 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
protected val rootPath: String = getClass.getResource("/").getPath
protected val basePath: String = rootPath + "unit-tests-working-home"

protected lazy val sparkVersion: String = {
val version = SPARK_VERSION_SHORT.split("\\.")
version(0) + "." + version(1)
}

protected val tablesPath: String = basePath + "/tpch-data"
protected val tpchQueries: String =
rootPath + "../../../../gluten-core/src/test/resources/tpch-queries"
Expand Down Expand Up @@ -527,4 +535,51 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
)(checkOperatorMatch[ProjectExecTransformer])
}
}

test("test common subexpression eliminate") {
def checkOperatorCount[T <: TransformSupport](count: Int)(df: DataFrame)(implicit
tag: ClassTag[T]): Unit = {
if (sparkVersion.equals("3.3")) {
assert(
getExecutedPlan(df).count(
plan => {
plan.getClass == tag.runtimeClass
}) == count,
s"executed plan: ${getExecutedPlan(df)}")
}
}

withSQLConf(("spark.gluten.sql.commonSubexpressionEliminate", "true")) {
// CSE in project
runQueryAndCompare("select hash(id), hash(id)+1, hash(id)-1 from range(10)") {
df => checkOperatorCount[ProjectExecTransformer](2)(df)
}

// CSE in filter(not work yet)
// runQueryAndCompare(
// "select id from range(10) " +
// "where hex(id) != '' and upper(hex(id)) != '' and lower(hex(id)) != ''") { _ => }

// CSE in window
runQueryAndCompare(
"SELECT id, AVG(id) OVER (PARTITION BY id % 2 ORDER BY id) as avg_id, " +
"SUM(id) OVER (PARTITION BY id % 2 ORDER BY id) as sum_id FROM range(10)") {
df => checkOperatorCount[ProjectExecTransformer](4)(df)
}

// CSE in aggregate
runQueryAndCompare(
"select id % 2, max(hash(id)), min(hash(id)) " +
"from range(10) group by id % 2") {
df => checkOperatorCount[ProjectExecTransformer](1)(df)
}

// CSE in sort
runQueryAndCompare(
"select id from range(10) " +
"order by hash(id%10), hash(hash(id%10))") {
df => checkOperatorCount[ProjectExecTransformer](2)(df)
}
}
}
}
3 changes: 2 additions & 1 deletion cpp-ch/local-engine/Parser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ DB::Block TypeParser::buildBlockFromNamedStruct(const substrait::NamedStruct & s

// This is a partial aggregate data column.
// It's type is special, must be a struct type contains all arguments types.
// Notice: there are some coincidence cases in which the type is not a struct type, e.g. name is "_1#913 + _2#914#928". We need to handle it.
Poco::StringTokenizer name_parts(name, "#");
if (name_parts.count() >= 4)
if (name_parts.count() >= 4 && !name.contains(' '))
{
auto nested_data_type = DB::removeNullable(ch_type);
const auto * tuple_type = typeid_cast<const DB::DataTypeTuple *>(nested_data_type.get());
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Rewriter/ExpressionRewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class GetJsonObjectFunctionWriter : public RelRewriter
arg0->CopyFrom(scalar_function_pb.arguments(0));
auto * arg1 = decoded_json_function.add_arguments();
arg1->mutable_value()->mutable_literal()->set_string(required_fields_str);

substrait::Expression new_get_json_object_arg0;
new_get_json_object_arg0.mutable_scalar_function()->CopyFrom(decoded_json_function);
*scalar_function_pb.mutable_arguments()->Mutable(0)->mutable_value() = new_get_json_object_arg0;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.extension

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf

import scala.collection.mutable

class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf)
extends Rule[LogicalPlan]
with Logging {

private var lastPlan: LogicalPlan = null

override def apply(plan: LogicalPlan): LogicalPlan = {
val newPlan = if (plan.resolved && !plan.fastEquals(lastPlan)) {
lastPlan = plan
visitPlan(plan)
} else {
plan
}
newPlan
}

private case class AliasAndAttribute(alias: Alias, attribute: Attribute)

private case class RewriteContext(exprs: Seq[Expression], child: LogicalPlan)

private def visitPlan(plan: LogicalPlan): LogicalPlan = {
var newPlan = plan match {
case project: Project => visitProject(project)
// TODO: CSE in Filter doesn't work for unknown reason, need to fix it later
// case filter: Filter => visitFilter(filter)
case window: Window => visitWindow(window)
case aggregate: Aggregate => visitAggregate(aggregate)
case sort: Sort => visitSort(sort)
case other =>
val children = other.children.map(visitPlan)
other.withNewChildren(children)
}

if (newPlan.output.size == plan.output.size) {
return newPlan
}

// Add a Project to trim unnecessary attributes(which are always at the end of the output)
val postProjectList = newPlan.output.take(plan.output.size)
Project(postProjectList, newPlan)
}

private def replaceCommonExprWithAttribute(
expr: Expression,
commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute]): Expression = {
val exprEquals = commonExprMap.get(ExpressionEquals(expr))
if (exprEquals.isDefined) {
exprEquals.get.attribute
} else {
expr.mapChildren(replaceCommonExprWithAttribute(_, commonExprMap))
}
}

private def isValidCommonExpr(expr: Expression): Boolean = {
if (
(expr.isInstanceOf[Unevaluable] && !expr.isInstanceOf[AttributeReference])
|| expr.isInstanceOf[AggregateFunction]
|| (expr.isInstanceOf[AttributeReference]
&& expr.asInstanceOf[AttributeReference].name == VirtualColumn.groupingIdName)
) {
logTrace(s"Check common expression failed $expr class ${expr.getClass.toString}")
return false
}

expr.children.forall(isValidCommonExpr(_))
}

private def rewrite(inputCtx: RewriteContext): RewriteContext = {
logTrace(s"Start rewrite with input exprs:${inputCtx.exprs} input child:${inputCtx.child}")
val equivalentExpressions = new EquivalentExpressions
inputCtx.exprs.foreach(equivalentExpressions.addExprTree(_))

// Get all the expressions that appear at least twice
val newChild = visitPlan(inputCtx.child)
val commonExprs = equivalentExpressions.getCommonSubexpressions

// Put the common expressions into a hash map
val commonExprMap = mutable.HashMap.empty[ExpressionEquals, AliasAndAttribute]
commonExprs.foreach {
expr =>
if (!expr.foldable && !expr.isInstanceOf[Attribute] && isValidCommonExpr(expr)) {
logTrace(s"Common subexpression $expr class ${expr.getClass.toString}")
val exprEquals = ExpressionEquals(expr)
val alias = Alias(expr, expr.toString)()
val attribute = alias.toAttribute
commonExprMap.put(exprEquals, AliasAndAttribute(alias, attribute))
}
}

if (commonExprMap.isEmpty) {
logTrace(s"commonExprMap is empty, all exprs: ${equivalentExpressions.debugString(true)}")
return RewriteContext(inputCtx.exprs, newChild)
}

// Generate pre-project as new child
var preProjectList = newChild.output ++ commonExprMap.values.map(_.alias)
val preProject = Project(preProjectList, newChild)
logTrace(s"newChild after rewrite: $preProject")

// Replace the common expressions with the first expression that produces it.
try {
var newExprs = inputCtx.exprs
.map(replaceCommonExprWithAttribute(_, commonExprMap))
logTrace(s"newExprs after rewrite: $newExprs")
RewriteContext(newExprs, preProject)
} catch {
case e: Exception =>
logWarning(
s"Common subexpression eliminate failed with exception: ${e.getMessage}" +
s" while replace ${inputCtx.exprs} with $commonExprMap, fallback now")
RewriteContext(inputCtx.exprs, newChild)
}
}

private def visitProject(project: Project): Project = {
val inputCtx = RewriteContext(project.projectList, project.child)
val outputCtx = rewrite(inputCtx)
Project(outputCtx.exprs.map(_.asInstanceOf[NamedExpression]), outputCtx.child)
}

private def visitFilter(filter: Filter): Filter = {
val inputCtx = RewriteContext(Seq(filter.condition), filter.child)
val outputCtx = rewrite(inputCtx)
Filter(outputCtx.exprs.head, outputCtx.child)
}

private def visitWindow(window: Window): Window = {
val inputCtx = RewriteContext(window.windowExpressions, window.child)
val outputCtx = rewrite(inputCtx)
Window(
outputCtx.exprs.map(_.asInstanceOf[NamedExpression]),
window.partitionSpec,
window.orderSpec,
outputCtx.child)
}

private def visitAggregate(aggregate: Aggregate): Aggregate = {
logTrace(
s"aggregate groupingExpressions: ${aggregate.groupingExpressions} " +
s"aggregateExpressions: ${aggregate.aggregateExpressions}")
val groupingSize = aggregate.groupingExpressions.size
val aggregateSize = aggregate.aggregateExpressions.size

val inputCtx = RewriteContext(
aggregate.groupingExpressions ++ aggregate.aggregateExpressions,
aggregate.child)
val outputCtx = rewrite(inputCtx)
Aggregate(
outputCtx.exprs.slice(0, groupingSize),
outputCtx.exprs
.slice(groupingSize, groupingSize + aggregateSize)
.map(_.asInstanceOf[NamedExpression]),
outputCtx.child
)
}

private def visitSort(sort: Sort): Sort = {
val exprs = sort.order.flatMap(_.children)
val inputCtx = RewriteContext(exprs, sort.child)
val outputCtx = rewrite(inputCtx)

var start = 0;
var newOrder = Seq.empty[SortOrder]
sort.order.foreach(
order => {
val childrenSize = order.children.size
val newChildren = outputCtx.exprs.slice(start, start + childrenSize)
newOrder = newOrder :+ order.withNewChildren(newChildren).asInstanceOf[SortOrder]
start += childrenSize
})

Sort(newOrder, sort.global, outputCtx.child)
}
}
17 changes: 15 additions & 2 deletions shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ class GlutenConfig(conf: SQLConf) extends Logging {

def columnarTableCacheEnabled: Boolean = conf.getConf(COLUMNAR_TABLE_CACHE_ENABLED)

def enableDateTimestampComparison: Boolean = conf.getConf(ENABLE_DATE_TIMESTAMP_COMPARISON)
def enableRewriteDateTimestampComparison: Boolean =
conf.getConf(ENABLE_REWRITE_DATE_TIMESTAMP_COMPARISON)

def enableCommonSubexpressionEliminate: Boolean =
conf.getConf(ENABLE_COMMON_SUBEXPRESSION_ELIMINATE)

// whether to use ColumnarShuffleManager
def isUseColumnarShuffleManager: Boolean =
Expand Down Expand Up @@ -1388,7 +1392,7 @@ object GlutenConfig {
.intConf
.createOptional

val ENABLE_DATE_TIMESTAMP_COMPARISON =
val ENABLE_REWRITE_DATE_TIMESTAMP_COMPARISON =
buildConf("spark.gluten.sql.rewrite.dateTimestampComparison")
.internal()
.doc("Rewrite the comparision between date and timestamp to timestamp comparison."
Expand All @@ -1403,6 +1407,15 @@ object GlutenConfig {
.booleanConf
.createWithDefault(true)

val ENABLE_COMMON_SUBEXPRESSION_ELIMINATE =
buildConf("spark.gluten.sql.commonSubexpressionEliminate")
.internal()
.doc(
"Eliminate common subexpressions in logical plan to avoid multiple evaluation of the same"
+ "expression, may improve performance")
.booleanConf
.createWithDefault(true)

val COLUMNAR_VELOX_BLOOM_FILTER_EXPECTED_NUM_ITEMS =
buildConf("spark.gluten.sql.columnar.backend.velox.bloomFilter.expectedNumItems")
.internal()
Expand Down

0 comments on commit 42e6990

Please sign in to comment.