From d6f927cebc42856848491525ffa24e5b14126347 Mon Sep 17 00:00:00 2001 From: DanGould Date: Mon, 25 Nov 2024 14:19:14 -0500 Subject: [PATCH] Pass static size ohttp en/decapsulate arguments Take advantage of the edit to use `&[u8]` function signatures where applicable to reduce tech debt. --- payjoin-cli/src/app/v2.rs | 11 +++-------- payjoin-directory/src/lib.rs | 9 +++++++-- payjoin/src/ohttp.rs | 27 +++++++++++++++++++-------- payjoin/src/receive/v2/error.rs | 9 +++++++++ payjoin/src/receive/v2/mod.rs | 27 ++++++++++++++++++++------- payjoin/src/request.rs | 4 ++-- payjoin/src/send/error.rs | 6 ++++++ payjoin/src/send/mod.rs | 20 ++++++++++---------- payjoin/tests/integration.rs | 31 +++++++++++-------------------- 9 files changed, 87 insertions(+), 57 deletions(-) diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index c0fc3ade..045089e9 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -141,7 +141,7 @@ impl App { println!("Got a request from the sender. Responding with a Payjoin proposal."); let res = post_request(req).await?; payjoin_proposal - .process_res(res.bytes().await?.to_vec(), ohttp_ctx) + .process_res(&res.bytes().await?, ohttp_ctx) .map_err(|e| anyhow!("Failed to deserialize response {}", e))?; let payjoin_psbt = payjoin_proposal.psbt().clone(); println!( @@ -198,16 +198,11 @@ impl App { println!("Posting Original PSBT Payload request..."); let response = post_request(req).await?; println!("Sent fallback transaction"); - let v2_ctx = Arc::new( - ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, - ); + let v2_ctx = Arc::new(ctx.process_response(&response.bytes().await?)?); loop { let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?; let response = post_request(req).await?; - match v2_ctx.process_response( - &mut response.bytes().await?.to_vec().as_slice(), - ohttp_ctx, - ) { + match v2_ctx.process_response(&response.bytes().await?, ohttp_ctx) { Ok(Some(psbt)) => return Ok(psbt), Ok(None) => { println!("No response yet."); diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 7600089b..04a05de2 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -23,7 +23,11 @@ pub const DEFAULT_DIR_PORT: u16 = 8080; pub const DEFAULT_DB_HOST: &str = "localhost:6379"; pub const DEFAULT_TIMEOUT_SECS: u64 = 30; -const PADDED_BHTTP_BYTES: usize = 8192; +const ENCAPSULATED_MESSAGE_BYTES: usize = 8192; +const CHACHA20_POLY1305_NONCE_LEN: usize = 32; // chacha20poly1305 n_k +const POLY1305_TAG_SIZE: usize = 16; +pub const BHTTP_REQ_BYTES: usize = + ENCAPSULATED_MESSAGE_BYTES - (CHACHA20_POLY1305_NONCE_LEN + POLY1305_TAG_SIZE); const V1_MAX_BUFFER_SIZE: usize = 65536; const V1_REJECT_RES_JSON: &str = @@ -209,10 +213,11 @@ async fn handle_ohttp_gateway( bhttp_res .write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes) .map_err(|e| HandlerError::InternalServerError(e.into()))?; - bhttp_bytes.resize(PADDED_BHTTP_BYTES, 0); + bhttp_bytes.resize(BHTTP_REQ_BYTES, 0); let ohttp_res = res_ctx .encapsulate(&bhttp_bytes) .map_err(|e| HandlerError::InternalServerError(e.into()))?; + assert!(ohttp_res.len() == ENCAPSULATED_MESSAGE_BYTES, "Unexpected OHTTP response size"); Ok(Response::new(full(ohttp_res))) } diff --git a/payjoin/src/ohttp.rs b/payjoin/src/ohttp.rs index 2491cb21..dba9659b 100644 --- a/payjoin/src/ohttp.rs +++ b/payjoin/src/ohttp.rs @@ -3,15 +3,21 @@ use std::{error, fmt}; use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; use bitcoin::base64::Engine; +use bitcoin::key::constants::UNCOMPRESSED_PUBLIC_KEY_SIZE; -pub const PADDED_MESSAGE_BYTES: usize = 8192; +pub const ENCAPSULATED_MESSAGE_BYTES: usize = 8192; +const N_ENC: usize = UNCOMPRESSED_PUBLIC_KEY_SIZE; +const N_T: usize = crate::hpke::POLY1305_TAG_SIZE; +const OHTTP_REQ_HEADER_BYTES: usize = 7; +pub const PADDED_BHTTP_REQ_BYTES: usize = + ENCAPSULATED_MESSAGE_BYTES - (N_ENC + N_T + OHTTP_REQ_HEADER_BYTES); pub fn ohttp_encapsulate( ohttp_keys: &mut ohttp::KeyConfig, method: &str, target_resource: &str, body: Option<&[u8]>, -) -> Result<(Vec, ohttp::ClientResponse), OhttpEncapsulationError> { +) -> Result<([u8; ENCAPSULATED_MESSAGE_BYTES], ohttp::ClientResponse), OhttpEncapsulationError> { use std::fmt::Write; let ctx = ohttp::ClientRequest::from_config(ohttp_keys)?; @@ -33,17 +39,22 @@ pub fn ohttp_encapsulate( if let Some(body) = body { bhttp_message.write_content(body); } - let mut bhttp_req = Vec::new(); - let _ = bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_req); - bhttp_req.resize(PADDED_MESSAGE_BYTES, 0); - let encapsulated = ctx.encapsulate(&bhttp_req)?; - Ok(encapsulated) + + let mut bhttp_req = [0u8; PADDED_BHTTP_REQ_BYTES]; + let mut cursor = std::io::Cursor::new(&mut bhttp_req[..]); + let _ = bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut cursor); + let (encapsulated, ohttp_ctx) = ctx.encapsulate(&bhttp_req)?; + + let mut buffer = [0u8; ENCAPSULATED_MESSAGE_BYTES]; + let len = encapsulated.len().min(ENCAPSULATED_MESSAGE_BYTES); + buffer[..len].copy_from_slice(&encapsulated[..len]); + Ok((buffer, ohttp_ctx)) } /// decapsulate ohttp, bhttp response and return http response body and status code pub fn ohttp_decapsulate( res_ctx: ohttp::ClientResponse, - ohttp_body: &[u8], + ohttp_body: &[u8; ENCAPSULATED_MESSAGE_BYTES], ) -> Result>, OhttpEncapsulationError> { let bhttp_body = res_ctx.decapsulate(ohttp_body)?; let mut r = std::io::Cursor::new(bhttp_body); diff --git a/payjoin/src/receive/v2/error.rs b/payjoin/src/receive/v2/error.rs index 1a934dd3..22785819 100644 --- a/payjoin/src/receive/v2/error.rs +++ b/payjoin/src/receive/v2/error.rs @@ -12,6 +12,8 @@ pub(crate) enum InternalSessionError { Expired(std::time::SystemTime), /// OHTTP Encapsulation failed OhttpEncapsulation(OhttpEncapsulationError), + /// Unexpected response size + UnexpectedResponseSize(usize), } impl fmt::Display for SessionError { @@ -20,6 +22,12 @@ impl fmt::Display for SessionError { InternalSessionError::Expired(expiry) => write!(f, "Session expired at {:?}", expiry), InternalSessionError::OhttpEncapsulation(e) => write!(f, "OHTTP Encapsulation Error: {}", e), + InternalSessionError::UnexpectedResponseSize(size) => write!( + f, + "Unexpected response size {}, expected {} bytes", + size, + crate::ohttp::ENCAPSULATED_MESSAGE_BYTES + ), } } } @@ -29,6 +37,7 @@ impl error::Error for SessionError { match &self.0 { InternalSessionError::Expired(_) => None, InternalSessionError::OhttpEncapsulation(e) => Some(e), + InternalSessionError::UnexpectedResponseSize(_) => None, } } } diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index c62bc8c4..b0b46342 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -113,13 +113,17 @@ impl Receiver { /// indicating no UncheckedProposal is available yet. pub fn process_res( &mut self, - mut body: impl std::io::Read, + body: &[u8], context: ohttp::ClientResponse, ) -> Result, Error> { - let mut buf = Vec::new(); - let _ = body.read_to_end(&mut buf); + let response_array: &[u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES] = + body.try_into().map_err(|_| { + Error::Server(Box::new(SessionError::from( + InternalSessionError::UnexpectedResponseSize(body.len()), + ))) + })?; log::trace!("decapsulating directory response"); - let response = ohttp_decapsulate(context, &buf)?; + let response = ohttp_decapsulate(context, response_array)?; if response.body().is_empty() { log::debug!("response is empty"); return Ok(None); @@ -134,7 +138,10 @@ impl Receiver { fn fallback_req_body( &mut self, - ) -> Result<(Vec, ohttp::ClientResponse), OhttpEncapsulationError> { + ) -> Result< + ([u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES], ohttp::ClientResponse), + OhttpEncapsulationError, + > { let fallback_target = self.pj_url(); ohttp_encapsulate(&mut self.context.ohttp_keys, "GET", fallback_target.as_str(), None) } @@ -509,10 +516,16 @@ impl PayjoinProposal { /// choose to broadcast the original PSBT. pub fn process_res( &self, - res: Vec, + res: &[u8], ohttp_context: ohttp::ClientResponse, ) -> Result<(), Error> { - let res = ohttp_decapsulate(ohttp_context, &res)?; + let response_array: &[u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES] = + res.try_into().map_err(|_| { + Error::Server(Box::new(SessionError::from( + InternalSessionError::UnexpectedResponseSize(res.len()), + ))) + })?; + let res = ohttp_decapsulate(ohttp_context, response_array)?; if res.status().is_success() { Ok(()) } else { diff --git a/payjoin/src/request.rs b/payjoin/src/request.rs index a093aa10..efd8dac5 100644 --- a/payjoin/src/request.rs +++ b/payjoin/src/request.rs @@ -32,7 +32,7 @@ impl Request { } #[cfg(feature = "v2")] - pub fn new_v2(url: Url, body: Vec) -> Self { - Self { url, content_type: V2_REQ_CONTENT_TYPE, body } + pub fn new_v2(url: Url, body: [u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES]) -> Self { + Self { url, content_type: V2_REQ_CONTENT_TYPE, body: body.to_vec() } } } diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index 6a377b78..7de18a1f 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -66,6 +66,8 @@ pub(crate) enum InternalValidationError { OhttpEncapsulation(crate::ohttp::OhttpEncapsulationError), #[cfg(feature = "v2")] UnexpectedStatusCode, + #[cfg(feature = "v2")] + UnexpectedResponseSize(usize), } impl From for ValidationError { @@ -119,6 +121,8 @@ impl fmt::Display for ValidationError { OhttpEncapsulation(e) => write!(f, "Ohttp encapsulation error: {}", e), #[cfg(feature = "v2")] UnexpectedStatusCode => write!(f, "unexpected status code"), + #[cfg(feature = "v2")] + UnexpectedResponseSize(size) => write!(f, "unexpected response size {}, expected {} bytes", size, crate::ohttp::ENCAPSULATED_MESSAGE_BYTES), } } } @@ -164,6 +168,8 @@ impl std::error::Error for ValidationError { OhttpEncapsulation(error) => Some(error), #[cfg(feature = "v2")] UnexpectedStatusCode => None, + #[cfg(feature = "v2")] + UnexpectedResponseSize(_) => None, } } } diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index f31bb6f0..35cd6ae5 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -358,12 +358,9 @@ pub struct V2PostContext { #[cfg(feature = "v2")] impl V2PostContext { - pub fn process_response( - self, - response: &mut impl std::io::Read, - ) -> Result { - let mut res_buf = Vec::new(); - response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; + pub fn process_response(self, response: &[u8]) -> Result { + let mut res_buf = [0u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES]; + res_buf[..response.len()].copy_from_slice(response); let response = ohttp_decapsulate(self.ohttp_ctx, &res_buf) .map_err(InternalValidationError::OhttpEncapsulation)?; match response.status() { @@ -417,12 +414,15 @@ impl V2GetContext { pub fn process_response( &self, - response: &mut impl std::io::Read, + response: &[u8], ohttp_ctx: ohttp::ClientResponse, ) -> Result, ResponseError> { - let mut res_buf = Vec::new(); - response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; - let response = ohttp_decapsulate(ohttp_ctx, &res_buf) + let response_array: &[u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES] = + response + .try_into() + .map_err(|_| InternalValidationError::UnexpectedResponseSize(response.len()))?; + + let response = ohttp_decapsulate(ohttp_ctx, response_array) .map_err(InternalValidationError::OhttpEncapsulation)?; let body = match response.status() { http::StatusCode::OK => response.body().to_vec(), diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index ef0f3cba..becf82a3 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -368,8 +368,7 @@ mod integration { .unwrap(); log::info!("Response: {:#?}", &response); assert!(response.status().is_success()); - let send_ctx = - send_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?; + let send_ctx = send_ctx.process_response(&response.bytes().await?)?; // POST Original PSBT // ********************** @@ -390,8 +389,7 @@ mod integration { .body(req.body) .send() .await?; - let res = response.bytes().await?.to_vec(); - payjoin_proposal.process_res(res, ctx)?; + payjoin_proposal.process_res(&response.bytes().await?, ctx)?; // ********************** // Inside the Sender: @@ -407,9 +405,8 @@ mod integration { .await .unwrap(); log::info!("Response: {:#?}", &response); - let checked_payjoin_proposal_psbt = send_ctx - .process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)? - .unwrap(); + let checked_payjoin_proposal_psbt = + send_ctx.process_response(&response.bytes().await?, ohttp_ctx)?.unwrap(); let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?; log::info!("sent"); @@ -503,8 +500,7 @@ mod integration { let (req, ctx) = session.extract_req()?; let response = agent.post(req.url).body(req.body).send().await?; assert!(response.status().is_success()); - let response_body = - session.process_res(response.bytes().await?.to_vec().as_slice(), ctx).unwrap(); + let response_body = session.process_res(&response.bytes().await?, ctx).unwrap(); // No proposal yet since sender has not responded assert!(response_body.is_none()); @@ -530,8 +526,7 @@ mod integration { .unwrap(); log::info!("Response: {:#?}", &response); assert!(response.status().is_success()); - let get_ctx = - post_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?; + let get_ctx = post_ctx.process_response(&response.bytes().await?)?; let (Request { url, body, content_type, .. }, ohttp_ctx) = get_ctx.extract_req(directory.to_owned())?; let response = agent @@ -541,9 +536,7 @@ mod integration { .send() .await?; // No response body yet since we are async and pushed fallback_psbt to the buffer - assert!(get_ctx - .process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)? - .is_none()); + assert!(get_ctx.process_response(&response.bytes().await?, ohttp_ctx)?.is_none()); // ********************** // Inside the Receiver: @@ -560,8 +553,7 @@ mod integration { assert!(!payjoin_proposal.is_output_substitution_disabled()); let (req, ctx) = payjoin_proposal.extract_v2_req()?; let response = agent.post(req.url).body(req.body).send().await?; - let res = response.bytes().await?.to_vec(); - payjoin_proposal.process_res(res, ctx)?; + payjoin_proposal.process_res(&response.bytes().await?, ctx)?; // ********************** // Inside the Sender: @@ -575,9 +567,8 @@ mod integration { .body(body.clone()) .send() .await?; - let checked_payjoin_proposal_psbt = get_ctx - .process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)? - .unwrap(); + let checked_payjoin_proposal_psbt = + get_ctx.process_response(&response.bytes().await?, ohttp_ctx)?.unwrap(); let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?; log::info!("sent"); @@ -739,7 +730,7 @@ mod integration { let (req, ctx) = payjoin_proposal.extract_v2_req().unwrap(); let response = agent_clone.post(req.url).body(req.body).send().await?; payjoin_proposal - .process_res(response.bytes().await?.to_vec(), ctx) + .process_res(&response.bytes().await?, ctx) .map_err(|e| e.to_string())?; Ok::<_, Box>(()) });