Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove BigDecimal from HighPrecisionMoney.fromPreciseAmount #443

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions util/src/main/scala/Money.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import language.implicitConversions
import java.math.MathContext
import java.text.NumberFormat
import java.util.{Currency, Locale}

import cats.Monoid
import cats.data.ValidatedNel
import cats.syntax.validated._
Expand Down Expand Up @@ -351,8 +350,7 @@ case class HighPrecisionMoney private (

val `type`: String = TypeName

lazy val amount: BigDecimal =
(BigDecimal(preciseAmount) * factor(fractionDigits)).setScale(fractionDigits)
lazy val amount: BigDecimal = preciseAmountToAmount(preciseAmount, fractionDigits)

def withFractionDigits(fd: Int)(implicit mode: RoundingMode): HighPrecisionMoney = {
val scaledAmount = amount.setScale(fd, mode)
Expand Down Expand Up @@ -474,6 +472,8 @@ case class HighPrecisionMoney private (
}

object HighPrecisionMoney {
import MoneyRounding._

object ImplicitsDecimal {
final implicit class HighPrecisionMoneyNotation(val amount: BigDecimal) extends AnyVal {
def EUR: HighPrecisionMoney = HighPrecisionMoney.EUR(amount)
Expand Down Expand Up @@ -566,6 +566,9 @@ object HighPrecisionMoney {
private def amountToPreciseAmount(amount: BigDecimal, fractionDigits: Int): Long =
(amount * Money.cachedCentPower(fractionDigits)).toLong

def preciseAmountToAmount(preciseAmount: Long, fractionDigits: Int): BigDecimal =
(BigDecimal(preciseAmount) * factor(fractionDigits)).setScale(fractionDigits)

def fromDecimalAmount(amount: BigDecimal, fractionDigits: Int, currency: Currency)(implicit
mode: RoundingMode): HighPrecisionMoney = {
val scaledAmount = amount.setScale(fractionDigits, mode)
Expand Down Expand Up @@ -606,12 +609,9 @@ object HighPrecisionMoney {
centAmount: Option[Long]): ValidatedNel[String, HighPrecisionMoney] =
for {
fd <- validateFractionDigits(fractionDigits, currency)
amount = BigDecimal(preciseAmount) * factor(fd)
scaledAmount = amount.setScale(fd, BigDecimal.RoundingMode.UNNECESSARY)
ca <- validateCentAmount(scaledAmount, centAmount, currency)
ca <- validateCentAmount(preciseAmount, fractionDigits, centAmount, currency)
// TODO: revisit this part! the rounding mode might be dynamic and configured elsewhere
clemniem marked this conversation as resolved.
Show resolved Hide resolved
actualCentAmount = ca.getOrElse(
roundToCents(scaledAmount, currency)(BigDecimal.RoundingMode.HALF_EVEN))
actualCentAmount = ca.getOrElse(roundHalfEven(preciseAmount, fractionDigits, currency))
} yield HighPrecisionMoney(preciseAmount, fd, actualCentAmount, currency)

private def validateFractionDigits(
Expand All @@ -625,13 +625,14 @@ object HighPrecisionMoney {
fractionDigits.validNel

private def validateCentAmount(
amount: BigDecimal,
preciseAmount: Long,
fractionDigits: Int,
centAmount: Option[Long],
currency: Currency): ValidatedNel[String, Option[Long]] =
centAmount match {
case Some(actual) =>
val min = roundToCents(amount, currency)(RoundingMode.FLOOR)
val max = roundToCents(amount, currency)(RoundingMode.CEILING)
val min = roundFloor(preciseAmount, fractionDigits, currency)
val max = roundCeiling(preciseAmount, fractionDigits, currency)

if (actual < min || actual > max)
s"centAmount must be correctly rounded preciseAmount (a number between $min and $max).".invalidNel
Expand Down
84 changes: 84 additions & 0 deletions util/src/main/scala/MoneyRounding.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package io.sphere.util

import java.util.Currency
import scala.annotation.tailrec

/** This object contains rounding algorithms for the Money classes. So far we used BigDecimal for
* this purpose, but BigDecimal is slower and consumes more memory than this approach.
*/
object MoneyRounding {

private def pow10(n: Int): Long = Math.pow(10, n).toLong

/** @return
* Floor rounded (preciseAmount, fractionDigits) to the cent value of the given currency
*/
def roundFloor(preciseAmount: Long, fractionDigits: Int, currency: Currency): Long =
if (preciseAmount < 0L) {
val power = pow10(fractionDigits - currency.getDefaultFractionDigits)
val floor = preciseAmount / power
val remainder = preciseAmount % power
if (remainder == 0L) floor else floor - 1L
} else
preciseAmount / pow10(fractionDigits - currency.getDefaultFractionDigits)

/** @return
* Ceiling rounded (preciseAmount, fractionDigits) to the cent value of the given currency
*/
def roundCeiling(preciseAmount: Long, fractionDigits: Int, currency: Currency): Long =
if (preciseAmount < 0L)
preciseAmount / pow10(fractionDigits - currency.getDefaultFractionDigits)
else {
val power = pow10(fractionDigits - currency.getDefaultFractionDigits)
val floor = preciseAmount / power
val remainder = preciseAmount % power
if (remainder == 0L) floor else floor + 1L
}

private def getFractionDigits(
fractionWithoutLeadingZeros: Long,
fractionDigits: Int): List[Int] = {
@tailrec
def loop(remainder: Long, acc: List[Int]): List[Int] = {
val lastDigit = (remainder % 10L).toInt
val newRemainder = remainder / 10L
val newAcc = lastDigit :: acc
if (newRemainder == 0L) newAcc
else loop(newRemainder, newAcc)
}
val digits = loop(fractionWithoutLeadingZeros, List.empty)

if (digits.length < fractionDigits) List.fill(fractionDigits - digits.length)(0) ::: digits
else digits
}

/** @return
* half even rounded (preciseAmount, fractionDigits) to the cent value of the given currency
*/
def roundHalfEven(preciseAmount: Long, fractionDigits: Int, currency: Currency): Long = {
val centFractionDigits = fractionDigits - currency.getDefaultFractionDigits
val power = pow10(centFractionDigits)
val integer = preciseAmount / power
val fraction = preciseAmount % power

// Eg: 3 for 123.456
val leastSignificantDigitOfInt = integer % 10L

val fractionDigitsList = getFractionDigits(fraction, centFractionDigits)

// Eg: 4 for 123.456
val mostSignificantDigitOfFraction :: rest = fractionDigitsList

if (mostSignificantDigitOfFraction == 5 && rest.forall(_ == 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure where this number 5 is coming from. Can you explain?

Copy link
Contributor Author

@benko-balog benko-balog Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the half even rounding we are interested in if we have exactly .5 in the fractional part. If it's lower or bigger it's down or up rounding, but if it's exactly something.5 we need to check for the integer's least significant digit (that's the next line in the code).
So this line is checking if the first digit after the decimal point is 5 and all the others are non-existent or zeros (to handle the cases of .50, .500, etc )

(Maybe firstDigitAfterDecimalPoint would be a better name for mostSignificantDigitOfFraction) 🤔

if (leastSignificantDigitOfInt % 2L == 0L) integer else integer + 1L
else if (mostSignificantDigitOfFraction >= 5)
integer + 1L
else if (mostSignificantDigitOfFraction == -5 && rest.forall(_ == 0))
if (leastSignificantDigitOfInt % 2L == 0L) integer else integer - 1L
else if (mostSignificantDigitOfFraction <= -5)
integer - 1L
else
integer
}

}
3 changes: 2 additions & 1 deletion util/src/test/scala/DomainObjectsGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ object DomainObjectsGen {

val highPrecisionMoney: Gen[HighPrecisionMoney] = for {
money <- money
} yield HighPrecisionMoney.fromMoney(money, money.currency.getDefaultFractionDigits)
fractionDigits <- Gen.oneOf(money.currency.getDefaultFractionDigits to 10)
} yield HighPrecisionMoney.fromMoney(money, fractionDigits)

val baseMoney: Gen[BaseMoney] = Gen.oneOf(money, highPrecisionMoney)

Expand Down
78 changes: 78 additions & 0 deletions util/src/test/scala/MoneyRoundingSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package io.sphere.util

import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.must.Matchers
import org.scalatest.prop.TableDrivenPropertyChecks
import org.scalatest.prop.TableDrivenPropertyChecks.Table
import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks

import java.util.Currency
import scala.math.BigDecimal.RoundingMode

class MoneyRoundingSpec extends AnyFunSpec with Matchers with ScalaCheckDrivenPropertyChecks {
val Euro: Currency = Currency.getInstance("EUR")
val ZWL: Currency = Currency.getInstance("ZWL")
val JPY: Currency = Currency.getInstance("JPY")

describe("Money Rounding") {
it("roundFloor should behave similarly to BigDecimal rounding") {
ScalaCheckDrivenPropertyChecks.forAll(DomainObjectsGen.highPrecisionMoney) { h =>
val bdRes = HighPrecisionMoney.roundToCents(h.amount, h.currency)(RoundingMode.FLOOR)
val longRes =
MoneyRounding.roundFloor(h.preciseAmount, h.fractionDigits, h.currency)
bdRes must be(longRes)
}
}

it("roundCeiling should behave similarly to BigDecimal rounding") {
ScalaCheckDrivenPropertyChecks.forAll(DomainObjectsGen.highPrecisionMoney) { h =>
benko-balog marked this conversation as resolved.
Show resolved Hide resolved
val bdRes = HighPrecisionMoney.roundToCents(h.amount, h.currency)(RoundingMode.CEILING)
val longRes =
MoneyRounding.roundCeiling(h.preciseAmount, h.fractionDigits, h.currency)
bdRes must be(longRes)
}
}

it("roundHalfEven should behave similarly to BigDecimal rounding") {
benko-balog marked this conversation as resolved.
Show resolved Hide resolved
// I used random generated values later, but I needed these very specific values too to check the
// edge cases of the half even rounding
val data = Table(
("preciseAmount", "fraction", "currency"),
(1119L, 3, Euro),
(1111L, 3, Euro),
(1115L, 3, Euro),
(1125L, 3, Euro),
(112500L, 5, Euro),
(11250001L, 7, Euro),
(11000004L, 7, Euro),
(11249999L, 7, Euro),
(-1119L, 3, Euro),
(-1111L, 3, Euro),
(-1115L, 3, Euro),
(-1125L, 3, Euro),
(-112500L, 5, Euro),
(5721482481806080960L, 6, ZWL),
(123L, 0, JPY)
)

TableDrivenPropertyChecks.forAll(data) { (preciseAmount, fd, cur) =>
val amount = HighPrecisionMoney.preciseAmountToAmount(preciseAmount, fd)
val bdRes = HighPrecisionMoney.roundToCents(amount, cur)(RoundingMode.HALF_EVEN)

val longRes =
MoneyRounding.roundHalfEven(preciseAmount, fd, cur)

bdRes must be(longRes)
}

ScalaCheckDrivenPropertyChecks.forAll(DomainObjectsGen.highPrecisionMoney) { h =>
val bdRes = HighPrecisionMoney.roundToCents(h.amount, h.currency)(RoundingMode.HALF_EVEN)

val longRes =
MoneyRounding.roundHalfEven(h.preciseAmount, h.fractionDigits, h.currency)

bdRes must be(longRes)
}
}
}
}