diff --git a/payjoin-directory/src/db.rs b/payjoin-directory/src/db.rs index 41b9b6dc..6165abf9 100644 --- a/payjoin-directory/src/db.rs +++ b/payjoin-directory/src/db.rs @@ -7,19 +7,6 @@ use tracing::debug; const DEFAULT_COLUMN: &str = ""; const PJ_V1_COLUMN: &str = "pjv1"; -// TODO move to payjoin crate as pub? -// TODO impl From for ShortId -// TODO impl Display for ShortId (Base64) -// TODO impl TryFrom<&str> for ShortId (Base64) -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) struct ShortId(pub [u8; 8]); - -impl ShortId { - pub fn column_key(&self, column: &str) -> Vec { - self.0.iter().chain(column.as_bytes()).copied().collect() - } -} - #[derive(Debug, Clone)] pub(crate) struct DbPool { client: Client, @@ -32,30 +19,30 @@ impl DbPool { Ok(Self { client, timeout }) } - pub async fn push_default(&self, pubkey_id: &ShortId, data: Vec) -> RedisResult<()> { - self.push(pubkey_id, DEFAULT_COLUMN, data).await + pub async fn push_default(&self, subdirectory_id: &str, data: Vec) -> RedisResult<()> { + self.push(subdirectory_id, DEFAULT_COLUMN, data).await } - pub async fn peek_default(&self, pubkey_id: &ShortId) -> Option>> { - self.peek_with_timeout(pubkey_id, DEFAULT_COLUMN).await + pub async fn peek_default(&self, subdirectory_id: &str) -> Option>> { + self.peek_with_timeout(subdirectory_id, DEFAULT_COLUMN).await } - pub async fn push_v1(&self, pubkey_id: &ShortId, data: Vec) -> RedisResult<()> { - self.push(pubkey_id, PJ_V1_COLUMN, data).await + pub async fn push_v1(&self, subdirectory_id: &str, data: Vec) -> RedisResult<()> { + self.push(subdirectory_id, PJ_V1_COLUMN, data).await } - pub async fn peek_v1(&self, pubkey_id: &ShortId) -> Option>> { - self.peek_with_timeout(pubkey_id, PJ_V1_COLUMN).await + pub async fn peek_v1(&self, subdirectory_id: &str) -> Option>> { + self.peek_with_timeout(subdirectory_id, PJ_V1_COLUMN).await } async fn push( &self, - pubkey_id: &ShortId, + subdirectory_id: &str, channel_type: &str, data: Vec, ) -> RedisResult<()> { let mut conn = self.client.get_async_connection().await?; - let key = pubkey_id.column_key(channel_type); + let key = channel_name(subdirectory_id, channel_type); () = conn.set(&key, data.clone()).await?; () = conn.publish(&key, "updated").await?; Ok(()) @@ -63,17 +50,17 @@ impl DbPool { async fn peek_with_timeout( &self, - pubkey_id: &ShortId, + subdirectory_id: &str, channel_type: &str, ) -> Option>> { - tokio::time::timeout(self.timeout, self.peek(pubkey_id, channel_type)).await.ok() + tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await.ok() } - async fn peek(&self, pubkey_id: &ShortId, channel_type: &str) -> RedisResult> { + async fn peek(&self, subdirectory_id: &str, channel_type: &str) -> RedisResult> { let mut conn = self.client.get_async_connection().await?; - let key = pubkey_id.column_key(channel_type); + let key = channel_name(subdirectory_id, channel_type); - // Attempt to fetch existing content for the given pubkey_id and channel_type + // Attempt to fetch existing content for the given subdirectory_id and channel_type if let Ok(data) = conn.get::<_, Vec>(&key).await { if !data.is_empty() { return Ok(data); @@ -83,7 +70,7 @@ impl DbPool { // Set up a temporary listener for changes let mut pubsub_conn = self.client.get_async_connection().await?.into_pubsub(); - let channel_name = pubkey_id.column_key(channel_type); + let channel_name = channel_name(subdirectory_id, channel_type); pubsub_conn.subscribe(&channel_name).await?; // Use a block to limit the scope of the mutable borrow @@ -116,3 +103,7 @@ impl DbPool { Ok(data) } } + +fn channel_name(subdirectory_id: &str, channel_type: &str) -> Vec { + (subdirectory_id.to_owned() + channel_type).into_bytes() +} diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 04a05de2..0fde99a4 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -3,8 +3,6 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Result; -use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; -use bitcoin::base64::Engine; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Empty, Full}; use hyper::body::{Body, Bytes, Incoming}; @@ -17,8 +15,6 @@ use tokio::net::TcpListener; use tokio::sync::Mutex; use tracing::{debug, error, info, trace}; -use crate::db::ShortId; - pub const DEFAULT_DIR_PORT: u16 = 8080; pub const DEFAULT_DB_HOST: &str = "localhost:6379"; pub const DEFAULT_TIMEOUT_SECS: u64 = 30; @@ -34,6 +30,9 @@ const V1_REJECT_RES_JSON: &str = r#"{{"errorCode": "original-psbt-rejected ", "message": "Body is not a string"}}"#; const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message": "V2 receiver offline. V1 sends require synchronous communications."}}"#; +// 8 bytes as bech32 is 12.8 characters +const ID_LENGTH: usize = 13; + mod db; use crate::db::DbPool; @@ -306,11 +305,11 @@ async fn post_fallback_v1( }; let v2_compat_body = format!("{}\n{}", body_str, query); - let id = decode_short_id(id)?; - pool.push_default(&id, v2_compat_body.into()) + let id = check_id_length(id)?; + pool.push_default(id, v2_compat_body.into()) .await .map_err(|e| HandlerError::BadRequest(e.into()))?; - match pool.peek_v1(&id).await { + match pool.peek_v1(id).await { Some(result) => match result { Ok(buffered_req) => Ok(Response::new(full(buffered_req))), Err(e) => Err(HandlerError::BadRequest(e.into())), @@ -327,19 +326,29 @@ async fn put_payjoin_v1( trace!("Put_payjoin_v1"); let ok_response = Response::builder().status(StatusCode::OK).body(empty())?; - let id = decode_short_id(id)?; + let id = check_id_length(id)?; let req = body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); if req.len() > V1_MAX_BUFFER_SIZE { return Err(HandlerError::PayloadTooLarge); } - match pool.push_v1(&id, req.into()).await { + match pool.push_v1(id, req.into()).await { Ok(_) => Ok(ok_response), Err(e) => Err(HandlerError::BadRequest(e.into())), } } +fn check_id_length(id: &str) -> Result<&str, HandlerError> { + if id.len() != ID_LENGTH { + return Err(HandlerError::BadRequest(anyhow::anyhow!( + "subdirectory ID must be 13 bech32 characters", + ))); + } + + Ok(id) +} + async fn post_subdir( id: &str, body: BoxBody, @@ -348,14 +357,15 @@ async fn post_subdir( let none_response = Response::builder().status(StatusCode::OK).body(empty())?; trace!("post_subdir"); - let id = decode_short_id(id)?; + let id = check_id_length(id)?; + let req = body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); if req.len() > V1_MAX_BUFFER_SIZE { return Err(HandlerError::PayloadTooLarge); } - match pool.push_default(&id, req.into()).await { + match pool.push_default(id, req.into()).await { Ok(_) => Ok(none_response), Err(e) => Err(HandlerError::BadRequest(e.into())), } @@ -366,8 +376,8 @@ async fn get_subdir( pool: DbPool, ) -> Result>, HandlerError> { trace!("get_subdir"); - let id = decode_short_id(id)?; - match pool.peek_default(&id).await { + let id = check_id_length(id)?; + match pool.peek_default(id).await { Some(result) => match result { Ok(buffered_req) => Ok(Response::new(full(buffered_req))), Err(e) => Err(HandlerError::BadRequest(e.into())), @@ -396,16 +406,6 @@ async fn get_ohttp_keys( Ok(res) } -fn decode_short_id(input: &str) -> Result { - let decoded = - BASE64_URL_SAFE_NO_PAD.decode(input).map_err(|e| HandlerError::BadRequest(e.into()))?; - - decoded[..8] - .try_into() - .map_err(|_| HandlerError::BadRequest(anyhow::anyhow!("Invalid subdirectory ID"))) - .map(ShortId) -} - fn empty() -> BoxBody { Empty::::new().map_err(|never| match never {}).boxed() } diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 39cd2a1a..b2d619f9 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -1,8 +1,6 @@ use std::str::FromStr; use std::time::{Duration, SystemTime}; -use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; -use bitcoin::base64::Engine; use bitcoin::hashes::{sha256, Hash}; use bitcoin::psbt::Psbt; use bitcoin::{Address, FeeRate, OutPoint, Script, TxOut}; @@ -20,6 +18,7 @@ use crate::ohttp::{ohttp_decapsulate, ohttp_encapsulate, OhttpEncapsulationError use crate::psbt::PsbtExt; use crate::receive::optional_parameters::Params; use crate::receive::InputPair; +use crate::uri::ShortId; use crate::{PjUriBuilder, Request}; pub(crate) mod error; @@ -48,9 +47,8 @@ where Ok(address.assume_checked()) } -fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> String { - let hash = sha256::Hash::hash(&pubkey.to_compressed_bytes()); - BASE64_URL_SAFE_NO_PAD.encode(&hash.as_byte_array()[..8]) +fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> ShortId { + sha256::Hash::hash(&pubkey.to_compressed_bytes()).into() } /// A payjoin V2 receiver, allowing for polled requests to the @@ -200,22 +198,18 @@ impl Receiver { // The contents of the `&pj=` query parameter. // This identifies a session at the payjoin directory server. pub fn pj_url(&self) -> Url { - let id_base64 = BASE64_URL_SAFE_NO_PAD.encode(self.id()); let mut url = self.context.directory.clone(); { let mut path_segments = url.path_segments_mut().expect("Payjoin Directory URL cannot be a base"); - path_segments.push(&id_base64); + path_segments.push(&self.id().to_string()); } url } /// The per-session identifier - pub fn id(&self) -> [u8; 8] { - let hash = sha256::Hash::hash(&self.context.s.public_key().to_compressed_bytes()); - hash.as_byte_array()[..8] - .try_into() - .expect("truncating SHA256 to 8 bytes should always succeed") + pub fn id(&self) -> ShortId { + sha256::Hash::hash(&self.context.s.public_key().to_compressed_bytes()).into() } } @@ -479,8 +473,11 @@ impl PayjoinProposal { // Prepare v2 payload let payjoin_bytes = self.inner.payjoin_psbt.serialize(); let sender_subdir = subdir_path_from_pubkey(e); - target_resource = - self.context.directory.join(&sender_subdir).map_err(|e| Error::Server(e.into()))?; + target_resource = self + .context + .directory + .join(&sender_subdir.to_string()) + .map_err(|e| Error::Server(e.into()))?; body = encrypt_message_b(payjoin_bytes, &self.context.s, e)?; method = "POST"; } else { @@ -490,7 +487,7 @@ impl PayjoinProposal { target_resource = self .context .directory - .join(&receiver_subdir) + .join(&receiver_subdir.to_string()) .map_err(|e| Error::Server(e.into()))?; method = "PUT"; } diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 07ef0946..2292a00a 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -23,8 +23,6 @@ use std::str::FromStr; -#[cfg(feature = "v2")] -use bitcoin::base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; #[cfg(feature = "v2")] use bitcoin::hashes::{sha256, Hash}; use bitcoin::psbt::Psbt; @@ -41,6 +39,8 @@ use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeKeyPair, HpkePublicK use crate::ohttp::{ohttp_decapsulate, ohttp_encapsulate}; use crate::psbt::PsbtExt; use crate::request::Request; +#[cfg(feature = "v2")] +use crate::uri::ShortId; use crate::PjUri; // See usize casts @@ -405,8 +405,8 @@ impl V2GetContext { // TODO unify with receiver's fn subdir_path_from_pubkey let hash = sha256::Hash::hash(&self.hpke_ctx.reply_pair.public_key().to_compressed_bytes()); - let subdir = BASE64_URL_SAFE_NO_PAD.encode(&hash.as_byte_array()[..8]); - url.set_path(&subdir); + let subdir: ShortId = hash.into(); + url.set_path(&subdir.to_string()); let body = encrypt_message_a( Vec::new(), &self.hpke_ctx.reply_pair.public_key().clone(), diff --git a/payjoin/src/uri/mod.rs b/payjoin/src/uri/mod.rs index 6281446c..7acf6853 100644 --- a/payjoin/src/uri/mod.rs +++ b/payjoin/src/uri/mod.rs @@ -17,6 +17,64 @@ pub mod error; #[cfg(feature = "v2")] pub(crate) mod url_ext; +#[cfg(feature = "v2")] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ShortId(pub [u8; 8]); + +#[cfg(feature = "v2")] +impl ShortId { + pub fn as_bytes(&self) -> &[u8] { &self.0 } + pub fn as_slice(&self) -> &[u8] { &self.0 } +} + +#[cfg(feature = "v2")] +impl std::fmt::Display for ShortId { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let id_hrp = bech32::Hrp::parse("ID").unwrap(); + f.write_str( + crate::bech32::nochecksum::encode(id_hrp, &self.0) + .expect("bech32 encoding of short ID must succeed") + .strip_prefix("ID1") + .expect("human readable part must be ID1"), + ) + } +} + +#[cfg(feature = "v2")] +#[derive(Debug)] +pub enum ShortIdError { + DecodeBech32(bech32::primitives::decode::CheckedHrpstringError), + IncorrectLength(std::array::TryFromSliceError), +} + +#[cfg(feature = "v2")] +impl std::convert::From for ShortId { + fn from(h: bitcoin::hashes::sha256::Hash) -> Self { + bitcoin::hashes::Hash::as_byte_array(&h)[..8] + .try_into() + .expect("truncating SHA256 to 8 bytes should always succeed") + } +} + +#[cfg(feature = "v2")] +impl std::convert::TryFrom<&[u8]> for ShortId { + type Error = ShortIdError; + fn try_from(bytes: &[u8]) -> Result { + let bytes: [u8; 8] = bytes.try_into().map_err(ShortIdError::IncorrectLength)?; + Ok(Self(bytes)) + } +} + +#[cfg(feature = "v2")] +impl std::str::FromStr for ShortId { + type Err = ShortIdError; + fn from_str(s: &str) -> Result { + let (_, bytes) = crate::bech32::nochecksum::decode(&("ID1".to_string() + s)) + .map_err(ShortIdError::DecodeBech32)?; + (&bytes[..]).try_into() + } +} + #[derive(Debug, Clone)] pub enum MaybePayjoinExtras { Supported(PayjoinExtras),