diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 735b1039..e3d3ed34 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -285,7 +285,7 @@ impl Sender { ohttp_relay: Url, ) -> Result<(Request, V2PostContext), CreateRequestError> { use crate::uri::UrlExt; - if let Some(expiry) = self.endpoint.exp() { + if let Ok(expiry) = self.endpoint.exp() { if std::time::SystemTime::now() > expiry { return Err(InternalCreateRequestError::Expired(expiry).into()); } diff --git a/payjoin/src/uri/error.rs b/payjoin/src/uri/error.rs index 34a05481..84794d49 100644 --- a/payjoin/src/uri/error.rs +++ b/payjoin/src/uri/error.rs @@ -18,6 +18,29 @@ pub(crate) enum ParseOhttpKeysParamError { InvalidOhttpKeys(crate::ohttp::ParseOhttpKeysError), } +#[cfg(feature = "v2")] +#[derive(Debug)] +pub(crate) enum ParseExpParamError { + MissingExp, + InvalidHrp(bitcoin::bech32::Hrp), + DecodeBech32(bitcoin::bech32::primitives::decode::CheckedHrpstringError), + InvalidExp(bitcoin::consensus::encode::Error), +} + +#[cfg(feature = "v2")] +impl std::fmt::Display for ParseExpParamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use ParseExpParamError::*; + + match &self { + MissingExp => write!(f, "exp is missing"), + InvalidHrp(h) => write!(f, "incorrect hrp for exp: {}", h), + DecodeBech32(d) => write!(f, "exp is not valid bech32: {}", d), + InvalidExp(i) => write!(f, "invalid exp: {}", i), + } + } +} + #[cfg(feature = "v2")] impl std::fmt::Display for ParseOhttpKeysParamError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/payjoin/src/uri/url_ext.rs b/payjoin/src/uri/url_ext.rs index 8ad080a9..2e3d255e 100644 --- a/payjoin/src/uri/url_ext.rs +++ b/payjoin/src/uri/url_ext.rs @@ -5,7 +5,7 @@ use bitcoin::consensus::encode::Decodable; use bitcoin::consensus::Encodable; use url::Url; -use super::error::{ParseOhttpKeysParamError, ParseReceiverPubkeyParamError}; +use super::error::{ParseExpParamError, ParseOhttpKeysParamError, ParseReceiverPubkeyParamError}; use crate::hpke::HpkePublicKey; use crate::ohttp::OhttpKeys; @@ -15,7 +15,7 @@ pub(crate) trait UrlExt { fn set_receiver_pubkey(&mut self, exp: HpkePublicKey); fn ohttp(&self) -> Result; fn set_ohttp(&mut self, ohttp: OhttpKeys); - fn exp(&self) -> Option; + fn exp(&self) -> Result; fn set_exp(&mut self, exp: std::time::SystemTime); } @@ -60,22 +60,24 @@ impl UrlExt for Url { fn set_ohttp(&mut self, ohttp: OhttpKeys) { set_param(self, "OH1", &ohttp.to_string()) } /// Retrieve the exp parameter from the URL fragment - fn exp(&self) -> Option { - get_param(self, "EX1", |value| { - let (hrp, bytes) = crate::bech32::nochecksum::decode(value).ok()?; + fn exp(&self) -> Result { + let value = + get_param(self, "EX1", |v| Some(v.to_owned())).ok_or(ParseExpParamError::MissingExp)?; - let ex_hrp: Hrp = Hrp::parse("EX").unwrap(); - if hrp != ex_hrp { - return None; - } + let (hrp, bytes) = + crate::bech32::nochecksum::decode(&value).map_err(ParseExpParamError::DecodeBech32)?; + + let ex_hrp: Hrp = Hrp::parse("EX").unwrap(); + if hrp != ex_hrp { + return Err(ParseExpParamError::InvalidHrp(hrp)); + } - let mut cursor = &bytes[..]; - u32::consensus_decode(&mut cursor) - .map(|timestamp| { - std::time::UNIX_EPOCH + std::time::Duration::from_secs(timestamp as u64) - }) - .ok() - }) + let mut cursor = &bytes[..]; + u32::consensus_decode(&mut cursor) + .map(|timestamp| { + std::time::UNIX_EPOCH + std::time::Duration::from_secs(timestamp as u64) + }) + .map_err(ParseExpParamError::InvalidExp) } /// Set the exp parameter in the URL fragment @@ -173,7 +175,26 @@ mod tests { url.set_exp(exp_time); assert_eq!(url.fragment(), Some("EX1C4UC6ES")); - assert_eq!(url.exp(), Some(exp_time)); + assert_eq!(url.exp().unwrap(), exp_time); + } + + #[test] + fn test_errors_when_parsing_exp() { + let missing_exp_url = Url::parse("http://example.com").unwrap(); + assert!(matches!(missing_exp_url.exp(), Err(ParseExpParamError::MissingExp))); + + let invalid_bech32_exp_url = + Url::parse("http://example.com?pj=https://test-payjoin-url#EX1invalid_bech_32") + .unwrap(); + assert!(matches!(invalid_bech32_exp_url.exp(), Err(ParseExpParamError::DecodeBech32(_)))); + + let invalid_hrp_exp_url = + Url::parse("http://example.com?pj=https://test-payjoin-url#EX1010").unwrap(); + assert!(matches!(invalid_hrp_exp_url.exp(), Err(ParseExpParamError::InvalidHrp(_)))); + + let invalid_timestamp_exp_url = + Url::parse("http://example.com?pj=https://test-payjoin-url#EX10").unwrap(); + assert!(matches!(invalid_timestamp_exp_url.exp(), Err(ParseExpParamError::InvalidExp(_)))) } #[test]