From e7228a50d5471fe02c39369036cbfd455985e811 Mon Sep 17 00:00:00 2001 From: spacebear Date: Wed, 30 Oct 2024 14:31:38 -0400 Subject: [PATCH] Run additional validation on InputPair::new Validate address type and, if it's P2SH, ensures there is a redeem_script on the psbtin. --- payjoin/src/psbt.rs | 58 ++++++++++++++++++++------------------ payjoin/src/receive/mod.rs | 17 +++++------ 2 files changed, 38 insertions(+), 37 deletions(-) diff --git a/payjoin/src/psbt.rs b/payjoin/src/psbt.rs index 8f178fd8..9821fb69 100644 --- a/payjoin/src/psbt.rs +++ b/payjoin/src/psbt.rs @@ -108,39 +108,39 @@ const NESTED_P2WPKH_MAX: InputWeightPrediction = InputWeightPrediction::from_sli #[derive(Clone, Debug)] pub struct InputPair { - txin: TxIn, - psbtin: psbt::Input, + pub(crate) txin: TxIn, + pub(crate) psbtin: psbt::Input, } impl InputPair { pub fn new(txin: TxIn, psbtin: psbt::Input) -> Result { let input_pair = Self { txin, psbtin }; - // TODO validate and document Input details required for Input Contribution fee estimation - // TODO Validate AddressType will return valid AddressType or an error // TODO consider whether or not this should live in receive module since it's a baby of that state machine - InternalInputPair::from(&input_pair).validate_utxo(true)?; + let raw = InternalInputPair::from(&input_pair); + raw.validate_utxo(true)?; + let address_type = raw.address_type()?; + if address_type == AddressType::P2sh && input_pair.psbtin.redeem_script.is_none() { + return Err(PsbtInputError::NoRedeemScript); + } Ok(input_pair) } - pub(crate) fn txin(&self) -> &TxIn { &self.txin } - - pub(crate) fn psbtin(&self) -> &psbt::Input { &self.psbtin } - - pub(crate) fn address_type(&self) -> Result { - let raw = InternalInputPair { txin: &self.txin, psbtin: &self.psbtin }; - raw.address_type() + pub(crate) fn address_type(&self) -> AddressType { + InternalInputPair::from(self) + .address_type() + .expect("address type should have been validated in InputPair::new") } pub(crate) fn previous_txout(&self) -> TxOut { InternalInputPair::from(self) .previous_txout() - .expect("missing UTXO information should have been validated in InputPair::new") + .expect("UTXO information should have been validated in InputPair::new") .clone() } } #[derive(Clone, Debug)] -pub struct InternalInputPair<'a> { +pub(crate) struct InternalInputPair<'a> { pub txin: &'a TxIn, pub psbtin: &'a psbt::Input, } @@ -149,12 +149,6 @@ impl<'a> From<&'a InputPair> for InternalInputPair<'a> { fn from(pair: &'a InputPair) -> Self { Self { psbtin: &pair.psbtin, txin: &pair.txin } } } -impl<'a> From<&InternalInputPair<'a>> for InputPair { - fn from(internal: &InternalInputPair<'a>) -> Self { - InputPair { txin: internal.txin.clone(), psbtin: internal.psbtin.clone() } - } -} - impl<'a> InternalInputPair<'a> { /// Returns TxOut associated with the input pub fn previous_txout(&self) -> Result<&TxOut, PrevTxOutError> { @@ -294,14 +288,18 @@ pub enum PsbtInputError { UnequalTxid, /// TxOut provided in `segwit_utxo` doesn't match the one in `non_segwit_utxo` SegWitTxOutMismatch, + AddressType(AddressTypeError), + NoRedeemScript, } impl fmt::Display for PsbtInputError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - PsbtInputError::PrevTxOut(_) => write!(f, "invalid previous transaction output"), - PsbtInputError::UnequalTxid => write!(f, "transaction ID of previous transaction doesn't match one specified in input spending it"), - PsbtInputError::SegWitTxOutMismatch => write!(f, "transaction output provided in SegWit UTXO field doesn't match the one in non-SegWit UTXO field"), + Self::PrevTxOut(_) => write!(f, "invalid previous transaction output"), + Self::UnequalTxid => write!(f, "transaction ID of previous transaction doesn't match one specified in input spending it"), + Self::SegWitTxOutMismatch => write!(f, "transaction output provided in SegWit UTXO field doesn't match the one in non-SegWit UTXO field"), + Self::AddressType(_) => write!(f, "invalid address type"), + Self::NoRedeemScript => write!(f, "provided p2sh PSBT input is missing a redeem_script"), } } } @@ -309,9 +307,11 @@ impl fmt::Display for PsbtInputError { impl std::error::Error for PsbtInputError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - PsbtInputError::PrevTxOut(error) => Some(error), - PsbtInputError::UnequalTxid => None, - PsbtInputError::SegWitTxOutMismatch => None, + Self::PrevTxOut(error) => Some(error), + Self::UnequalTxid => None, + Self::SegWitTxOutMismatch => None, + Self::AddressType(error) => Some(error), + Self::NoRedeemScript => None, } } } @@ -320,6 +320,10 @@ impl From for PsbtInputError { fn from(value: PrevTxOutError) -> Self { PsbtInputError::PrevTxOut(value) } } +impl From for PsbtInputError { + fn from(value: AddressTypeError) -> Self { Self::AddressType(value) } +} + #[derive(Debug)] pub struct PsbtInputsError { index: usize, @@ -337,7 +341,7 @@ impl std::error::Error for PsbtInputsError { } #[derive(Debug)] -pub(crate) enum AddressTypeError { +pub enum AddressTypeError { PrevTxOut(PrevTxOutError), InvalidScript(FromScriptError), UnknownAddressType, diff --git a/payjoin/src/receive/mod.rs b/payjoin/src/receive/mod.rs index 21b3dbf6..90febead 100644 --- a/payjoin/src/receive/mod.rs +++ b/payjoin/src/receive/mod.rs @@ -553,9 +553,7 @@ impl WantsInputs { let mut rng = rand::thread_rng(); let mut receiver_input_amount = Amount::ZERO; for input_pair in inputs.into_iter() { - let input_type = - input_pair.address_type().map_err(InternalInputContributionError::AddressType)?; - + let input_type = input_pair.address_type(); if self.params.v == 1 { // v1 payjoin proposals must not introduce mixed input script types self.check_mixed_input_types(input_type, uniform_sender_input_type)?; @@ -563,11 +561,11 @@ impl WantsInputs { receiver_input_amount += input_pair.previous_txout().value; let index = rng.gen_range(0..=self.payjoin_psbt.unsigned_tx.input.len()); - payjoin_psbt.inputs.insert(index, input_pair.psbtin().clone()); + payjoin_psbt.inputs.insert(index, input_pair.psbtin); payjoin_psbt .unsigned_tx .input - .insert(index, TxIn { sequence: original_sequence, ..input_pair.txin().clone() }); + .insert(index, TxIn { sequence: original_sequence, ..input_pair.txin }); } // Add the receiver change amount to the receiver change output, if applicable @@ -964,11 +962,10 @@ mod test { proposal.params.min_feerate = FeeRate::from_sat_per_vb_unchecked(1000); // Input contribution for the receiver, from the BIP78 test vector let proposal_psbt = Psbt::from_str("cHNidP8BAJwCAAAAAo8nutGgJdyYGXWiBEb45Hoe9lWGbkxh/6bNiOJdCDuDAAAAAAD+////jye60aAl3JgZdaIERvjkeh72VYZuTGH/ps2I4l0IO4MBAAAAAP7///8CJpW4BQAAAAAXqRQd6EnwadJ0FQ46/q6NcutaawlEMIcACT0AAAAAABepFHdAltvPSGdDwi9DR+m0af6+i2d6h9MAAAAAAAEBIICEHgAAAAAAF6kUyPLL+cphRyyI5GTUazV0hF2R2NWHAQcXFgAUX4BmVeWSTJIEwtUb5TlPS/ntohABCGsCRzBEAiBnu3tA3yWlT0WBClsXXS9j69Bt+waCs9JcjWtNjtv7VgIge2VYAaBeLPDB6HGFlpqOENXMldsJezF9Gs5amvDQRDQBIQJl1jz1tBt8hNx2owTm+4Du4isx0pmdKNMNIjjaMHFfrQAAAA==").unwrap(); - let input = InputPair::new( - proposal_psbt.unsigned_tx.input[1].clone(), - proposal_psbt.inputs[1].clone(), - ) - .expect("Input pair should be valid"); + let input = InputPair { + txin: proposal_psbt.unsigned_tx.input[1].clone(), + psbtin: proposal_psbt.inputs[1].clone(), + }; let mut payjoin = proposal .assume_interactive_receiver() .check_inputs_not_owned(|_| Ok(false))