diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index b6a535118..96dbad89c 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -896,17 +896,38 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("cast between decimals with different precision and scale") { - // cast between default Decimal(38, 18) to Decimal(9,1) + // cast between default Decimal(38, 18) to Decimal(7,2) val values = Seq(BigDecimal("12345.6789"), BigDecimal("9876.5432"), BigDecimal("123.4567")) val df = withNulls(values).toDF("a") castTest(df, DataTypes.createDecimalType(7, 2)) } - test("cast two between decimals with different precision and scale") { + test("cast between decimals with lower precision and scale") { // cast between Decimal(10, 2) to Decimal(9,1) castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(9, 1)) } + test("cast between decimals with higher precision than source") { + // cast between Decimal(10, 2) to Decimal(10,4) + withSQLConf("spark.comet.explainFallback.enabled" -> "true") { + castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 4)) + } + } + + test("cast between decimals with negative precision") { + // cast to negative scale + checkSparkMaybeThrows( + spark.sql("select a, cast(a as DECIMAL(10,-4)) from t order by a")) match { + case (expected, actual) => + assert(expected.contains("PARSE_SYNTAX_ERROR") === actual.contains("PARSE_SYNTAX_ERROR")) + } + } + + test("cast between decimals with zero precision") { + // cast between Decimal(10, 2) to Decimal(10,4) + castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 0)) + } + private def generateFloats(): DataFrame = { withNulls(gen.generateFloats(dataSize)).toDF("a") } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 1709cce61..213ec7efe 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -231,11 +231,9 @@ abstract class CometTestBase df: => DataFrame): (Option[Throwable], Option[Throwable]) = { var expected: Option[Throwable] = None withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - val dfSpark = Dataset.ofRows(spark, df.logicalPlan) - expected = Try(dfSpark.collect()).failed.toOption + expected = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption } - val dfComet = Dataset.ofRows(spark, df.logicalPlan) - val actual = Try(dfComet.collect()).failed.toOption + val actual = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption (expected, actual) }