Skip to content

Commit

Permalink
[VL] RAS: Add a new built-in cost model that avoids offloading trivia…
Browse files Browse the repository at this point in the history
…l projects if its neighbor nodes fell back (apache#6756)
  • Loading branch information
zml1206 authored Aug 9, 2024
1 parent 58f1cf6 commit 9251078
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 64 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.gluten.GlutenConfig

import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.ProjectExec

class VeloxRoughCostModelSuite extends VeloxWholeStageTransformerSuite {
override protected val resourcePath: String = "/tpch-data-parquet-velox"
override protected val fileFormat: String = "parquet"

override def beforeAll(): Unit = {
super.beforeAll()
spark
.range(100)
.selectExpr("cast(id % 3 as int) as c1", "id as c2")
.write
.format("parquet")
.saveAsTable("tmp1")
}

override protected def afterAll(): Unit = {
spark.sql("drop table tmp1")
super.afterAll()
}

override protected def sparkConf: SparkConf = super.sparkConf
.set(GlutenConfig.RAS_ENABLED.key, "true")
.set(GlutenConfig.RAS_COST_MODEL.key, "rough")

test("fallback trivial project if its neighbor nodes fell back") {
withSQLConf(GlutenConfig.COLUMNAR_FILESCAN_ENABLED.key -> "false") {
runQueryAndCompare("select c1 as c3 from tmp1") {
checkSparkOperatorMatch[ProjectExec]
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ object VeloxRasSuite {
}

class UserCostModel1 extends CostModel[SparkPlan] {
private val base = GlutenCostModel.rough()
private val base = GlutenCostModel.legacy()
override def costOf(node: SparkPlan): Cost = node match {
case _: RowUnary => base.makeInfCost()
case other => base.costOf(other)
Expand All @@ -205,7 +205,7 @@ object VeloxRasSuite {
}

class UserCostModel2 extends CostModel[SparkPlan] {
private val base = GlutenCostModel.rough()
private val base = GlutenCostModel.legacy()
override def costOf(node: SparkPlan): Cost = node match {
case _: ColumnarUnary => base.makeInfCost()
case other => base.costOf(other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,16 @@
package org.apache.gluten.planner.cost

import org.apache.gluten.GlutenConfig
import org.apache.gluten.extension.columnar.enumerated.RemoveFilter
import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, RowToColumnarLike}
import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
import org.apache.gluten.ras.{Cost, CostModel}
import org.apache.gluten.utils.PlanUtil
import org.apache.gluten.ras.CostModel

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.utils.ReflectionUtil

object GlutenCostModel extends Logging {
def find(): CostModel[SparkPlan] = {
val aliases: Map[String, Class[_ <: CostModel[SparkPlan]]] = Map(
"rough" -> classOf[RoughCostModel])
val aliases: Map[String, Class[_ <: CostModel[SparkPlan]]] =
Map("legacy" -> classOf[LegacyCostModel], "rough" -> classOf[RoughCostModel])
val aliasOrClass = GlutenConfig.getConf.rasCostModel
val clazz: Class[_ <: CostModel[SparkPlan]] = if (aliases.contains(aliasOrClass)) {
aliases(aliasOrClass)
Expand All @@ -45,55 +41,5 @@ object GlutenCostModel extends Logging {
model
}

def rough(): CostModel[SparkPlan] = new RoughCostModel()

private class RoughCostModel extends CostModel[SparkPlan] {
private val infLongCost = Long.MaxValue

override def costOf(node: SparkPlan): GlutenCost = node match {
case _: GroupLeafExec => throw new IllegalStateException()
case _ => GlutenCost(longCostOf(node))
}

private def longCostOf(node: SparkPlan): Long = node match {
case n =>
val selfCost = selfLongCostOf(n)

// Sum with ceil to avoid overflow.
def safeSum(a: Long, b: Long): Long = {
assert(a >= 0)
assert(b >= 0)
val sum = a + b
if (sum < a || sum < b) Long.MaxValue else sum
}

(n.children.map(longCostOf).toList :+ selfCost).reduce(safeSum)
}

// A very rough estimation as of now. The cost model basically considers any
// fallen back ops as having extreme high cost so offloads computations as
// much as possible.
private def selfLongCostOf(node: SparkPlan): Long = {
node match {
case _: RemoveFilter.NoopFilter =>
// To make planner choose the tree that has applied rule PushFilterToScan.
0L
case ColumnarToRowExec(child) => 10L
case RowToColumnarExec(child) => 10L
case ColumnarToRowLike(child) => 10L
case RowToColumnarLike(child) => 10L
case p if PlanUtil.isGlutenColumnarOp(p) => 10L
case p if PlanUtil.isVanillaColumnarOp(p) => 1000L
// Other row ops. Usually a vanilla row op.
case _ => 1000L
}
}

override def costComparator(): Ordering[Cost] = Ordering.Long.on {
case GlutenCost(value) => value
case _ => throw new IllegalStateException("Unexpected cost type")
}

override def makeInfCost(): Cost = GlutenCost(infLongCost)
}
def legacy(): CostModel[SparkPlan] = new LegacyCostModel()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.planner.cost

import org.apache.gluten.extension.columnar.enumerated.RemoveFilter
import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, RowToColumnarLike}
import org.apache.gluten.utils.PlanUtil

import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan}

class LegacyCostModel extends LongCostModel {

// A very rough estimation as of now. The cost model basically considers any
// fallen back ops as having extreme high cost so offloads computations as
// much as possible.
override def selfLongCostOf(node: SparkPlan): Long = {
node match {
case _: RemoveFilter.NoopFilter =>
// To make planner choose the tree that has applied rule PushFilterToScan.
0L
case ColumnarToRowExec(_) => 10L
case RowToColumnarExec(_) => 10L
case ColumnarToRowLike(_) => 10L
case RowToColumnarLike(_) => 10L
case p if PlanUtil.isGlutenColumnarOp(p) => 10L
case p if PlanUtil.isVanillaColumnarOp(p) => 1000L
// Other row ops. Usually a vanilla row op.
case _ => 1000L
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.planner.cost

import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
import org.apache.gluten.ras.{Cost, CostModel}

import org.apache.spark.sql.execution.SparkPlan

abstract class LongCostModel extends CostModel[SparkPlan] {
private val infLongCost = Long.MaxValue

override def costOf(node: SparkPlan): GlutenCost = node match {
case _: GroupLeafExec => throw new IllegalStateException()
case _ => GlutenCost(longCostOf(node))
}

private def longCostOf(node: SparkPlan): Long = node match {
case n =>
val selfCost = selfLongCostOf(n)

// Sum with ceil to avoid overflow.
def safeSum(a: Long, b: Long): Long = {
assert(a >= 0)
assert(b >= 0)
val sum = a + b
if (sum < a || sum < b) Long.MaxValue else sum
}

(n.children.map(longCostOf).toList :+ selfCost).reduce(safeSum)
}

def selfLongCostOf(node: SparkPlan): Long

override def costComparator(): Ordering[Cost] = Ordering.Long.on {
case GlutenCost(value) => value
case _ => throw new IllegalStateException("Unexpected cost type")
}

override def makeInfCost(): Cost = GlutenCost(infLongCost)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.planner.cost

import org.apache.gluten.extension.columnar.enumerated.RemoveFilter
import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, RowToColumnarLike}
import org.apache.gluten.utils.PlanUtil

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression}
import org.apache.spark.sql.execution.{ColumnarToRowExec, ProjectExec, RowToColumnarExec, SparkPlan}

class RoughCostModel extends LongCostModel {

override def selfLongCostOf(node: SparkPlan): Long = {
node match {
case _: RemoveFilter.NoopFilter =>
// To make planner choose the tree that has applied rule PushFilterToScan.
0L
case ProjectExec(projectList, _) if projectList.forall(isCheapExpression) =>
// Make trivial ProjectExec has the same cost as ProjectExecTransform to reduce unnecessary
// c2r and r2c.
10L
case ColumnarToRowExec(_) => 10L
case RowToColumnarExec(_) => 10L
case ColumnarToRowLike(_) => 10L
case RowToColumnarLike(_) => 10L
case p if PlanUtil.isGlutenColumnarOp(p) => 10L
case p if PlanUtil.isVanillaColumnarOp(p) => 1000L
// Other row ops. Usually a vanilla row op.
case _ => 1000L
}
}

private def isCheapExpression(ne: NamedExpression): Boolean = ne match {
case Alias(_: Attribute, _) => true
case _: Attribute => true
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1317,10 +1317,11 @@ object GlutenConfig {
val RAS_COST_MODEL =
buildConf("spark.gluten.ras.costModel")
.doc(
"Experimental: The class name of user-defined cost model that will be used by RAS. " +
"If not specified, a rough built-in cost model will be used.")
"Experimental: The class name of user-defined cost model that will be used by RAS. If " +
"not specified, a legacy built-in cost model that exhaustively offloads computations " +
"will be used.")
.stringConf
.createWithDefaultString("rough")
.createWithDefaultString("legacy")

// velox caching options.
val COLUMNAR_VELOX_CACHE_ENABLED =
Expand Down

0 comments on commit 9251078

Please sign in to comment.