From ac2b18f33d75333a4a344539cf7599a5179e8fa9 Mon Sep 17 00:00:00 2001 From: exmy Date: Thu, 11 May 2023 13:46:45 +0800 Subject: [PATCH] [GLUTEN-1577][CORE][Fix] Respect spark's config for case sensitive when get attribute name (#1578) --- .../expression/ConverterUtils.scala | 16 ++++++++++++---- .../clickhouse/ClickHouseTestSettings.scala | 3 ++- .../org/apache/spark/sql/GlutenJoinSuite.scala | 13 +++++++++++++ 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala index c877a7302b6d..3263346608a0 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala @@ -17,6 +17,8 @@ package io.glutenproject.expression +import java.util.Locale + import io.glutenproject.execution.{BasicScanExecTransformer, BatchScanExecTransformer, FileSourceScanExecTransformer} import io.glutenproject.substrait.`type`._ import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat @@ -26,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -76,11 +79,16 @@ object ConverterUtils extends Logging { } def getShortAttributeName(attr: Attribute): String = { - val subIndex = attr.name.indexOf("(") + val name = if (SQLConf.get.caseSensitiveAnalysis) { + attr.name + } else { + attr.name.toLowerCase(Locale.ROOT) + } + val subIndex = name.indexOf("(") if (subIndex != -1) { - attr.name.substring(0, subIndex) + name.substring(0, subIndex) } else { - attr.name + name } } @@ -89,7 +97,7 @@ object ConverterUtils extends Logging { } def isNullable(nullability: Type.Nullability): Boolean = { - return nullability == Type.Nullability.NULLABILITY_NULLABLE + nullability == Type.Nullability.NULLABILITY_NULLABLE } def parseFromSubstraitType(substraitType: Type): (DataType, Boolean) = { diff --git a/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala index be7df409e4f5..e3b2ee607aff 100644 --- a/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala @@ -221,7 +221,8 @@ class ClickHouseTestSettings extends BackendTestSettings { "inner join where, one match per row", "left semi join", "multiple-key equi-join is hash-join", - "full outer join" + "full outer join", + GlutenTestConstants.GLUTEN_TEST + "test case sensitive for BHJ" ) enableSuite[GlutenHashExpressionsSuite] diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenJoinSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenJoinSuite.scala index 6243e3626e65..b733c1accb51 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenJoinSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenJoinSuite.scala @@ -40,4 +40,17 @@ class GlutenJoinSuite extends JoinSuite with GlutenSQLTestsTrait { // NaN is not supported currently, just skip. "NaN and -0.0 in join keys" ) + + test(GlutenTestConstants.GLUTEN_TEST + "test case sensitive for BHJ") { + spark.sql("create table t_bhj(a int, b int, C int) using parquet") + spark.sql("insert overwrite t_bhj select id as a, (id+1) as b, (id+2) as c from range(3)") + val sql = + """ + |select /*+ BROADCAST(t1) */ t0.a, t0.b + |from t_bhj as t0 join t_bhj as t1 on t0.a = t1.a and t0.b = t1.b and t0.c = t1.c + |group by t0.a, t0.b + |order by t0.a, t0.b + |""".stripMargin + checkAnswer(spark.sql(sql), Seq(Row(0, 1), Row(1, 2), Row(2, 3))) + } }