Skip to content

Commit

Permalink
Run additional validation on InputPair::new
Browse files Browse the repository at this point in the history
Validate address type and, if it's P2SH, ensures there is a
redeem_script on the psbtin.
  • Loading branch information
spacebear21 committed Oct 30, 2024
1 parent 21d0cd6 commit e7228a5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 37 deletions.
58 changes: 31 additions & 27 deletions payjoin/src/psbt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,39 +108,39 @@ const NESTED_P2WPKH_MAX: InputWeightPrediction = InputWeightPrediction::from_sli

#[derive(Clone, Debug)]
pub struct InputPair {
txin: TxIn,
psbtin: psbt::Input,
pub(crate) txin: TxIn,
pub(crate) psbtin: psbt::Input,
}

impl InputPair {
pub fn new(txin: TxIn, psbtin: psbt::Input) -> Result<Self, PsbtInputError> {
let input_pair = Self { txin, psbtin };
// TODO validate and document Input details required for Input Contribution fee estimation
// TODO Validate AddressType will return valid AddressType or an error
// TODO consider whether or not this should live in receive module since it's a baby of that state machine
InternalInputPair::from(&input_pair).validate_utxo(true)?;
let raw = InternalInputPair::from(&input_pair);
raw.validate_utxo(true)?;
let address_type = raw.address_type()?;
if address_type == AddressType::P2sh && input_pair.psbtin.redeem_script.is_none() {
return Err(PsbtInputError::NoRedeemScript);
}
Ok(input_pair)
}

pub(crate) fn txin(&self) -> &TxIn { &self.txin }

pub(crate) fn psbtin(&self) -> &psbt::Input { &self.psbtin }

pub(crate) fn address_type(&self) -> Result<AddressType, AddressTypeError> {
let raw = InternalInputPair { txin: &self.txin, psbtin: &self.psbtin };
raw.address_type()
pub(crate) fn address_type(&self) -> AddressType {
InternalInputPair::from(self)
.address_type()
.expect("address type should have been validated in InputPair::new")
}

pub(crate) fn previous_txout(&self) -> TxOut {
InternalInputPair::from(self)
.previous_txout()
.expect("missing UTXO information should have been validated in InputPair::new")
.expect("UTXO information should have been validated in InputPair::new")
.clone()
}
}

#[derive(Clone, Debug)]
pub struct InternalInputPair<'a> {
pub(crate) struct InternalInputPair<'a> {
pub txin: &'a TxIn,
pub psbtin: &'a psbt::Input,
}
Expand All @@ -149,12 +149,6 @@ impl<'a> From<&'a InputPair> for InternalInputPair<'a> {
fn from(pair: &'a InputPair) -> Self { Self { psbtin: &pair.psbtin, txin: &pair.txin } }
}

impl<'a> From<&InternalInputPair<'a>> for InputPair {
fn from(internal: &InternalInputPair<'a>) -> Self {
InputPair { txin: internal.txin.clone(), psbtin: internal.psbtin.clone() }
}
}

impl<'a> InternalInputPair<'a> {
/// Returns TxOut associated with the input
pub fn previous_txout(&self) -> Result<&TxOut, PrevTxOutError> {
Expand Down Expand Up @@ -294,24 +288,30 @@ pub enum PsbtInputError {
UnequalTxid,
/// TxOut provided in `segwit_utxo` doesn't match the one in `non_segwit_utxo`
SegWitTxOutMismatch,
AddressType(AddressTypeError),
NoRedeemScript,
}

impl fmt::Display for PsbtInputError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
PsbtInputError::PrevTxOut(_) => write!(f, "invalid previous transaction output"),
PsbtInputError::UnequalTxid => write!(f, "transaction ID of previous transaction doesn't match one specified in input spending it"),
PsbtInputError::SegWitTxOutMismatch => write!(f, "transaction output provided in SegWit UTXO field doesn't match the one in non-SegWit UTXO field"),
Self::PrevTxOut(_) => write!(f, "invalid previous transaction output"),
Self::UnequalTxid => write!(f, "transaction ID of previous transaction doesn't match one specified in input spending it"),
Self::SegWitTxOutMismatch => write!(f, "transaction output provided in SegWit UTXO field doesn't match the one in non-SegWit UTXO field"),
Self::AddressType(_) => write!(f, "invalid address type"),
Self::NoRedeemScript => write!(f, "provided p2sh PSBT input is missing a redeem_script"),
}
}
}

impl std::error::Error for PsbtInputError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
PsbtInputError::PrevTxOut(error) => Some(error),
PsbtInputError::UnequalTxid => None,
PsbtInputError::SegWitTxOutMismatch => None,
Self::PrevTxOut(error) => Some(error),
Self::UnequalTxid => None,
Self::SegWitTxOutMismatch => None,
Self::AddressType(error) => Some(error),
Self::NoRedeemScript => None,
}
}
}
Expand All @@ -320,6 +320,10 @@ impl From<PrevTxOutError> for PsbtInputError {
fn from(value: PrevTxOutError) -> Self { PsbtInputError::PrevTxOut(value) }
}

impl From<AddressTypeError> for PsbtInputError {
fn from(value: AddressTypeError) -> Self { Self::AddressType(value) }
}

#[derive(Debug)]
pub struct PsbtInputsError {
index: usize,
Expand All @@ -337,7 +341,7 @@ impl std::error::Error for PsbtInputsError {
}

#[derive(Debug)]
pub(crate) enum AddressTypeError {
pub enum AddressTypeError {
PrevTxOut(PrevTxOutError),
InvalidScript(FromScriptError),
UnknownAddressType,
Expand Down
17 changes: 7 additions & 10 deletions payjoin/src/receive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,21 +553,19 @@ impl WantsInputs {
let mut rng = rand::thread_rng();
let mut receiver_input_amount = Amount::ZERO;
for input_pair in inputs.into_iter() {
let input_type =
input_pair.address_type().map_err(InternalInputContributionError::AddressType)?;

let input_type = input_pair.address_type();
if self.params.v == 1 {
// v1 payjoin proposals must not introduce mixed input script types
self.check_mixed_input_types(input_type, uniform_sender_input_type)?;
}

receiver_input_amount += input_pair.previous_txout().value;
let index = rng.gen_range(0..=self.payjoin_psbt.unsigned_tx.input.len());
payjoin_psbt.inputs.insert(index, input_pair.psbtin().clone());
payjoin_psbt.inputs.insert(index, input_pair.psbtin);
payjoin_psbt
.unsigned_tx
.input
.insert(index, TxIn { sequence: original_sequence, ..input_pair.txin().clone() });
.insert(index, TxIn { sequence: original_sequence, ..input_pair.txin });
}

// Add the receiver change amount to the receiver change output, if applicable
Expand Down Expand Up @@ -964,11 +962,10 @@ mod test {
proposal.params.min_feerate = FeeRate::from_sat_per_vb_unchecked(1000);
// Input contribution for the receiver, from the BIP78 test vector
let proposal_psbt = Psbt::from_str("cHNidP8BAJwCAAAAAo8nutGgJdyYGXWiBEb45Hoe9lWGbkxh/6bNiOJdCDuDAAAAAAD+////jye60aAl3JgZdaIERvjkeh72VYZuTGH/ps2I4l0IO4MBAAAAAP7///8CJpW4BQAAAAAXqRQd6EnwadJ0FQ46/q6NcutaawlEMIcACT0AAAAAABepFHdAltvPSGdDwi9DR+m0af6+i2d6h9MAAAAAAAEBIICEHgAAAAAAF6kUyPLL+cphRyyI5GTUazV0hF2R2NWHAQcXFgAUX4BmVeWSTJIEwtUb5TlPS/ntohABCGsCRzBEAiBnu3tA3yWlT0WBClsXXS9j69Bt+waCs9JcjWtNjtv7VgIge2VYAaBeLPDB6HGFlpqOENXMldsJezF9Gs5amvDQRDQBIQJl1jz1tBt8hNx2owTm+4Du4isx0pmdKNMNIjjaMHFfrQAAAA==").unwrap();
let input = InputPair::new(
proposal_psbt.unsigned_tx.input[1].clone(),
proposal_psbt.inputs[1].clone(),
)
.expect("Input pair should be valid");
let input = InputPair {
txin: proposal_psbt.unsigned_tx.input[1].clone(),
psbtin: proposal_psbt.inputs[1].clone(),
};
let mut payjoin = proposal
.assume_interactive_receiver()
.check_inputs_not_owned(|_| Ok(false))
Expand Down

0 comments on commit e7228a5

Please sign in to comment.