Skip to content

Commit

Permalink
[VL] RAS: Avoid adding R2C whose schema contains complex data types (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Sep 14, 2024
1 parent ffe2be6 commit 8b65630
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class VeloxRoughCostModelSuite extends VeloxWholeStageTransformerSuite {
super.beforeAll()
spark
.range(100)
.selectExpr("cast(id % 3 as int) as c1", "id as c2")
.selectExpr("cast(id % 3 as int) as c1", "id as c2", "array(id, id + 1) as c3")
.write
.format("parquet")
.saveAsTable("tmp1")
Expand All @@ -43,6 +43,7 @@ class VeloxRoughCostModelSuite extends VeloxWholeStageTransformerSuite {
override protected def sparkConf: SparkConf = super.sparkConf
.set(GlutenConfig.RAS_ENABLED.key, "true")
.set(GlutenConfig.RAS_COST_MODEL.key, "rough")
.set(GlutenConfig.VANILLA_VECTORIZED_READERS_ENABLED.key, "false")

test("fallback trivial project if its neighbor nodes fell back") {
withSQLConf(GlutenConfig.COLUMNAR_FILESCAN_ENABLED.key -> "false") {
Expand All @@ -51,4 +52,12 @@ class VeloxRoughCostModelSuite extends VeloxWholeStageTransformerSuite {
}
}
}

test("avoid adding r2c whose schema contains complex data types") {
withSQLConf(GlutenConfig.COLUMNAR_FILESCAN_ENABLED.key -> "false") {
runQueryAndCompare("select array_contains(c3, 0) as list from tmp1") {
checkSparkOperatorMatch[ProjectExec]
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import org.apache.gluten.utils.PlanUtil

import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan}

/**
* A cost model that is supposed to drive RAS planner create the same query plan with legacy
* planner.
*/
class LegacyCostModel extends LongCostModel {

// A very rough estimation as of now. The cost model basically considers any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
*/
package org.apache.gluten.planner.cost

import org.apache.gluten.execution.RowToColumnarExecBase
import org.apache.gluten.extension.columnar.enumerated.RemoveFilter
import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, RowToColumnarLike}
import org.apache.gluten.utils.PlanUtil

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression}
import org.apache.spark.sql.execution.{ColumnarToRowExec, ProjectExec, RowToColumnarExec, SparkPlan}
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}

/** A rough cost model with some empirical heuristics. */
class RoughCostModel extends LongCostModel {

override def selfLongCostOf(node: SparkPlan): Long = {
Expand All @@ -34,6 +37,10 @@ class RoughCostModel extends LongCostModel {
// Make trivial ProjectExec has the same cost as ProjectExecTransform to reduce unnecessary
// c2r and r2c.
10L
case r2c: RowToColumnarExecBase if hasComplexTypes(r2c.schema) =>
// Avoid moving computation back to native when transition has complex types in schema.
// Such transitions are observed to be extremely expensive as of now.
Long.MaxValue
case ColumnarToRowExec(_) => 10L
case RowToColumnarExec(_) => 10L
case ColumnarToRowLike(_) => 10L
Expand All @@ -50,4 +57,13 @@ class RoughCostModel extends LongCostModel {
case _: Attribute => true
case _ => false
}

private def hasComplexTypes(schema: StructType): Boolean = {
schema.exists(_.dataType match {
case _: StructType => true
case _: ArrayType => true
case _: MapType => true
case _ => false
})
}
}

0 comments on commit 8b65630

Please sign in to comment.