diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala index b9d956b290df..a55d7586955b 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -49,6 +49,7 @@ import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildS import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.utils.CHExecUtil import org.apache.spark.sql.extension.ClickHouseAnalysis +import org.apache.spark.sql.extension.CommonSubexpressionEliminateRule import org.apache.spark.sql.extension.RewriteDateTimestampComparisonRule import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -340,7 +341,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { */ override def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]] = { val analyzers = List(spark => new ClickHouseAnalysis(spark, spark.sessionState.conf)) - if (GlutenConfig.getConf.enableDateTimestampComparison) { + if (GlutenConfig.getConf.enableRewriteDateTimestampComparison) { analyzers :+ (spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) } else { analyzers @@ -353,7 +354,12 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { * @return */ override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = { - List.empty + var optimizers = List.empty[SparkSession => Rule[LogicalPlan]] + if (GlutenConfig.getConf.enableCommonSubexpressionEliminate) { + optimizers = optimizers :+ ( + spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf)) + } + optimizers } /** diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetColumnarShuffleSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetColumnarShuffleSuite.scala index c45a2cd7bdd9..0aeef7ee7c69 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetColumnarShuffleSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetColumnarShuffleSuite.scala @@ -39,6 +39,7 @@ class GlutenClickHouseTPCDSParquetColumnarShuffleSuite extends GlutenClickHouseT // .set("spark.sql.files.maxPartitionBytes", "134217728") // .set("spark.sql.files.openCostInBytes", "134217728") .set("spark.memory.offHeap.size", "4g") + // .set("spark.sql.planChangeLog.level", "error") } executeTPCDSTest(false) diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetSuite.scala index 077731375f4f..0f7f8d5500b1 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetSuite.scala @@ -48,6 +48,7 @@ class GlutenClickHouseTPCHParquetSuite extends GlutenClickHouseTPCHAbstractSuite .set("spark.sql.autoBroadcastJoinThreshold", "10MB") .set("spark.gluten.sql.columnar.backend.ch.use.v2", "false") .set("spark.gluten.supported.scala.udfs", "my_add") + // .set("spark.sql.planChangeLog.level", "error") } override protected val createNullableTables = true diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala index ef64d3782f10..6d0a7a2642c5 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala @@ -19,7 +19,9 @@ package io.glutenproject.execution import io.glutenproject.GlutenConfig import io.glutenproject.utils.UTSystemParameters +import org.apache.spark.SPARK_VERSION_SHORT import org.apache.spark.SparkConf +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation} import org.apache.spark.sql.internal.SQLConf @@ -29,6 +31,7 @@ import java.nio.file.Files import java.sql.Date import scala.collection.immutable.Seq +import scala.reflect.ClassTag class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerSuite { override protected val resourcePath: String = { @@ -39,6 +42,11 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS protected val rootPath: String = getClass.getResource("/").getPath protected val basePath: String = rootPath + "unit-tests-working-home" + protected lazy val sparkVersion: String = { + val version = SPARK_VERSION_SHORT.split("\\.") + version(0) + "." + version(1) + } + protected val tablesPath: String = basePath + "/tpch-data" protected val tpchQueries: String = rootPath + "../../../../gluten-core/src/test/resources/tpch-queries" @@ -527,4 +535,51 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS )(checkOperatorMatch[ProjectExecTransformer]) } } + + test("test common subexpression eliminate") { + def checkOperatorCount[T <: TransformSupport](count: Int)(df: DataFrame)(implicit + tag: ClassTag[T]): Unit = { + if (sparkVersion.equals("3.3")) { + assert( + getExecutedPlan(df).count( + plan => { + plan.getClass == tag.runtimeClass + }) == count, + s"executed plan: ${getExecutedPlan(df)}") + } + } + + withSQLConf(("spark.gluten.sql.commonSubexpressionEliminate", "true")) { + // CSE in project + runQueryAndCompare("select hash(id), hash(id)+1, hash(id)-1 from range(10)") { + df => checkOperatorCount[ProjectExecTransformer](2)(df) + } + + // CSE in filter(not work yet) + // runQueryAndCompare( + // "select id from range(10) " + + // "where hex(id) != '' and upper(hex(id)) != '' and lower(hex(id)) != ''") { _ => } + + // CSE in window + runQueryAndCompare( + "SELECT id, AVG(id) OVER (PARTITION BY id % 2 ORDER BY id) as avg_id, " + + "SUM(id) OVER (PARTITION BY id % 2 ORDER BY id) as sum_id FROM range(10)") { + df => checkOperatorCount[ProjectExecTransformer](4)(df) + } + + // CSE in aggregate + runQueryAndCompare( + "select id % 2, max(hash(id)), min(hash(id)) " + + "from range(10) group by id % 2") { + df => checkOperatorCount[ProjectExecTransformer](1)(df) + } + + // CSE in sort + runQueryAndCompare( + "select id from range(10) " + + "order by hash(id%10), hash(hash(id%10))") { + df => checkOperatorCount[ProjectExecTransformer](2)(df) + } + } + } } diff --git a/cpp-ch/local-engine/Parser/TypeParser.cpp b/cpp-ch/local-engine/Parser/TypeParser.cpp index 5fb8cb8dd4d8..958a5fb4518f 100644 --- a/cpp-ch/local-engine/Parser/TypeParser.cpp +++ b/cpp-ch/local-engine/Parser/TypeParser.cpp @@ -254,8 +254,9 @@ DB::Block TypeParser::buildBlockFromNamedStruct(const substrait::NamedStruct & s // This is a partial aggregate data column. // It's type is special, must be a struct type contains all arguments types. + // Notice: there are some coincidence cases in which the type is not a struct type, e.g. name is "_1#913 + _2#914#928". We need to handle it. Poco::StringTokenizer name_parts(name, "#"); - if (name_parts.count() >= 4) + if (name_parts.count() >= 4 && !name.contains(' ')) { auto nested_data_type = DB::removeNullable(ch_type); const auto * tuple_type = typeid_cast(nested_data_type.get()); diff --git a/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h b/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h index 8c0bc0e0d981..4e7e17b9a13b 100644 --- a/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h +++ b/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h @@ -216,7 +216,7 @@ class GetJsonObjectFunctionWriter : public RelRewriter arg0->CopyFrom(scalar_function_pb.arguments(0)); auto * arg1 = decoded_json_function.add_arguments(); arg1->mutable_value()->mutable_literal()->set_string(required_fields_str); - + substrait::Expression new_get_json_object_arg0; new_get_json_object_arg0.mutable_scalar_function()->CopyFrom(decoded_json_function); *scalar_function_pb.mutable_arguments()->Mutable(0)->mutable_value() = new_get_json_object_arg0; diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/CommonSubexpressionEliminateRule.scala b/gluten-core/src/main/scala/io/glutenproject/extension/CommonSubexpressionEliminateRule.scala new file mode 100644 index 000000000000..7b4c4b04fb33 --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/extension/CommonSubexpressionEliminateRule.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.extension + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + +import scala.collection.mutable + +class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf) + extends Rule[LogicalPlan] + with Logging { + + private var lastPlan: LogicalPlan = null + + override def apply(plan: LogicalPlan): LogicalPlan = { + val newPlan = if (plan.resolved && !plan.fastEquals(lastPlan)) { + lastPlan = plan + visitPlan(plan) + } else { + plan + } + newPlan + } + + private case class AliasAndAttribute(alias: Alias, attribute: Attribute) + + private case class RewriteContext(exprs: Seq[Expression], child: LogicalPlan) + + private def visitPlan(plan: LogicalPlan): LogicalPlan = { + var newPlan = plan match { + case project: Project => visitProject(project) + // TODO: CSE in Filter doesn't work for unknown reason, need to fix it later + // case filter: Filter => visitFilter(filter) + case window: Window => visitWindow(window) + case aggregate: Aggregate => visitAggregate(aggregate) + case sort: Sort => visitSort(sort) + case other => + val children = other.children.map(visitPlan) + other.withNewChildren(children) + } + + if (newPlan.output.size == plan.output.size) { + return newPlan + } + + // Add a Project to trim unnecessary attributes(which are always at the end of the output) + val postProjectList = newPlan.output.take(plan.output.size) + Project(postProjectList, newPlan) + } + + private def replaceCommonExprWithAttribute( + expr: Expression, + commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute]): Expression = { + val exprEquals = commonExprMap.get(ExpressionEquals(expr)) + if (exprEquals.isDefined) { + exprEquals.get.attribute + } else { + expr.mapChildren(replaceCommonExprWithAttribute(_, commonExprMap)) + } + } + + private def isValidCommonExpr(expr: Expression): Boolean = { + if ( + (expr.isInstanceOf[Unevaluable] && !expr.isInstanceOf[AttributeReference]) + || expr.isInstanceOf[AggregateFunction] + || (expr.isInstanceOf[AttributeReference] + && expr.asInstanceOf[AttributeReference].name == VirtualColumn.groupingIdName) + ) { + logTrace(s"Check common expression failed $expr class ${expr.getClass.toString}") + return false + } + + expr.children.forall(isValidCommonExpr(_)) + } + + private def rewrite(inputCtx: RewriteContext): RewriteContext = { + logTrace(s"Start rewrite with input exprs:${inputCtx.exprs} input child:${inputCtx.child}") + val equivalentExpressions = new EquivalentExpressions + inputCtx.exprs.foreach(equivalentExpressions.addExprTree(_)) + + // Get all the expressions that appear at least twice + val newChild = visitPlan(inputCtx.child) + val commonExprs = equivalentExpressions.getCommonSubexpressions + + // Put the common expressions into a hash map + val commonExprMap = mutable.HashMap.empty[ExpressionEquals, AliasAndAttribute] + commonExprs.foreach { + expr => + if (!expr.foldable && !expr.isInstanceOf[Attribute] && isValidCommonExpr(expr)) { + logTrace(s"Common subexpression $expr class ${expr.getClass.toString}") + val exprEquals = ExpressionEquals(expr) + val alias = Alias(expr, expr.toString)() + val attribute = alias.toAttribute + commonExprMap.put(exprEquals, AliasAndAttribute(alias, attribute)) + } + } + + if (commonExprMap.isEmpty) { + logTrace(s"commonExprMap is empty, all exprs: ${equivalentExpressions.debugString(true)}") + return RewriteContext(inputCtx.exprs, newChild) + } + + // Generate pre-project as new child + var preProjectList = newChild.output ++ commonExprMap.values.map(_.alias) + val preProject = Project(preProjectList, newChild) + logTrace(s"newChild after rewrite: $preProject") + + // Replace the common expressions with the first expression that produces it. + try { + var newExprs = inputCtx.exprs + .map(replaceCommonExprWithAttribute(_, commonExprMap)) + logTrace(s"newExprs after rewrite: $newExprs") + RewriteContext(newExprs, preProject) + } catch { + case e: Exception => + logWarning( + s"Common subexpression eliminate failed with exception: ${e.getMessage}" + + s" while replace ${inputCtx.exprs} with $commonExprMap, fallback now") + RewriteContext(inputCtx.exprs, newChild) + } + } + + private def visitProject(project: Project): Project = { + val inputCtx = RewriteContext(project.projectList, project.child) + val outputCtx = rewrite(inputCtx) + Project(outputCtx.exprs.map(_.asInstanceOf[NamedExpression]), outputCtx.child) + } + + private def visitFilter(filter: Filter): Filter = { + val inputCtx = RewriteContext(Seq(filter.condition), filter.child) + val outputCtx = rewrite(inputCtx) + Filter(outputCtx.exprs.head, outputCtx.child) + } + + private def visitWindow(window: Window): Window = { + val inputCtx = RewriteContext(window.windowExpressions, window.child) + val outputCtx = rewrite(inputCtx) + Window( + outputCtx.exprs.map(_.asInstanceOf[NamedExpression]), + window.partitionSpec, + window.orderSpec, + outputCtx.child) + } + + private def visitAggregate(aggregate: Aggregate): Aggregate = { + logTrace( + s"aggregate groupingExpressions: ${aggregate.groupingExpressions} " + + s"aggregateExpressions: ${aggregate.aggregateExpressions}") + val groupingSize = aggregate.groupingExpressions.size + val aggregateSize = aggregate.aggregateExpressions.size + + val inputCtx = RewriteContext( + aggregate.groupingExpressions ++ aggregate.aggregateExpressions, + aggregate.child) + val outputCtx = rewrite(inputCtx) + Aggregate( + outputCtx.exprs.slice(0, groupingSize), + outputCtx.exprs + .slice(groupingSize, groupingSize + aggregateSize) + .map(_.asInstanceOf[NamedExpression]), + outputCtx.child + ) + } + + private def visitSort(sort: Sort): Sort = { + val exprs = sort.order.flatMap(_.children) + val inputCtx = RewriteContext(exprs, sort.child) + val outputCtx = rewrite(inputCtx) + + var start = 0; + var newOrder = Seq.empty[SortOrder] + sort.order.foreach( + order => { + val childrenSize = order.children.size + val newChildren = outputCtx.exprs.slice(start, start + childrenSize) + newOrder = newOrder :+ order.withNewChildren(newChildren).asInstanceOf[SortOrder] + start += childrenSize + }) + + Sort(newOrder, sort.global, outputCtx.child) + } +} diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala index 9f57acf6596e..512678375ba8 100644 --- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala +++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala @@ -86,7 +86,11 @@ class GlutenConfig(conf: SQLConf) extends Logging { def columnarTableCacheEnabled: Boolean = conf.getConf(COLUMNAR_TABLE_CACHE_ENABLED) - def enableDateTimestampComparison: Boolean = conf.getConf(ENABLE_DATE_TIMESTAMP_COMPARISON) + def enableRewriteDateTimestampComparison: Boolean = + conf.getConf(ENABLE_REWRITE_DATE_TIMESTAMP_COMPARISON) + + def enableCommonSubexpressionEliminate: Boolean = + conf.getConf(ENABLE_COMMON_SUBEXPRESSION_ELIMINATE) // whether to use ColumnarShuffleManager def isUseColumnarShuffleManager: Boolean = @@ -1388,7 +1392,7 @@ object GlutenConfig { .intConf .createOptional - val ENABLE_DATE_TIMESTAMP_COMPARISON = + val ENABLE_REWRITE_DATE_TIMESTAMP_COMPARISON = buildConf("spark.gluten.sql.rewrite.dateTimestampComparison") .internal() .doc("Rewrite the comparision between date and timestamp to timestamp comparison." @@ -1403,6 +1407,15 @@ object GlutenConfig { .booleanConf .createWithDefault(true) + val ENABLE_COMMON_SUBEXPRESSION_ELIMINATE = + buildConf("spark.gluten.sql.commonSubexpressionEliminate") + .internal() + .doc( + "Eliminate common subexpressions in logical plan to avoid multiple evaluation of the same" + + "expression, may improve performance") + .booleanConf + .createWithDefault(true) + val COLUMNAR_VELOX_BLOOM_FILTER_EXPECTED_NUM_ITEMS = buildConf("spark.gluten.sql.columnar.backend.velox.bloomFilter.expectedNumItems") .internal()