Skip to content

Commit

Permalink
feat: enable decimal to decimal cast of different precision and scale (
Browse files Browse the repository at this point in the history
…#1086)

* enable decimal to decimal cast of different precision and scale

* add more test cases for negative scale and higher precision

* add check for compatibility for decimal to decimal

* fix code style

* Update spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Co-authored-by: Andy Grove <[email protected]>

* fix the nit in comment

---------

Co-authored-by: himadripal <[email protected]>
Co-authored-by: Andy Grove <[email protected]>
  • Loading branch information
3 people authored Nov 22, 2024
1 parent 9990b34 commit 500895d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,13 @@ object CometCast {
case _ =>
Unsupported
}
case (_: DecimalType, _: DecimalType) =>
// https://github.com/apache/datafusion-comet/issues/375
Incompatible()
case (from: DecimalType, to: DecimalType) =>
if (to.precision < from.precision) {
// https://github.com/apache/datafusion/issues/13492
Incompatible(Some("Casting to smaller precision is not supported"))
} else {
Compatible()
}
case (DataTypes.StringType, _) =>
canCastFromString(toType, timeZoneId, evalMode)
case (_, DataTypes.StringType) =>
Expand Down
28 changes: 28 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,34 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("cast between decimals with different precision and scale") {
// cast between default Decimal(38, 18) to Decimal(6,2)
val values = Seq(BigDecimal("12345.6789"), BigDecimal("9876.5432"), BigDecimal("123.4567"))
val df = withNulls(values)
.toDF("b")
.withColumn("a", col("b").cast(DecimalType(6, 2)))
checkSparkAnswer(df)
}

test("cast between decimals with higher precision than source") {
// cast between Decimal(10, 2) to Decimal(10,4)
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 4))
}

test("cast between decimals with negative precision") {
// cast to negative scale
checkSparkMaybeThrows(
spark.sql("select a, cast(a as DECIMAL(10,-4)) from t order by a")) match {
case (expected, actual) =>
assert(expected.contains("PARSE_SYNTAX_ERROR") === actual.contains("PARSE_SYNTAX_ERROR"))
}
}

test("cast between decimals with zero precision") {
// cast between Decimal(10, 2) to Decimal(10,0)
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 0))
}

private def generateFloats(): DataFrame = {
withNulls(gen.generateFloats(dataSize)).toDF("a")
}
Expand Down
6 changes: 2 additions & 4 deletions spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,9 @@ abstract class CometTestBase
df: => DataFrame): (Option[Throwable], Option[Throwable]) = {
var expected: Option[Throwable] = None
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val dfSpark = Dataset.ofRows(spark, df.logicalPlan)
expected = Try(dfSpark.collect()).failed.toOption
expected = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption
}
val dfComet = Dataset.ofRows(spark, df.logicalPlan)
val actual = Try(dfComet.collect()).failed.toOption
val actual = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption
(expected, actual)
}

Expand Down

0 comments on commit 500895d

Please sign in to comment.