Skip to content

Commit

Permalink
Revert "[SPARK-39226][SQL] Fix the precision of the return type of ro…
Browse files Browse the repository at this point in the history
…und-like functions"

This reverts commit 77b1313.
  • Loading branch information
gleonSun committed Dec 31, 2024
1 parent e3c73cc commit 1c34d93
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,21 @@ object CeilExpressionBuilder extends CeilFloorExpressionBuilderBase {
}

case class RoundCeil(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.CEILING, "ROUND_CEILING") {
extends RoundBase(child, scale, BigDecimal.RoundingMode.CEILING, "ROUND_CEILING")
with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, IntegerType)

override lazy val dataType: DataType = child.dataType match {
case DecimalType.Fixed(p, s) =>
if (_scale < 0) {
DecimalType(math.max(p, 1 - _scale), 0)
} else {
DecimalType(p, math.min(s, _scale))
}
case t => t
}

override def nodeName: String = "ceil"

override protected def withNewChildrenInternal(
Expand Down Expand Up @@ -552,10 +563,21 @@ object FloorExpressionBuilder extends CeilFloorExpressionBuilderBase {
}

case class RoundFloor(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.FLOOR, "ROUND_FLOOR") {
extends RoundBase(child, scale, BigDecimal.RoundingMode.FLOOR, "ROUND_FLOOR")
with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, IntegerType)

override lazy val dataType: DataType = child.dataType match {
case DecimalType.Fixed(p, s) =>
if (_scale < 0) {
DecimalType(math.max(p, 1 - _scale), 0)
} else {
DecimalType(p, math.min(s, _scale))
}
case t => t
}

override def nodeName: String = "floor"

override protected def withNewChildrenInternal(
Expand Down Expand Up @@ -1425,21 +1447,9 @@ abstract class RoundBase(child: Expression, scale: Expression,
override def foldable: Boolean = child.foldable

override lazy val dataType: DataType = child.dataType match {
case DecimalType.Fixed(p, s) =>
// After rounding we may need one more digit in the integral part,
// e.g. `ceil(9.9, 0)` -> `10`, `ceil(99, -1)` -> `100`.
val integralLeastNumDigits = p - s + 1
if (_scale < 0) {
// negative scale means we need to adjust `-scale` number of digits before the decimal
// point, which means we need at lease `-scale + 1` digits (after rounding).
val newPrecision = math.max(integralLeastNumDigits, -_scale + 1)
// We have to accept the risk of overflow as we can't exceed the max precision.
DecimalType(math.min(newPrecision, DecimalType.MAX_PRECISION), 0)
} else {
val newScale = math.min(s, _scale)
// We have to accept the risk of overflow as we can't exceed the max precision.
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
}
// if the new scale is bigger which means we are scaling up,
// keep the original scale as `Decimal` does
case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale)
case t => t
}

Expand Down Expand Up @@ -1606,14 +1616,13 @@ abstract class RoundBase(child: Expression, scale: Expression,
Examples:
> SELECT _FUNC_(2.5, 0);
3
> SELECT _FUNC_(25, -1);
30
""",
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class Round(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP") {
extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP")
with Serializable with ImplicitCastInputTypes {
def this(child: Expression) = this(child, Literal(0))
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Round =
copy(child = newLeft, scale = newRight)
Expand All @@ -1631,14 +1640,13 @@ case class Round(child: Expression, scale: Expression)
Examples:
> SELECT _FUNC_(2.5, 0);
2
> SELECT _FUNC_(25, -1);
20
""",
since = "2.0.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class BRound(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN") {
extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN")
with Serializable with ImplicitCastInputTypes {
def this(child: Expression) = this(child, Literal(0))
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): BRound = copy(child = newLeft, scale = newRight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -806,14 +806,12 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Round(-3.5, 0), -4.0)
checkEvaluation(Round(-0.35, 1), -0.4)
checkEvaluation(Round(-35, -1), -40)
checkEvaluation(Round(BigDecimal("45.00"), -1), BigDecimal(50))
checkEvaluation(BRound(2.5, 0), 2.0)
checkEvaluation(BRound(3.5, 0), 4.0)
checkEvaluation(BRound(-2.5, 0), -2.0)
checkEvaluation(BRound(-3.5, 0), -4.0)
checkEvaluation(BRound(-0.35, 1), -0.4)
checkEvaluation(BRound(-35, -1), -40)
checkEvaluation(BRound(BigDecimal("45.00"), -1), BigDecimal(40))
checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(2.5), Literal(0))), Decimal(2))
checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(3.5), Literal(0))), Decimal(3))
checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(-2.5), Literal(0))), Decimal(-3L))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ SELECT CEIL(-35, -1);
SELECT CEIL(-0.1, 0);
SELECT CEIL(5, 0);
SELECT CEIL(3.14115, -3);
SELECT CEIL(9.9, 0);
SELECT CEIL(CAST(99 AS DECIMAL(2, 0)), -1);
SELECT CEIL(2.5, null);
SELECT CEIL(2.5, 'a');
SELECT CEIL(2.5, 0, 0);
Expand All @@ -24,8 +22,6 @@ SELECT FLOOR(-35, -1);
SELECT FLOOR(-0.1, 0);
SELECT FLOOR(5, 0);
SELECT FLOOR(3.14115, -3);
SELECT FLOOR(-9.9, 0);
SELECT FLOOR(CAST(-99 AS DECIMAL(2, 0)), -1);
SELECT FLOOR(2.5, null);
SELECT FLOOR(2.5, 'a');
SELECT FLOOR(2.5, 0, 0);
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 28
-- Number of queries: 24


-- !query
Expand Down Expand Up @@ -45,7 +45,7 @@ struct<ceil(-0.35, 1):decimal(2,1)>
-- !query
SELECT CEIL(-35, -1)
-- !query schema
struct<ceil(-35, -1):decimal(11,0)>
struct<ceil(-35, -1):decimal(10,0)>
-- !query output
-30

Expand All @@ -61,35 +61,19 @@ struct<ceil(-0.1, 0):decimal(1,0)>
-- !query
SELECT CEIL(5, 0)
-- !query schema
struct<ceil(5, 0):decimal(11,0)>
struct<ceil(5, 0):decimal(10,0)>
-- !query output
5


-- !query
SELECT CEIL(3.14115, -3)
-- !query schema
struct<ceil(3.14115, -3):decimal(4,0)>
struct<ceil(3.14115, -3):decimal(6,0)>
-- !query output
1000


-- !query
SELECT CEIL(9.9, 0)
-- !query schema
struct<ceil(9.9, 0):decimal(2,0)>
-- !query output
10


-- !query
SELECT CEIL(CAST(99 AS DECIMAL(2, 0)), -1)
-- !query schema
struct<ceil(CAST(99 AS DECIMAL(2,0)), -1):decimal(3,0)>
-- !query output
100


-- !query
SELECT CEIL(2.5, null)
-- !query schema
Expand Down Expand Up @@ -160,7 +144,7 @@ struct<floor(-0.35, 1):decimal(2,1)>
-- !query
SELECT FLOOR(-35, -1)
-- !query schema
struct<floor(-35, -1):decimal(11,0)>
struct<floor(-35, -1):decimal(10,0)>
-- !query output
-40

Expand All @@ -176,35 +160,19 @@ struct<floor(-0.1, 0):decimal(1,0)>
-- !query
SELECT FLOOR(5, 0)
-- !query schema
struct<floor(5, 0):decimal(11,0)>
struct<floor(5, 0):decimal(10,0)>
-- !query output
5


-- !query
SELECT FLOOR(3.14115, -3)
-- !query schema
struct<floor(3.14115, -3):decimal(4,0)>
struct<floor(3.14115, -3):decimal(6,0)>
-- !query output
0


-- !query
SELECT FLOOR(-9.9, 0)
-- !query schema
struct<floor(-9.9, 0):decimal(2,0)>
-- !query output
-10


-- !query
SELECT FLOOR(CAST(-99 AS DECIMAL(2, 0)), -1)
-- !query schema
struct<floor(CAST(-99 AS DECIMAL(2,0)), -1):decimal(3,0)>
-- !query output
-100


-- !query
SELECT FLOOR(2.5, null)
-- !query schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4404,7 +4404,7 @@ struct<>
-- !query
SELECT a, ceil(a), ceiling(a), floor(a), round(a) FROM ceil_floor_round
-- !query schema
struct<a:decimal(38,18),CEIL(a):decimal(21,0),ceiling(a):decimal(21,0),FLOOR(a):decimal(21,0),round(a, 0):decimal(21,0)>
struct<a:decimal(38,18),CEIL(a):decimal(21,0),ceiling(a):decimal(21,0),FLOOR(a):decimal(21,0),round(a, 0):decimal(38,0)>
-- !query output
-0.000001000000000000 0 0 -1 0
-5.499999000000000000 -5 -5 -6 -5
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-- Automatically generated by TPCDSQueryTestSuite

-- !query schema
struct<d_week_seq1:int,round((sun_sales1 / sun_sales2), 2):decimal(20,2),round((mon_sales1 / mon_sales2), 2):decimal(20,2),round((tue_sales1 / tue_sales2), 2):decimal(20,2),round((wed_sales1 / wed_sales2), 2):decimal(20,2),round((thu_sales1 / thu_sales2), 2):decimal(20,2),round((fri_sales1 / fri_sales2), 2):decimal(20,2),round((sat_sales1 / sat_sales2), 2):decimal(20,2)>
struct<d_week_seq1:int,round((sun_sales1 / sun_sales2), 2):decimal(37,2),round((mon_sales1 / mon_sales2), 2):decimal(37,2),round((tue_sales1 / tue_sales2), 2):decimal(37,2),round((wed_sales1 / wed_sales2), 2):decimal(37,2),round((thu_sales1 / thu_sales2), 2):decimal(37,2),round((fri_sales1 / fri_sales2), 2):decimal(37,2),round((sat_sales1 / sat_sales2), 2):decimal(37,2)>
-- !query output
5270 3.18 1.63 2.25 1.64 3.41 3.62 3.72
5270 3.18 1.63 2.25 1.64 3.41 3.62 3.72
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
types.StructType(Seq(types.StructField("a", types.LongType))))
assert(
spark.range(1).select(ceil(col("id"), lit(0)).alias("a")).schema ==
types.StructType(Seq(types.StructField("a", types.DecimalType(21, 0)))))
types.StructType(Seq(types.StructField("a", types.DecimalType(20, 0)))))
checkAnswer(
sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
Row(0L, 1L, 2L))
Expand Down Expand Up @@ -263,7 +263,7 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
types.StructType(Seq(types.StructField("a", types.LongType))))
assert(
spark.range(1).select(floor(col("id"), lit(0)).alias("a")).schema ==
types.StructType(Seq(types.StructField("a", types.DecimalType(21, 0)))))
types.StructType(Seq(types.StructField("a", types.DecimalType(20, 0)))))
}

test("factorial") {
Expand Down

0 comments on commit 1c34d93

Please sign in to comment.