diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxRoughCostModelSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxRoughCostModelSuite.scala index ca3bbb0b1e72..180a8febb649 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxRoughCostModelSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxRoughCostModelSuite.scala @@ -29,7 +29,7 @@ class VeloxRoughCostModelSuite extends VeloxWholeStageTransformerSuite { super.beforeAll() spark .range(100) - .selectExpr("cast(id % 3 as int) as c1", "id as c2") + .selectExpr("cast(id % 3 as int) as c1", "id as c2", "array(id, id + 1) as c3") .write .format("parquet") .saveAsTable("tmp1") @@ -43,6 +43,7 @@ class VeloxRoughCostModelSuite extends VeloxWholeStageTransformerSuite { override protected def sparkConf: SparkConf = super.sparkConf .set(GlutenConfig.RAS_ENABLED.key, "true") .set(GlutenConfig.RAS_COST_MODEL.key, "rough") + .set(GlutenConfig.VANILLA_VECTORIZED_READERS_ENABLED.key, "false") test("fallback trivial project if its neighbor nodes fell back") { withSQLConf(GlutenConfig.COLUMNAR_FILESCAN_ENABLED.key -> "false") { @@ -51,4 +52,12 @@ class VeloxRoughCostModelSuite extends VeloxWholeStageTransformerSuite { } } } + + test("avoid adding r2c whose schema contains complex data types") { + withSQLConf(GlutenConfig.COLUMNAR_FILESCAN_ENABLED.key -> "false") { + runQueryAndCompare("select array_contains(c3, 0) as list from tmp1") { + checkSparkOperatorMatch[ProjectExec] + } + } + } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/planner/cost/LegacyCostModel.scala b/gluten-substrait/src/main/scala/org/apache/gluten/planner/cost/LegacyCostModel.scala index 3b631872caa6..86f3f6e0b5e2 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/planner/cost/LegacyCostModel.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/planner/cost/LegacyCostModel.scala @@ -22,6 +22,10 @@ import org.apache.gluten.utils.PlanUtil import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} +/** + * A cost model that is supposed to drive RAS planner create the same query plan with legacy + * planner. + */ class LegacyCostModel extends LongCostModel { // A very rough estimation as of now. The cost model basically considers any diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/planner/cost/RoughCostModel.scala b/gluten-substrait/src/main/scala/org/apache/gluten/planner/cost/RoughCostModel.scala index d621c3010c16..086bc4c0ce13 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/planner/cost/RoughCostModel.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/planner/cost/RoughCostModel.scala @@ -16,13 +16,16 @@ */ package org.apache.gluten.planner.cost +import org.apache.gluten.execution.RowToColumnarExecBase import org.apache.gluten.extension.columnar.enumerated.RemoveFilter import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, RowToColumnarLike} import org.apache.gluten.utils.PlanUtil import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression} import org.apache.spark.sql.execution.{ColumnarToRowExec, ProjectExec, RowToColumnarExec, SparkPlan} +import org.apache.spark.sql.types.{ArrayType, MapType, StructType} +/** A rough cost model with some empirical heuristics. */ class RoughCostModel extends LongCostModel { override def selfLongCostOf(node: SparkPlan): Long = { @@ -34,6 +37,10 @@ class RoughCostModel extends LongCostModel { // Make trivial ProjectExec has the same cost as ProjectExecTransform to reduce unnecessary // c2r and r2c. 10L + case r2c: RowToColumnarExecBase if hasComplexTypes(r2c.schema) => + // Avoid moving computation back to native when transition has complex types in schema. + // Such transitions are observed to be extremely expensive as of now. + Long.MaxValue case ColumnarToRowExec(_) => 10L case RowToColumnarExec(_) => 10L case ColumnarToRowLike(_) => 10L @@ -50,4 +57,13 @@ class RoughCostModel extends LongCostModel { case _: Attribute => true case _ => false } + + private def hasComplexTypes(schema: StructType): Boolean = { + schema.exists(_.dataType match { + case _: StructType => true + case _: ArrayType => true + case _: MapType => true + case _ => false + }) + } }