From d9f26b4bc9a8b508269be245e87ea189920d01a6 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Sun, 21 May 2023 04:11:27 +0800 Subject: [PATCH] Spark 3.4: Decimal pushdown --- .../clickhouse/single/ClickHouseDataTypeSuite.scala | 10 ++++++++-- .../src/main/scala/xenon/clickhouse/SQLHelper.scala | 4 ++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/spark-3.4/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseDataTypeSuite.scala b/spark-3.4/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseDataTypeSuite.scala index b6881019..7730a61c 100644 --- a/spark-3.4/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseDataTypeSuite.scala +++ b/spark-3.4/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseDataTypeSuite.scala @@ -159,7 +159,8 @@ class ClickHouseDataTypeSuite extends SparkClickHouseSingleTest { testDataType(dataType) { (db, tbl) => runClickHouseSQL( s"""INSERT INTO $db.$tbl VALUES - |(1, '11.1') + |(1, '11.1'), + |(2, '22.2') |""".stripMargin ) } { df => @@ -167,7 +168,12 @@ class ClickHouseDataTypeSuite extends SparkClickHouseSingleTest { assert(df.schema.fields(1).dataType === DecimalType(p, s)) checkAnswer( df, - Row(1, BigDecimal("11.1", new MathContext(p))) :: Nil + Row(1, BigDecimal("11.1", new MathContext(p))) :: + Row(2, BigDecimal("22.2", new MathContext(p))) :: Nil + ) + checkAnswer( + df.filter("value > 20"), + Row(2, BigDecimal("22.2", new MathContext(p))) :: Nil ) } } diff --git a/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/SQLHelper.scala b/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/SQLHelper.scala index cd84abfd..cb2c137c 100644 --- a/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/SQLHelper.scala +++ b/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/SQLHelper.scala @@ -20,6 +20,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.connector.expressions.aggregate._ import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.DecimalType import org.apache.spark.unsafe.types.UTF8String import xenon.clickhouse.Utils._ @@ -38,6 +39,9 @@ trait SQLHelper { case localDateTime: LocalDateTime => s"'${dateTimeFmt.format(localDateTime)}'" case legacyDate: Date => s"'${legacyDateFmt.format(legacyDate)}'" case localDate: LocalDate => s"'${dateFmt.format(localDate)}'" + case decimal: DecimalType if decimal.precision <= 9 => s"toDecimal32($value, ${decimal.scale})" + case decimal: DecimalType if decimal.precision <= 18 => s"toDecimal64($value, ${decimal.scale})" + case decimal: DecimalType => s"toDecimal128($value, ${decimal.scale})" case array: Array[Any] => array.map(compileValue).mkString(",") case _ => value }