Skip to content

Commit

Permalink
Introduce monadic AmountOpResult
Browse files Browse the repository at this point in the history
We would like to return an error when doing match ops on amount types.
We cannot however use the stdlib `Result` or `Option` because we want to
implement ops on the result type.

Add an `AmountOpResult` type. Return this type from all math operations
on `Amount` and `SignedAmount`.

Implement `core::iter::Sum` for the new type to allow summing iterators
of amounts - somewhat ugly to use, see tests for example usage.
  • Loading branch information
tcharding committed Feb 5, 2025
1 parent bd27670 commit 90271c6
Show file tree
Hide file tree
Showing 10 changed files with 516 additions and 201 deletions.
6 changes: 3 additions & 3 deletions bitcoin/examples/taproot-psbt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ impl BenefactorWallet {
taproot_spend_info.internal_key(),
taproot_spend_info.merkle_root(),
);
let value = input_utxo.amount - ABSOLUTE_FEES;
let value = (input_utxo.amount - ABSOLUTE_FEES).unwrap_valid();

// Spend a normal BIP86-like output as an input in our inheritance funding transaction
let tx = generate_bip86_key_spend_tx(
Expand Down Expand Up @@ -476,7 +476,7 @@ impl BenefactorWallet {
let mut psbt = self.next_psbt.clone().expect("should have next_psbt");
let input = &mut psbt.inputs[0];
let input_value = input.witness_utxo.as_ref().unwrap().value;
let output_value = input_value - ABSOLUTE_FEES;
let output_value = (input_value - ABSOLUTE_FEES).into_result()?;

// We use some other derivation path in this example for our inheritance protocol. The important thing is to ensure
// that we use an unhardened path so we can make use of xpubs.
Expand Down Expand Up @@ -649,7 +649,7 @@ impl BeneficiaryWallet {
psbt.unsigned_tx.lock_time = lock_time;
psbt.unsigned_tx.output = vec![TxOut {
script_pubkey: to_address.script_pubkey(),
value: input_value - ABSOLUTE_FEES,
value: (input_value - ABSOLUTE_FEES).unwrap_valid(),
}];
psbt.outputs = vec![Output::default()];
let unsigned_tx = psbt.unsigned_tx.clone();
Expand Down
2 changes: 1 addition & 1 deletion bitcoin/src/blockdata/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,7 @@ mod tests {

// 10 sat/kwu * (204wu + BASE_WEIGHT) = 4 sats
let expected_fee = "4 sats".parse::<SignedAmount>().unwrap();
let expected_effective_value = value.to_signed() - expected_fee;
let expected_effective_value = (value.to_signed() - expected_fee).unwrap_valid();
assert_eq!(effective_value, expected_effective_value);
}

Expand Down
2 changes: 1 addition & 1 deletion bitcoin/src/psbt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2249,7 +2249,7 @@ mod tests {
};
assert_eq!(
t.fee().expect("fee calculation"),
prev_output_val - (output_0_val + output_1_val)
(prev_output_val - (output_0_val + output_1_val).unwrap_valid()).unwrap_valid()
);
// no previous output
let mut t2 = t.clone();
Expand Down
3 changes: 3 additions & 0 deletions units/src/amount/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//! We refer to the documentation on the types for more information.
mod error;
mod result;
#[cfg(feature = "serde")]
pub mod serde;

Expand Down Expand Up @@ -37,6 +38,8 @@ pub use self::{
signed::SignedAmount,
unsigned::Amount,
};
pub(in crate::amount) use self::result::OptionExt;
pub(crate) use self::result::{AmountOpError, AmountOpResult};

/// A set of denominations in which amounts can be expressed.
///
Expand Down
322 changes: 322 additions & 0 deletions units/src/amount/result.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
// SPADIX-License-Identifier: CC0-1.0

//! Provides a monodic result type that is used to return the result of
//! doing mathematical operations (`core::ops`) on amount types.
use core::{fmt, ops};

use super::{Amount, SignedAmount};

/// Result of an operation on [`Amount`] or [`SignedAmount`].
///
/// The type parameter `T` should be normally `Amout` or `SignedAmount`.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AmountOpResult<T> {
/// Result of a successful mathematical operation.
Valid(T),
/// Result of an unsuccessful mathematical operation.
Error(AmountOpError),
}

impl<T> AmountOpResult<T> {
/// Returns the contained valid amount, consuming `self`.
///
/// # Panics
///
/// Panics if the result is an `Error`.
pub fn unwrap_valid(self) -> T {
use AmountOpResult as R;

match self {
R::Valid(amount) => amount,
R::Error(_) => panic!("tried to unwrap an invalid result"),
}
}

/// Returns the contained error, consuming `self`.
///
/// # Panics
///
/// Panics if the result is a valid amount.
pub fn unwrap_error(self) -> AmountOpError {
use AmountOpResult as R;

match self {
R::Error(e) => e,
R::Valid(_) => panic!("tried to unwrap a valid result"),
}
}

/// Converts this `AmountOpResult` to a `Result<T, AmountOpError>`.
pub fn into_result(self) -> Result<T, AmountOpError> {
use AmountOpResult as R;

match self {
R::Valid(amount) => Ok(amount),
R::Error(e) => Err(e),
}
}

/// Returns `true` if this result is a valid amount.
pub fn is_valid(&self) -> bool {
use AmountOpResult as R;

match self {
R::Valid(_) => true,
R::Error(_) => false,
}
}

/// Returns `true` if this result is an invalid amount.
pub fn is_error(&self) -> bool { !self.is_valid() }
}

impl From<Amount> for AmountOpResult<Amount> {
fn from(a: Amount) -> Self { Self::Valid(a) }
}

impl From<SignedAmount> for AmountOpResult<SignedAmount> {
fn from(a: SignedAmount) -> Self { Self::Valid(a) }
}

impl<T> ops::Add for AmountOpResult<T>
where
T: ops::Add<Output = AmountOpResult<T>>,
{
type Output = AmountOpResult<T>;

fn add(self, rhs: Self) -> Self::Output {
use AmountOpResult as R;

match (self, rhs) {
(R::Valid(lhs), R::Valid(rhs)) => lhs + rhs,
(_, _) => R::Error(AmountOpError {}),
}
}
}
impl<T> ops::Add<AmountOpResult<T>> for &AmountOpResult<T>
where
T: ops::Add<Output = AmountOpResult<T>> + Copy,
{
type Output = AmountOpResult<T>;

fn add(self, rhs: AmountOpResult<T>) -> Self::Output {
use AmountOpResult as R;

match (self, rhs) {
(R::Valid(lhs), R::Valid(rhs)) => *lhs + rhs,
(_, _) => R::Error(AmountOpError {}),
}
}
}
impl<T> ops::Add<&AmountOpResult<T>> for AmountOpResult<T>
where
T: ops::Add<Output = AmountOpResult<T>> + Copy,
{
type Output = AmountOpResult<T>;

fn add(self, rhs: &AmountOpResult<T>) -> Self::Output {
use AmountOpResult as R;

match (self, rhs) {
(R::Valid(lhs), R::Valid(rhs)) => lhs + *rhs,
(_, _) => R::Error(AmountOpError {}),
}
}
}
impl<T> ops::Add for &AmountOpResult<T>
where
T: ops::Add<Output = AmountOpResult<T>> + Copy,
{
type Output = AmountOpResult<T>;

fn add(self, rhs: &AmountOpResult<T>) -> Self::Output {
use AmountOpResult as R;

match (self, rhs) {
(R::Valid(lhs), R::Valid(rhs)) => *lhs + *rhs,
(_, _) => R::Error(AmountOpError {}),
}
}
}

impl<T> ops::Sub for AmountOpResult<T>
where
T: ops::Sub<Output = AmountOpResult<T>>,
{
type Output = AmountOpResult<T>;

fn sub(self, rhs: Self) -> Self::Output {
use AmountOpResult as R;

match (self, rhs) {
(R::Valid(lhs), R::Valid(rhs)) => lhs - rhs,
(_, _) => R::Error(AmountOpError {}),
}
}
}
impl<T> ops::Sub<AmountOpResult<T>> for &AmountOpResult<T>
where
T: ops::Sub<Output = AmountOpResult<T>> + Copy,
{
type Output = AmountOpResult<T>;

fn sub(self, rhs: AmountOpResult<T>) -> Self::Output {
use AmountOpResult as R;

match (self, rhs) {
(R::Valid(lhs), R::Valid(rhs)) => *lhs - rhs,
(_, _) => R::Error(AmountOpError {}),
}
}
}
impl<T> ops::Sub<&AmountOpResult<T>> for AmountOpResult<T>
where
T: ops::Sub<Output = AmountOpResult<T>> + Copy,
{
type Output = AmountOpResult<T>;

fn sub(self, rhs: &AmountOpResult<T>) -> Self::Output {
use AmountOpResult as R;

match (self, rhs) {
(R::Valid(lhs), R::Valid(rhs)) => lhs - *rhs,
(_, _) => R::Error(AmountOpError {}),
}
}
}
impl<T> ops::Sub for &AmountOpResult<T>
where
T: ops::Sub<Output = AmountOpResult<T>> + Copy,
{
type Output = AmountOpResult<T>;

fn sub(self, rhs: Self) -> Self::Output {
use AmountOpResult as R;

match (self, rhs) {
(R::Valid(lhs), R::Valid(rhs)) => *lhs - *rhs,
(_, _) => R::Error(AmountOpError {}),
}
}
}

impl core::iter::Sum<AmountOpResult<Amount>> for AmountOpResult<Amount> {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = AmountOpResult<Amount>>,
{
use AmountOpResult as R;

iter.fold(R::Valid(Amount::ZERO), |acc, amount| match (acc, amount) {
(R::Valid(lhs), R::Valid(rhs)) => lhs + rhs,
(_, _) => R::Error(AmountOpError {}),
})
}
}
impl<'a> core::iter::Sum<AmountOpResult<&'a Amount>> for AmountOpResult<Amount> {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = AmountOpResult<&'a Amount>>,
{
use AmountOpResult as R;

iter.fold(R::Valid(Amount::ZERO), |acc, amount| match (acc, amount) {
(R::Valid(lhs), R::Valid(rhs)) => lhs + rhs,
(_, _) => R::Error(AmountOpError {}),
})
}
}
impl<'a> core::iter::Sum<&'a AmountOpResult<Amount>> for AmountOpResult<Amount> {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = &'a AmountOpResult<Amount>>,
{
use AmountOpResult as R;

iter.fold(R::Valid(Amount::ZERO), |acc, amount| match (acc, amount) {
(R::Valid(lhs), R::Valid(rhs)) => lhs + rhs,
(_, _) => R::Error(AmountOpError {}),
})
}
}

impl core::iter::Sum<SignedAmount> for AmountOpResult<SignedAmount> {
fn sum<I>(iter: I) -> AmountOpResult<SignedAmount>
where
I: Iterator<Item = SignedAmount>,
{
use AmountOpResult as R;

iter.fold(R::Valid(SignedAmount::ZERO), |acc, amount| match acc {
R::Valid(lhs) => lhs + amount,
R::Error(e) => R::Error(e),
})
}
}
impl<'a> core::iter::Sum<AmountOpResult<&'a SignedAmount>> for AmountOpResult<SignedAmount> {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = AmountOpResult<&'a SignedAmount>>,
{
use AmountOpResult as R;

iter.fold(R::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) {
(R::Valid(lhs), R::Valid(rhs)) => lhs + rhs,
(_, _) => R::Error(AmountOpError {}),
})
}
}
impl<'a> core::iter::Sum<&'a AmountOpResult<SignedAmount>> for AmountOpResult<SignedAmount> {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = &'a AmountOpResult<SignedAmount>>,
{
use AmountOpResult as R;

iter.fold(R::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) {
(R::Valid(lhs), R::Valid(rhs)) => lhs + rhs,
(_, _) => R::Error(AmountOpError {}),
})
}
}

pub(in crate::amount) trait OptionExt<T> {
fn unwrap_or_amount_op_error(self) -> AmountOpResult<T>;
}

impl OptionExt<Amount> for Option<Amount> {
fn unwrap_or_amount_op_error(self) -> AmountOpResult<Amount> {
use AmountOpResult as R;

match self {
Some(amount) => R::Valid(amount),
None => R::Error(AmountOpError {}),
}
}
}

impl OptionExt<SignedAmount> for Option<SignedAmount> {
fn unwrap_or_amount_op_error(self) -> AmountOpResult<SignedAmount> {
use AmountOpResult as R;

match self {
Some(amount) => R::Valid(amount),
None => R::Error(AmountOpError {}),
}
}
}

/// An error occurred while doing a mathematical operation.
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct AmountOpError;

impl fmt::Display for AmountOpError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "a math operation on amounts gave an invalid result")
}
}

#[cfg(feature = "std")]
impl std::error::Error for AmountOpError {}
Loading

0 comments on commit 90271c6

Please sign in to comment.