diff --git a/payjoin-cli/src/app/mod.rs b/payjoin-cli/src/app/mod.rs index 5d8c000b..468a4b66 100644 --- a/payjoin-cli/src/app/mod.rs +++ b/payjoin-cli/src/app/mod.rs @@ -7,7 +7,7 @@ use bitcoin::TxIn; use bitcoincore_rpc::bitcoin::Amount; use bitcoincore_rpc::RpcApi; use payjoin::bitcoin::psbt::Psbt; -use payjoin::send::RequestContext; +use payjoin::send::Sender; use payjoin::{bitcoin, PjUri}; pub mod config; @@ -30,7 +30,7 @@ pub trait App { async fn send_payjoin(&self, bip21: &str, fee_rate: &f32) -> Result<()>; async fn receive_payjoin(self, amount_arg: &str) -> Result<()>; - fn create_pj_request(&self, uri: &PjUri, fee_rate: &f32) -> Result { + fn create_pj_request(&self, uri: &PjUri, fee_rate: &f32) -> Result { let amount = uri.amount.ok_or_else(|| anyhow!("please specify the amount in the Uri"))?; // wallet_create_funded_psbt requires a HashMap @@ -66,7 +66,7 @@ pub trait App { .psbt; let psbt = Psbt::from_str(&psbt).with_context(|| "Failed to load PSBT from base64")?; log::debug!("Original psbt: {:#?}", psbt); - let req_ctx = payjoin::send::RequestBuilder::from_psbt_and_uri(psbt, uri.clone()) + let req_ctx = payjoin::send::SenderBuilder::from_psbt_and_uri(psbt, uri.clone()) .with_context(|| "Failed to build payjoin request")? .build_recommended(fee_rate) .with_context(|| "Failed to build payjoin request")?; diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index 3a8c8a18..49a1e26b 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -7,8 +7,8 @@ use bitcoincore_rpc::RpcApi; use payjoin::bitcoin::consensus::encode::serialize_hex; use payjoin::bitcoin::psbt::Psbt; use payjoin::bitcoin::{Amount, FeeRate}; -use payjoin::receive::v2::ActiveSession; -use payjoin::send::RequestContext; +use payjoin::receive::v2::Receiver; +use payjoin::send::Sender; use payjoin::{bitcoin, Error, Uri}; use tokio::signal; use tokio::sync::watch; @@ -75,39 +75,23 @@ impl AppTrait for App { } async fn receive_payjoin(self, amount_arg: &str) -> Result<()> { - use payjoin::receive::v2::SessionInitializer; - let address = self.bitcoind()?.get_new_address(None, None)?.assume_checked(); let amount = Amount::from_sat(amount_arg.parse()?); let ohttp_keys = unwrap_ohttp_keys_or_else_fetch(&self.config).await?; - let mut initializer = SessionInitializer::new( + let session = Receiver::new( address, self.config.pj_directory.clone(), ohttp_keys.clone(), self.config.ohttp_relay.clone(), None, ); - let (req, ctx) = - initializer.extract_req().map_err(|e| anyhow!("Failed to extract request {}", e))?; - println!("Starting new Payjoin session with {}", self.config.pj_directory); - let http = http_agent()?; - let ohttp_response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; - let session = initializer - .process_res(ohttp_response.bytes().await?.to_vec().as_slice(), ctx) - .map_err(|e| anyhow!("Enrollment failed {}", e))?; self.db.insert_recv_session(session.clone())?; self.spawn_payjoin_receiver(session, Some(amount)).await } } impl App { - async fn spawn_payjoin_sender(&self, mut req_ctx: RequestContext) -> Result<()> { + async fn spawn_payjoin_sender(&self, mut req_ctx: Sender) -> Result<()> { let mut interrupt = self.interrupt.clone(); tokio::select! { res = self.long_poll_post(&mut req_ctx) => { @@ -123,7 +107,7 @@ impl App { async fn spawn_payjoin_receiver( &self, - mut session: ActiveSession, + mut session: Receiver, amount: Option, ) -> Result<()> { println!("Receive session established"); @@ -213,30 +197,57 @@ impl App { Ok(()) } - async fn long_poll_post(&self, req_ctx: &mut payjoin::send::RequestContext) -> Result { - loop { - let (req, ctx) = req_ctx.extract_v2(self.config.ohttp_relay.clone())?; - println!("Polling send request..."); - let http = http_agent()?; - let response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; - - println!("Sent fallback transaction"); - match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { - Ok(Some(psbt)) => return Ok(psbt), - Ok(None) => { - println!("No response yet."); - tokio::time::sleep(std::time::Duration::from_secs(5)).await; + async fn long_poll_post(&self, req_ctx: &mut payjoin::send::Sender) -> Result { + let (req, ctx) = req_ctx.extract_highest_version(self.config.ohttp_relay.clone())?; + println!("Posting Original PSBT Payload request..."); + let http = http_agent()?; + let response = http + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err)?; + println!("Sent fallback transaction"); + match ctx { + payjoin::send::Context::V2(ctx) => { + let v2_ctx = Arc::new( + ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, + ); + loop { + let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?; + let response = http + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err)?; + match v2_ctx.process_response( + &mut response.bytes().await?.to_vec().as_slice(), + ohttp_ctx, + ) { + Ok(Some(psbt)) => return Ok(psbt), + Ok(None) => { + println!("No response yet."); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + } + Err(re) => { + println!("{}", re); + log::debug!("{:?}", re); + return Err(anyhow!("Response error").context(re)); + } + } } - Err(re) => { - println!("{}", re); - log::debug!("{:?}", re); - return Err(anyhow!("Response error").context(re)); + } + payjoin::send::Context::V1(ctx) => { + match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { + Ok(psbt) => Ok(psbt), + Err(re) => { + println!("{}", re); + log::debug!("{:?}", re); + Err(anyhow!("Response error").context(re)) + } } } } @@ -244,7 +255,7 @@ impl App { async fn long_poll_fallback( &self, - session: &mut payjoin::receive::v2::ActiveSession, + session: &mut payjoin::receive::v2::Receiver, ) -> Result { loop { let (req, context) = session.extract_req()?; diff --git a/payjoin-cli/src/db/v2.rs b/payjoin-cli/src/db/v2.rs index 8ec7250b..a2168647 100644 --- a/payjoin-cli/src/db/v2.rs +++ b/payjoin-cli/src/db/v2.rs @@ -1,13 +1,13 @@ use bitcoincore_rpc::jsonrpc::serde_json; -use payjoin::receive::v2::ActiveSession; -use payjoin::send::RequestContext; +use payjoin::receive::v2::Receiver; +use payjoin::send::Sender; use sled::{IVec, Tree}; use url::Url; use super::*; impl Database { - pub(crate) fn insert_recv_session(&self, session: ActiveSession) -> Result<()> { + pub(crate) fn insert_recv_session(&self, session: Receiver) -> Result<()> { let recv_tree = self.0.open_tree("recv_sessions")?; let key = &session.id(); let value = serde_json::to_string(&session).map_err(Error::Serialize)?; @@ -16,13 +16,12 @@ impl Database { Ok(()) } - pub(crate) fn get_recv_sessions(&self) -> Result> { + pub(crate) fn get_recv_sessions(&self) -> Result> { let recv_tree = self.0.open_tree("recv_sessions")?; let mut sessions = Vec::new(); for item in recv_tree.iter() { let (_, value) = item?; - let session: ActiveSession = - serde_json::from_slice(&value).map_err(Error::Deserialize)?; + let session: Receiver = serde_json::from_slice(&value).map_err(Error::Deserialize)?; sessions.push(session); } Ok(sessions) @@ -35,11 +34,7 @@ impl Database { Ok(()) } - pub(crate) fn insert_send_session( - &self, - session: &mut RequestContext, - pj_url: &Url, - ) -> Result<()> { + pub(crate) fn insert_send_session(&self, session: &mut Sender, pj_url: &Url) -> Result<()> { let send_tree: Tree = self.0.open_tree("send_sessions")?; let value = serde_json::to_string(session).map_err(Error::Serialize)?; send_tree.insert(pj_url.to_string(), IVec::from(value.as_str()))?; @@ -47,23 +42,21 @@ impl Database { Ok(()) } - pub(crate) fn get_send_sessions(&self) -> Result> { + pub(crate) fn get_send_sessions(&self) -> Result> { let send_tree: Tree = self.0.open_tree("send_sessions")?; let mut sessions = Vec::new(); for item in send_tree.iter() { let (_, value) = item?; - let session: RequestContext = - serde_json::from_slice(&value).map_err(Error::Deserialize)?; + let session: Sender = serde_json::from_slice(&value).map_err(Error::Deserialize)?; sessions.push(session); } Ok(sessions) } - pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result> { + pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result> { let send_tree = self.0.open_tree("send_sessions")?; if let Some(val) = send_tree.get(pj_url.to_string())? { - let session: RequestContext = - serde_json::from_slice(&val).map_err(Error::Deserialize)?; + let session: Sender = serde_json::from_slice(&val).map_err(Error::Deserialize)?; Ok(Some(session)) } else { Ok(None) diff --git a/payjoin-cli/tests/e2e.rs b/payjoin-cli/tests/e2e.rs index 7eae8b43..5db46539 100644 --- a/payjoin-cli/tests/e2e.rs +++ b/payjoin-cli/tests/e2e.rs @@ -482,14 +482,7 @@ mod e2e { let db = docker.run(Redis::default()); let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); println!("Database running on {}", db.get_host_port_ipv4(6379)); - payjoin_directory::listen_tcp_with_tls( - format!("http://localhost:{}", port), - port, - db_host, - timeout, - local_cert_key, - ) - .await + payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await } // generates or gets a DER encoded localhost cert and key. diff --git a/payjoin-directory/src/db.rs b/payjoin-directory/src/db.rs index 26b69864..679a0f40 100644 --- a/payjoin-directory/src/db.rs +++ b/payjoin-directory/src/db.rs @@ -4,8 +4,8 @@ use futures::StreamExt; use redis::{AsyncCommands, Client, ErrorKind, RedisError, RedisResult}; use tracing::debug; -const RES_COLUMN: &str = "res"; -const REQ_COLUMN: &str = "req"; +const DEFAULT_COLUMN: &str = ""; +const PJ_V1_COLUMN: &str = "pjv1"; #[derive(Debug, Clone)] pub(crate) struct DbPool { @@ -19,20 +19,20 @@ impl DbPool { Ok(Self { client, timeout }) } - pub async fn peek_req(&self, pubkey_id: &str) -> Option>> { - self.peek_with_timeout(pubkey_id, REQ_COLUMN).await + pub async fn push_default(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { + self.push(pubkey_id, DEFAULT_COLUMN, data).await } - pub async fn peek_res(&self, pubkey_id: &str) -> Option>> { - self.peek_with_timeout(pubkey_id, RES_COLUMN).await + pub async fn peek_default(&self, pubkey_id: &str) -> Option>> { + self.peek_with_timeout(pubkey_id, DEFAULT_COLUMN).await } - pub async fn push_req(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { - self.push(pubkey_id, REQ_COLUMN, data).await + pub async fn push_v1(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { + self.push(pubkey_id, PJ_V1_COLUMN, data).await } - pub async fn push_res(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { - self.push(pubkey_id, RES_COLUMN, data).await + pub async fn peek_v1(&self, pubkey_id: &str) -> Option>> { + self.peek_with_timeout(pubkey_id, PJ_V1_COLUMN).await } async fn push(&self, pubkey_id: &str, channel_type: &str, data: Vec) -> RedisResult<()> { diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 875ff9fb..ef267c86 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -3,12 +3,10 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Result; -use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; -use bitcoin::base64::Engine; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Empty, Full}; use hyper::body::{Body, Bytes, Incoming}; -use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE, LOCATION}; +use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE}; use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{Method, Request, Response, StatusCode, Uri}; @@ -20,7 +18,6 @@ use tracing::{debug, error, info, trace}; pub const DEFAULT_DIR_PORT: u16 = 8080; pub const DEFAULT_DB_HOST: &str = "localhost:6379"; pub const DEFAULT_TIMEOUT_SECS: u64 = 30; -pub const DEFAULT_BASE_URL: &str = "https://localhost"; const MAX_BUFFER_SIZE: usize = 65536; @@ -32,7 +29,6 @@ mod db; use crate::db::DbPool; pub async fn listen_tcp( - base_url: String, port: u16, db_host: String, timeout: Duration, @@ -44,14 +40,13 @@ pub async fn listen_tcp( while let Ok((stream, _)) = listener.accept().await { let pool = pool.clone(); let ohttp = ohttp.clone(); - let base_url = base_url.clone(); let io = TokioIo::new(stream); tokio::spawn(async move { if let Err(err) = http1::Builder::new() .serve_connection( io, service_fn(move |req| { - serve_payjoin_directory(req, pool.clone(), ohttp.clone(), base_url.clone()) + serve_payjoin_directory(req, pool.clone(), ohttp.clone()) }), ) .with_upgrades() @@ -67,7 +62,6 @@ pub async fn listen_tcp( #[cfg(feature = "danger-local-https")] pub async fn listen_tcp_with_tls( - base_url: String, port: u16, db_host: String, timeout: Duration, @@ -81,7 +75,6 @@ pub async fn listen_tcp_with_tls( while let Ok((stream, _)) = listener.accept().await { let pool = pool.clone(); let ohttp = ohttp.clone(); - let base_url = base_url.clone(); let tls_acceptor = tls_acceptor.clone(); tokio::spawn(async move { let tls_stream = match tls_acceptor.accept(stream).await { @@ -95,7 +88,7 @@ pub async fn listen_tcp_with_tls( .serve_connection( TokioIo::new(tls_stream), service_fn(move |req| { - serve_payjoin_directory(req, pool.clone(), ohttp.clone(), base_url.clone()) + serve_payjoin_directory(req, pool.clone(), ohttp.clone()) }), ) .with_upgrades() @@ -146,7 +139,6 @@ async fn serve_payjoin_directory( req: Request, pool: DbPool, ohttp: Arc>, - base_url: String, ) -> Result>> { let path = req.uri().path().to_string(); let query = req.uri().query().unwrap_or_default().to_string(); @@ -155,7 +147,7 @@ async fn serve_payjoin_directory( let path_segments: Vec<&str> = path.split('/').collect(); debug!("serve_payjoin_directory: {:?}", &path_segments); let mut response = match (parts.method, path_segments.as_slice()) { - (Method::POST, ["", ""]) => handle_ohttp_gateway(body, pool, ohttp, base_url).await, + (Method::POST, ["", ""]) => handle_ohttp_gateway(body, pool, ohttp).await, (Method::GET, ["", "ohttp-keys"]) => get_ohttp_keys(&ohttp).await, (Method::POST, ["", id]) => post_fallback_v1(id, query, body, pool).await, (Method::GET, ["", "health"]) => health_check().await, @@ -173,7 +165,6 @@ async fn handle_ohttp_gateway( body: Incoming, pool: DbPool, ohttp: Arc>, - base_url: String, ) -> Result>, HandlerError> { // decapsulate let ohttp_body = @@ -199,7 +190,7 @@ async fn handle_ohttp_gateway( } let request = http_req.body(full(body))?; - let response = handle_v2(pool, base_url, request).await?; + let response = handle_v2(pool, request).await?; let (parts, body) = response.into_parts(); let mut bhttp_res = bhttp::Message::response(parts.status.as_u16()); @@ -221,7 +212,6 @@ async fn handle_ohttp_gateway( async fn handle_v2( pool: DbPool, - base_url: String, req: Request>, ) -> Result>, HandlerError> { let path = req.uri().path().to_string(); @@ -230,10 +220,9 @@ async fn handle_v2( let path_segments: Vec<&str> = path.split('/').collect(); debug!("handle_v2: {:?}", &path_segments); match (parts.method, path_segments.as_slice()) { - (Method::POST, &["", ""]) => post_session(base_url, body).await, - (Method::POST, &["", id]) => post_fallback_v2(id, body, pool).await, - (Method::GET, &["", id]) => get_fallback(id, pool).await, - (Method::PUT, &["", id]) => post_payjoin(id, body, pool).await, + (Method::POST, &["", id]) => post_subdir(id, body, pool).await, + (Method::GET, &["", id]) => get_subdir(id, pool).await, + (Method::PUT, &["", id]) => put_payjoin_v1(id, body, pool).await, _ => Ok(not_found()), } } @@ -282,24 +271,6 @@ impl From for HandlerError { fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) } } -async fn post_session( - base_url: String, - body: BoxBody, -) -> Result>, HandlerError> { - let bytes = body.collect().await.map_err(|e| HandlerError::BadRequest(e.into()))?.to_bytes(); - let base64_id = - String::from_utf8(bytes.to_vec()).map_err(|e| HandlerError::BadRequest(e.into()))?; - let pubkey_bytes: Vec = - BASE64_URL_SAFE_NO_PAD.decode(base64_id).map_err(|e| HandlerError::BadRequest(e.into()))?; - let pubkey = bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes) - .map_err(|e| HandlerError::BadRequest(e.into()))?; - tracing::info!("Initialized session with pubkey: {:?}", pubkey); - Ok(Response::builder() - .header(LOCATION, format!("{}/{}", base_url, pubkey)) - .status(StatusCode::CREATED) - .body(empty())?) -} - async fn post_fallback_v1( id: &str, query: String, @@ -323,27 +294,49 @@ async fn post_fallback_v1( Err(_) => return Ok(bad_request_body_res), }; - let v2_compat_body = full(format!("{}\n{}", body_str, query)); - post_fallback(id, v2_compat_body, pool, none_response).await + let v2_compat_body = format!("{}\n{}", body_str, query); + let id = shorten_string(id); + pool.push_default(&id, v2_compat_body.into()) + .await + .map_err(|e| HandlerError::BadRequest(e.into()))?; + match pool.peek_v1(&id).await { + Some(result) => match result { + Ok(buffered_req) => Ok(Response::new(full(buffered_req))), + Err(e) => Err(HandlerError::BadRequest(e.into())), + }, + None => Ok(none_response), + } } -async fn post_fallback_v2( +async fn put_payjoin_v1( id: &str, body: BoxBody, pool: DbPool, ) -> Result>, HandlerError> { - trace!("Post fallback v2"); - let none_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?; - post_fallback(id, body, pool, none_response).await + trace!("Put_payjoin_v1"); + let ok_response = Response::builder().status(StatusCode::OK).body(empty())?; + + let id = shorten_string(id); + let req = + body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); + if req.len() > MAX_BUFFER_SIZE { + return Err(HandlerError::PayloadTooLarge); + } + + match pool.push_v1(&id, req.into()).await { + Ok(_) => Ok(ok_response), + Err(e) => Err(HandlerError::BadRequest(e.into())), + } } -async fn post_fallback( +async fn post_subdir( id: &str, body: BoxBody, pool: DbPool, - none_response: Response>, ) -> Result>, HandlerError> { - tracing::trace!("Post fallback"); + let none_response = Response::builder().status(StatusCode::OK).body(empty())?; + trace!("post_subdir"); + let id = shorten_string(id); let req = body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); @@ -351,27 +344,19 @@ async fn post_fallback( return Err(HandlerError::PayloadTooLarge); } - match pool.push_req(&id, req.into()).await { - Ok(_) => (), - Err(e) => return Err(HandlerError::BadRequest(e.into())), - }; - - match pool.peek_res(&id).await { - Some(result) => match result { - Ok(buffered_res) => Ok(Response::new(full(buffered_res))), - Err(e) => Err(HandlerError::BadRequest(e.into())), - }, - None => Ok(none_response), + match pool.push_default(&id, req.into()).await { + Ok(_) => Ok(none_response), + Err(e) => Err(HandlerError::BadRequest(e.into())), } } -async fn get_fallback( +async fn get_subdir( id: &str, pool: DbPool, ) -> Result>, HandlerError> { - trace!("GET fallback"); + trace!("get_subdir"); let id = shorten_string(id); - match pool.peek_req(&id).await { + match pool.peek_default(&id).await { Some(result) => match result { Ok(buffered_req) => Ok(Response::new(full(buffered_req))), Err(e) => Err(HandlerError::BadRequest(e.into())), @@ -380,22 +365,6 @@ async fn get_fallback( } } -async fn post_payjoin( - id: &str, - body: BoxBody, - pool: DbPool, -) -> Result>, HandlerError> { - trace!("POST payjoin"); - let id = shorten_string(id); - let res = - body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); - - match pool.push_res(&id, res.into()).await { - Ok(_) => Ok(Response::builder().status(StatusCode::NO_CONTENT).body(empty())?), - Err(e) => Err(HandlerError::BadRequest(e.into())), - } -} - fn not_found() -> Response> { let mut res = Response::default(); *res.status_mut() = StatusCode::NOT_FOUND; @@ -425,32 +394,3 @@ fn empty() -> BoxBody { fn full>(chunk: T) -> BoxBody { Full::new(chunk.into()).map_err(|never| match never {}).boxed() } - -#[cfg(test)] -mod tests { - use hyper::Request; - - use super::*; - - /// Ensure that the POST / endpoint returns a 201 Created with a Location header - /// as is semantically correct when creating a resource. - /// - /// https://datatracker.ietf.org/doc/html/rfc9110#name-post - #[tokio::test] - async fn test_post_session() -> Result<(), Box> { - let base_url = "https://localhost".to_string(); - let body = full("A6z245ZfDfnlk7_HiAp6sPmNaVYwADih-vCGE3eysWp7"); - - let request = Request::builder().method(Method::POST).uri("/").body(body)?; - - let response = post_session(base_url.clone(), request.into_body()) - .await - .map_err(|e| format!("{:?}", e))?; - - assert_eq!(response.status(), StatusCode::CREATED); - assert!(response.headers().contains_key(LOCATION)); - let location_header = response.headers().get(LOCATION).ok_or("Missing LOCATION header")?; - assert!(location_header.to_str()?.starts_with(&base_url)); - Ok(()) - } -} diff --git a/payjoin-directory/src/main.rs b/payjoin-directory/src/main.rs index 13d04cff..39dcd8c6 100644 --- a/payjoin-directory/src/main.rs +++ b/payjoin-directory/src/main.rs @@ -17,9 +17,7 @@ async fn main() -> Result<(), Box> { let db_host = env::var("PJ_DB_HOST").unwrap_or_else(|_| DEFAULT_DB_HOST.to_string()); - let base_url = env::var("PJ_DIR_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()); - - payjoin_directory::listen_tcp(base_url, dir_port, db_host, timeout).await + payjoin_directory::listen_tcp(dir_port, db_host, timeout).await } fn init_logging() { diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs new file mode 100644 index 00000000..c5cfb429 --- /dev/null +++ b/payjoin/src/hpke.rs @@ -0,0 +1,469 @@ +use std::ops::Deref; +use std::{error, fmt}; + +use bitcoin::key::constants::{ELLSWIFT_ENCODING_SIZE, UNCOMPRESSED_PUBLIC_KEY_SIZE}; +use bitcoin::secp256k1::ellswift::ElligatorSwift; +use hpke::aead::ChaCha20Poly1305; +use hpke::kdf::HkdfSha256; +use hpke::kem::SecpK256HkdfSha256; +use hpke::rand_core::OsRng; +use hpke::{Deserializable, OpModeR, OpModeS, Serializable}; +use serde::{Deserialize, Serialize}; + +pub const PADDED_MESSAGE_BYTES: usize = 7168; +pub const PADDED_PLAINTEXT_A_LENGTH: usize = PADDED_MESSAGE_BYTES + - (ELLSWIFT_ENCODING_SIZE + UNCOMPRESSED_PUBLIC_KEY_SIZE + POLY1305_TAG_SIZE); +pub const PADDED_PLAINTEXT_B_LENGTH: usize = + PADDED_MESSAGE_BYTES - (ELLSWIFT_ENCODING_SIZE + POLY1305_TAG_SIZE); +pub const POLY1305_TAG_SIZE: usize = 16; // FIXME there is a U16 defined for poly1305, should bitcoin hpke re-export it? +pub const INFO_A: &[u8; 8] = b"PjV2MsgA"; +pub const INFO_B: &[u8; 8] = b"PjV2MsgB"; + +pub type SecretKey = ::PrivateKey; +pub type PublicKey = ::PublicKey; +pub type EncappedKey = ::EncappedKey; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct HpkeKeyPair(pub HpkeSecretKey, pub HpkePublicKey); + +impl From for (HpkeSecretKey, HpkePublicKey) { + fn from(value: HpkeKeyPair) -> Self { (value.0, value.1) } +} + +impl HpkeKeyPair { + pub fn gen_keypair() -> Self { + let (sk, pk) = ::gen_keypair(&mut OsRng); + Self(HpkeSecretKey(sk), HpkePublicKey(pk)) + } + pub fn secret_key(&self) -> &HpkeSecretKey { &self.0 } + pub fn public_key(&self) -> &HpkePublicKey { &self.1 } +} + +fn encapped_key_from_ellswift_bytes(encoded: &[u8]) -> Result { + let mut buf = [0u8; ELLSWIFT_ENCODING_SIZE]; + buf.copy_from_slice(encoded); + let ellswift = ElligatorSwift::from_array(buf); + let pk = bitcoin::secp256k1::PublicKey::from_ellswift(ellswift); + Ok(EncappedKey::from_bytes(pk.serialize_uncompressed().as_slice())?) +} + +fn ellswift_bytes_from_encapped_key( + enc: &EncappedKey, +) -> Result<[u8; ELLSWIFT_ENCODING_SIZE], HpkeError> { + let uncompressed = enc.to_bytes(); + let pk = bitcoin::secp256k1::PublicKey::from_slice(&uncompressed)?; + let ellswift = ElligatorSwift::from_pubkey(pk); + Ok(ellswift.to_array()) +} + +#[derive(Clone, PartialEq, Eq)] +pub struct HpkeSecretKey(pub SecretKey); + +impl Deref for HpkeSecretKey { + type Target = SecretKey; + + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl core::fmt::Debug for HpkeSecretKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SecpHpkeSecretKey([REDACTED])") + } +} + +impl serde::Serialize for HpkeSecretKey { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(&self.0.to_bytes()) + } +} + +impl<'de> serde::Deserialize<'de> for HpkeSecretKey { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let bytes = Vec::::deserialize(deserializer)?; + Ok(HpkeSecretKey( + SecretKey::from_bytes(&bytes) + .map_err(|_| serde::de::Error::custom("Invalid secret key"))?, + )) + } +} + +#[derive(Clone, PartialEq, Eq)] +pub struct HpkePublicKey(pub PublicKey); + +impl HpkePublicKey { + pub fn to_compressed_bytes(&self) -> [u8; 33] { + let compressed_key = bitcoin::secp256k1::PublicKey::from_slice(&self.0.to_bytes()) + .expect("Invalid public key from known valid bytes"); + compressed_key.serialize() + } + + pub fn from_compressed_bytes(bytes: &[u8]) -> Result { + let compressed_key = bitcoin::secp256k1::PublicKey::from_slice(bytes)?; + Ok(HpkePublicKey(PublicKey::from_bytes( + compressed_key.serialize_uncompressed().as_slice(), + )?)) + } +} + +impl Deref for HpkePublicKey { + type Target = PublicKey; + + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl core::fmt::Debug for HpkePublicKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SecpHpkePublicKey({:?})", self.0) + } +} + +impl serde::Serialize for HpkePublicKey { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(&self.0.to_bytes()) + } +} + +impl<'de> serde::Deserialize<'de> for HpkePublicKey { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let bytes = Vec::::deserialize(deserializer)?; + Ok(HpkePublicKey( + PublicKey::from_bytes(&bytes) + .map_err(|_| serde::de::Error::custom("Invalid public key"))?, + )) + } +} + +/// Message A is sent from the sender to the receiver containing an Original PSBT payload +#[cfg(feature = "send")] +pub fn encrypt_message_a( + body: Vec, + reply_pk: &HpkePublicKey, + receiver_pk: &HpkePublicKey, +) -> Result, HpkeError> { + let (encapsulated_key, mut encryption_context) = + hpke::setup_sender::( + &OpModeS::Base, + &receiver_pk.0, + INFO_A, + &mut OsRng, + )?; + let mut body = body; + pad_plaintext(&mut body, PADDED_PLAINTEXT_A_LENGTH)?; + let mut plaintext = reply_pk.to_bytes().to_vec(); + plaintext.extend(body); + let ciphertext = encryption_context.seal(&plaintext, &[])?; + let mut message_a = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec(); + message_a.extend(&ciphertext); + Ok(message_a.to_vec()) +} + +#[cfg(feature = "receive")] +pub fn decrypt_message_a( + message_a: &[u8], + receiver_sk: HpkeSecretKey, +) -> Result<(Vec, HpkePublicKey), HpkeError> { + use std::io::{Cursor, Read}; + + let mut cursor = Cursor::new(message_a); + + let mut enc_bytes = [0u8; ELLSWIFT_ENCODING_SIZE]; + cursor.read_exact(&mut enc_bytes).map_err(|_| HpkeError::PayloadTooShort)?; + let enc = encapped_key_from_ellswift_bytes(&enc_bytes)?; + + let mut decryption_ctx = hpke::setup_receiver::< + ChaCha20Poly1305, + HkdfSha256, + SecpK256HkdfSha256, + >(&OpModeR::Base, &receiver_sk.0, &enc, INFO_A)?; + + let mut ciphertext = Vec::new(); + cursor.read_to_end(&mut ciphertext).map_err(|_| HpkeError::PayloadTooShort)?; + let plaintext = decryption_ctx.open(&ciphertext, &[])?; + + let reply_pk_bytes = &plaintext[..UNCOMPRESSED_PUBLIC_KEY_SIZE]; + let reply_pk = HpkePublicKey(PublicKey::from_bytes(reply_pk_bytes)?); + + let body = &plaintext[UNCOMPRESSED_PUBLIC_KEY_SIZE..]; + + Ok((body.to_vec(), reply_pk)) +} + +/// Message B is sent from the receiver to the sender containing a Payjoin PSBT payload or an error +#[cfg(feature = "receive")] +pub fn encrypt_message_b( + mut plaintext: Vec, + receiver_keypair: &HpkeKeyPair, + sender_pk: &HpkePublicKey, +) -> Result, HpkeError> { + let (encapsulated_key, mut encryption_context) = + hpke::setup_sender::( + &OpModeS::Auth(( + receiver_keypair.secret_key().0.clone(), + receiver_keypair.public_key().0.clone(), + )), + &sender_pk.0, + INFO_B, + &mut OsRng, + )?; + let plaintext: &[u8] = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_B_LENGTH)?; + let ciphertext = encryption_context.seal(plaintext, &[])?; + let mut message_b = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec(); + message_b.extend(&ciphertext); + Ok(message_b.to_vec()) +} + +#[cfg(feature = "send")] +pub fn decrypt_message_b( + message_b: &[u8], + receiver_pk: HpkePublicKey, + sender_sk: HpkeSecretKey, +) -> Result, HpkeError> { + let enc = message_b.get(..ELLSWIFT_ENCODING_SIZE).ok_or(HpkeError::PayloadTooShort)?; + let enc = encapped_key_from_ellswift_bytes(enc)?; + let mut decryption_ctx = hpke::setup_receiver::< + ChaCha20Poly1305, + HkdfSha256, + SecpK256HkdfSha256, + >(&OpModeR::Auth(receiver_pk.0), &sender_sk.0, &enc, INFO_B)?; + let plaintext = decryption_ctx + .open(message_b.get(ELLSWIFT_ENCODING_SIZE..).ok_or(HpkeError::PayloadTooShort)?, &[])?; + Ok(plaintext) +} + +fn pad_plaintext(msg: &mut Vec, padded_length: usize) -> Result<&[u8], HpkeError> { + if msg.len() > padded_length { + return Err(HpkeError::PayloadTooLarge { actual: msg.len(), max: padded_length }); + } + msg.resize(padded_length, 0); + Ok(msg) +} + +/// Error from de/encrypting a v2 Hybrid Public Key Encryption payload. +#[derive(Debug, PartialEq)] +pub enum HpkeError { + Secp256k1(bitcoin::secp256k1::Error), + Hpke(hpke::HpkeError), + InvalidKeyLength, + PayloadTooLarge { actual: usize, max: usize }, + PayloadTooShort, +} + +impl From for HpkeError { + fn from(value: hpke::HpkeError) -> Self { Self::Hpke(value) } +} + +impl From for HpkeError { + fn from(value: bitcoin::secp256k1::Error) -> Self { Self::Secp256k1(value) } +} + +impl fmt::Display for HpkeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use HpkeError::*; + + match &self { + Hpke(e) => e.fmt(f), + InvalidKeyLength => write!(f, "Invalid Length"), + PayloadTooLarge { actual, max } => { + write!( + f, + "Plaintext too large, max size is {} bytes, actual size is {} bytes", + max, actual + ) + } + PayloadTooShort => write!(f, "Payload too small"), + Secp256k1(e) => e.fmt(f), + } + } +} + +impl error::Error for HpkeError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + use HpkeError::*; + + match &self { + Hpke(e) => Some(e), + PayloadTooLarge { .. } => None, + InvalidKeyLength | PayloadTooShort => None, + Secp256k1(e) => Some(e), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn message_a_round_trip() { + let mut plaintext = "foo".as_bytes().to_vec(); + + let reply_keypair = HpkeKeyPair::gen_keypair(); + let receiver_keypair = HpkeKeyPair::gen_keypair(); + + let message_a = encrypt_message_a( + plaintext.clone(), + reply_keypair.public_key(), + receiver_keypair.public_key(), + ) + .expect("encryption should work"); + assert_eq!(message_a.len(), PADDED_MESSAGE_BYTES); + + let decrypted = decrypt_message_a(&message_a, receiver_keypair.secret_key().clone()) + .expect("decryption should work"); + + assert_eq!(decrypted.0.len(), PADDED_PLAINTEXT_A_LENGTH); + + // decrypted plaintext is padded, so pad the expected plaintext + plaintext.resize(PADDED_PLAINTEXT_A_LENGTH, 0); + assert_eq!(decrypted, (plaintext.to_vec(), reply_keypair.public_key().clone())); + + // ensure full plaintext round trips + plaintext[PADDED_PLAINTEXT_A_LENGTH - 1] = 42; + let message_a = encrypt_message_a( + plaintext.clone(), + reply_keypair.public_key(), + receiver_keypair.public_key(), + ) + .expect("encryption should work"); + + let decrypted = decrypt_message_a(&message_a, receiver_keypair.secret_key().clone()) + .expect("decryption should work"); + + assert_eq!(decrypted.0.len(), plaintext.len()); + assert_eq!(decrypted, (plaintext.to_vec(), reply_keypair.public_key().clone())); + + let unrelated_keypair = HpkeKeyPair::gen_keypair(); + assert_eq!( + decrypt_message_a(&message_a, unrelated_keypair.secret_key().clone()), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + let mut corrupted_message_a = message_a.clone(); + corrupted_message_a[3] ^= 1; // corrupt dhkem + assert_eq!( + decrypt_message_a(&corrupted_message_a, receiver_keypair.secret_key().clone()), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + let mut corrupted_message_a = message_a.clone(); + corrupted_message_a[PADDED_MESSAGE_BYTES - 3] ^= 1; // corrupt aead ciphertext + assert_eq!( + decrypt_message_a(&corrupted_message_a, receiver_keypair.secret_key().clone()), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + plaintext.resize(PADDED_PLAINTEXT_A_LENGTH + 1, 0); + assert_eq!( + encrypt_message_a( + plaintext.clone(), + reply_keypair.public_key(), + receiver_keypair.public_key(), + ), + Err(HpkeError::PayloadTooLarge { + actual: PADDED_PLAINTEXT_A_LENGTH + 1, + max: PADDED_PLAINTEXT_A_LENGTH, + }) + ); + } + + #[test] + fn message_b_round_trip() { + let mut plaintext = "foo".as_bytes().to_vec(); + + let reply_keypair = HpkeKeyPair::gen_keypair(); + let receiver_keypair = HpkeKeyPair::gen_keypair(); + + let message_b = + encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()) + .expect("encryption should work"); + + assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES); + + let decrypted = decrypt_message_b( + &message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone(), + ) + .expect("decryption should work"); + + assert_eq!(decrypted.len(), PADDED_PLAINTEXT_B_LENGTH); + // decrypted plaintext is padded, so pad the expected plaintext + plaintext.resize(PADDED_PLAINTEXT_B_LENGTH, 0); + assert_eq!(decrypted, plaintext.to_vec()); + + plaintext[PADDED_PLAINTEXT_B_LENGTH - 1] = 42; + let message_b = + encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()) + .expect("encryption should work"); + + assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES); + + let decrypted = decrypt_message_b( + &message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone(), + ) + .expect("decryption should work"); + assert_eq!(decrypted.len(), plaintext.len()); + assert_eq!(decrypted, plaintext.to_vec()); + + let unrelated_keypair = HpkeKeyPair::gen_keypair(); + assert_eq!( + decrypt_message_b( + &message_b, + receiver_keypair.public_key().clone(), + unrelated_keypair.secret_key().clone() // wrong decryption key + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + assert_eq!( + decrypt_message_b( + &message_b, + unrelated_keypair.public_key().clone(), // wrong auth key + reply_keypair.secret_key().clone() + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + let mut corrupted_message_b = message_b.clone(); + corrupted_message_b[3] ^= 1; // corrupt dhkem + assert_eq!( + decrypt_message_b( + &corrupted_message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone() + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + let mut corrupted_message_b = message_b.clone(); + corrupted_message_b[PADDED_MESSAGE_BYTES - 3] ^= 1; // corrupt aead ciphertext + assert_eq!( + decrypt_message_b( + &corrupted_message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone() + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + plaintext.resize(PADDED_PLAINTEXT_B_LENGTH + 1, 0); + assert_eq!( + encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()), + Err(HpkeError::PayloadTooLarge { + actual: PADDED_PLAINTEXT_B_LENGTH + 1, + max: PADDED_PLAINTEXT_B_LENGTH + }) + ); + } +} diff --git a/payjoin/src/lib.rs b/payjoin/src/lib.rs index b1bb0345..8ebf3762 100644 --- a/payjoin/src/lib.rs +++ b/payjoin/src/lib.rs @@ -28,9 +28,11 @@ pub use crate::receive::Error; pub mod send; #[cfg(feature = "v2")] -pub(crate) mod v2; +pub(crate) mod hpke; #[cfg(feature = "v2")] -pub use v2::OhttpKeys; +pub(crate) mod ohttp; +#[cfg(feature = "v2")] +pub use crate::ohttp::OhttpKeys; #[cfg(feature = "io")] pub mod io; diff --git a/payjoin/src/ohttp.rs b/payjoin/src/ohttp.rs new file mode 100644 index 00000000..9bd7d147 --- /dev/null +++ b/payjoin/src/ohttp.rs @@ -0,0 +1,255 @@ +use std::ops::{Deref, DerefMut}; +use std::{error, fmt}; + +use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; +use bitcoin::base64::Engine; + +pub fn ohttp_encapsulate( + ohttp_keys: &mut ohttp::KeyConfig, + method: &str, + target_resource: &str, + body: Option<&[u8]>, +) -> Result<(Vec, ohttp::ClientResponse), OhttpEncapsulationError> { + use std::fmt::Write; + + let ctx = ohttp::ClientRequest::from_config(ohttp_keys)?; + let url = url::Url::parse(target_resource)?; + let authority_bytes = url.host().map_or_else(Vec::new, |host| { + let mut authority = host.to_string(); + if let Some(port) = url.port() { + write!(authority, ":{}", port).unwrap(); + } + authority.into_bytes() + }); + let mut bhttp_message = bhttp::Message::request( + method.as_bytes().to_vec(), + url.scheme().as_bytes().to_vec(), + authority_bytes, + url.path().as_bytes().to_vec(), + ); + // None of our messages include headers, so we don't add them + if let Some(body) = body { + bhttp_message.write_content(body); + } + let mut bhttp_req = Vec::new(); + let _ = bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_req); + let encapsulated = ctx.encapsulate(&bhttp_req)?; + Ok(encapsulated) +} + +/// decapsulate ohttp, bhttp response and return http response body and status code +pub fn ohttp_decapsulate( + res_ctx: ohttp::ClientResponse, + ohttp_body: &[u8], +) -> Result>, OhttpEncapsulationError> { + let bhttp_body = res_ctx.decapsulate(ohttp_body)?; + let mut r = std::io::Cursor::new(bhttp_body); + let m: bhttp::Message = bhttp::Message::read_bhttp(&mut r)?; + let mut builder = http::Response::builder(); + for field in m.header().iter() { + builder = builder.header(field.name(), field.value()); + } + builder + .status(m.control().status().unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR.into())) + .body(m.content().to_vec()) + .map_err(OhttpEncapsulationError::Http) +} + +/// Error from de/encapsulating an Oblivious HTTP request or response. +#[derive(Debug)] +pub enum OhttpEncapsulationError { + Http(http::Error), + Ohttp(ohttp::Error), + Bhttp(bhttp::Error), + ParseUrl(url::ParseError), +} + +impl From for OhttpEncapsulationError { + fn from(value: http::Error) -> Self { Self::Http(value) } +} + +impl From for OhttpEncapsulationError { + fn from(value: ohttp::Error) -> Self { Self::Ohttp(value) } +} + +impl From for OhttpEncapsulationError { + fn from(value: bhttp::Error) -> Self { Self::Bhttp(value) } +} + +impl From for OhttpEncapsulationError { + fn from(value: url::ParseError) -> Self { Self::ParseUrl(value) } +} + +impl fmt::Display for OhttpEncapsulationError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use OhttpEncapsulationError::*; + + match &self { + Http(e) => e.fmt(f), + Ohttp(e) => e.fmt(f), + Bhttp(e) => e.fmt(f), + ParseUrl(e) => e.fmt(f), + } + } +} + +impl error::Error for OhttpEncapsulationError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + use OhttpEncapsulationError::*; + + match &self { + Http(e) => Some(e), + Ohttp(e) => Some(e), + Bhttp(e) => Some(e), + ParseUrl(e) => Some(e), + } + } +} + +#[derive(Debug, Clone)] +pub struct OhttpKeys(pub ohttp::KeyConfig); + +impl OhttpKeys { + /// Decode an OHTTP KeyConfig + pub fn decode(bytes: &[u8]) -> Result { + ohttp::KeyConfig::decode(bytes).map(Self) + } +} + +const KEM_ID: &[u8] = b"\x00\x16"; // DHKEM(secp256k1, HKDF-SHA256) +const SYMMETRIC_LEN: &[u8] = b"\x00\x04"; // 4 bytes +const SYMMETRIC_KDF_AEAD: &[u8] = b"\x00\x01\x00\x03"; // KDF(HKDF-SHA256), AEAD(ChaCha20Poly1305) + +impl fmt::Display for OhttpKeys { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let bytes = self.encode().map_err(|_| fmt::Error)?; + let key_id = bytes[0]; + let pubkey = &bytes[3..68]; + + let compressed_pubkey = + bitcoin::secp256k1::PublicKey::from_slice(pubkey).map_err(|_| fmt::Error)?.serialize(); + + let mut buf = vec![key_id]; + buf.extend_from_slice(&compressed_pubkey); + + let encoded = BASE64_URL_SAFE_NO_PAD.encode(buf); + write!(f, "{}", encoded) + } +} + +impl std::str::FromStr for OhttpKeys { + type Err = ParseOhttpKeysError; + + /// Parses a base64URL-encoded string into OhttpKeys. + /// The string format is: key_id || compressed_public_key + fn from_str(s: &str) -> Result { + let bytes = BASE64_URL_SAFE_NO_PAD.decode(s).map_err(ParseOhttpKeysError::DecodeBase64)?; + + let key_id = *bytes.first().ok_or(ParseOhttpKeysError::InvalidFormat)?; + let compressed_pk = bytes.get(1..34).ok_or(ParseOhttpKeysError::InvalidFormat)?; + + let pubkey = bitcoin::secp256k1::PublicKey::from_slice(compressed_pk) + .map_err(|_| ParseOhttpKeysError::InvalidPublicKey)?; + + let mut buf = vec![key_id]; + buf.extend_from_slice(KEM_ID); + buf.extend_from_slice(&pubkey.serialize_uncompressed()); + buf.extend_from_slice(SYMMETRIC_LEN); + buf.extend_from_slice(SYMMETRIC_KDF_AEAD); + + ohttp::KeyConfig::decode(&buf).map(Self).map_err(ParseOhttpKeysError::DecodeKeyConfig) + } +} + +impl PartialEq for OhttpKeys { + fn eq(&self, other: &Self) -> bool { + match (self.encode(), other.encode()) { + (Ok(self_encoded), Ok(other_encoded)) => self_encoded == other_encoded, + // If OhttpKeys::encode(&self) is Err, return false + _ => false, + } + } +} + +impl Eq for OhttpKeys {} + +impl Deref for OhttpKeys { + type Target = ohttp::KeyConfig; + + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl DerefMut for OhttpKeys { + fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } +} + +impl<'de> serde::Deserialize<'de> for OhttpKeys { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let bytes = Vec::::deserialize(deserializer)?; + OhttpKeys::decode(&bytes).map_err(serde::de::Error::custom) + } +} + +impl serde::Serialize for OhttpKeys { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let bytes = self.encode().map_err(serde::ser::Error::custom)?; + bytes.serialize(serializer) + } +} + +#[derive(Debug)] +pub enum ParseOhttpKeysError { + InvalidFormat, + InvalidPublicKey, + DecodeBase64(bitcoin::base64::DecodeError), + DecodeKeyConfig(ohttp::Error), +} + +impl std::fmt::Display for ParseOhttpKeysError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ParseOhttpKeysError::InvalidFormat => write!(f, "Invalid format"), + ParseOhttpKeysError::InvalidPublicKey => write!(f, "Invalid public key"), + ParseOhttpKeysError::DecodeBase64(e) => write!(f, "Failed to decode base64: {}", e), + ParseOhttpKeysError::DecodeKeyConfig(e) => + write!(f, "Failed to decode KeyConfig: {}", e), + } + } +} + +impl std::error::Error for ParseOhttpKeysError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + ParseOhttpKeysError::DecodeBase64(e) => Some(e), + ParseOhttpKeysError::DecodeKeyConfig(e) => Some(e), + ParseOhttpKeysError::InvalidFormat | ParseOhttpKeysError::InvalidPublicKey => None, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_ohttp_keys_roundtrip() { + use std::str::FromStr; + + use ohttp::hpke::{Aead, Kdf, Kem}; + use ohttp::{KeyId, SymmetricSuite}; + const KEY_ID: KeyId = 1; + const KEM: Kem = Kem::K256Sha256; + const SYMMETRIC: &[SymmetricSuite] = + &[ohttp::SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)]; + let keys = OhttpKeys(ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap()); + let serialized = &keys.to_string(); + let deserialized = OhttpKeys::from_str(serialized).unwrap(); + assert_eq!(keys.encode().unwrap(), deserialized.encode().unwrap()); + } +} diff --git a/payjoin/src/receive/error.rs b/payjoin/src/receive/error.rs index 82bcbb77..a8479db6 100644 --- a/payjoin/src/receive/error.rs +++ b/payjoin/src/receive/error.rs @@ -36,13 +36,13 @@ impl From for Error { } #[cfg(feature = "v2")] -impl From for Error { - fn from(e: crate::v2::HpkeError) -> Self { Error::Server(Box::new(e)) } +impl From for Error { + fn from(e: crate::hpke::HpkeError) -> Self { Error::Server(Box::new(e)) } } #[cfg(feature = "v2")] -impl From for Error { - fn from(e: crate::v2::OhttpEncapsulationError) -> Self { Error::Server(Box::new(e)) } +impl From for Error { + fn from(e: crate::ohttp::OhttpEncapsulationError) -> Self { Error::Server(Box::new(e)) } } /// Error that may occur when the request from sender is malformed. diff --git a/payjoin/src/receive/v2/error.rs b/payjoin/src/receive/v2/error.rs index c6d7daf2..1a934dd3 100644 --- a/payjoin/src/receive/v2/error.rs +++ b/payjoin/src/receive/v2/error.rs @@ -1,7 +1,7 @@ use core::fmt; use std::error; -use crate::v2::OhttpEncapsulationError; +use crate::ohttp::OhttpEncapsulationError; #[derive(Debug)] pub struct SessionError(InternalSessionError); @@ -11,14 +11,14 @@ pub(crate) enum InternalSessionError { /// The session has expired Expired(std::time::SystemTime), /// OHTTP Encapsulation failed - OhttpEncapsulationError(OhttpEncapsulationError), + OhttpEncapsulation(OhttpEncapsulationError), } impl fmt::Display for SessionError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match &self.0 { InternalSessionError::Expired(expiry) => write!(f, "Session expired at {:?}", expiry), - InternalSessionError::OhttpEncapsulationError(e) => + InternalSessionError::OhttpEncapsulation(e) => write!(f, "OHTTP Encapsulation Error: {}", e), } } @@ -28,7 +28,7 @@ impl error::Error for SessionError { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match &self.0 { InternalSessionError::Expired(_) => None, - InternalSessionError::OhttpEncapsulationError(e) => Some(e), + InternalSessionError::OhttpEncapsulation(e) => Some(e), } } } @@ -39,6 +39,6 @@ impl From for SessionError { impl From for SessionError { fn from(e: OhttpEncapsulationError) -> Self { - SessionError(InternalSessionError::OhttpEncapsulationError(e)) + SessionError(InternalSessionError::OhttpEncapsulation(e)) } } diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 2c250bfa..cb19ffbd 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -14,10 +14,11 @@ use super::{ Error, InputContributionError, InternalRequestError, OutputSubstitutionError, RequestError, SelectionError, }; +use crate::hpke::{decrypt_message_a, encrypt_message_b, HpkeKeyPair, HpkePublicKey}; +use crate::ohttp::{ohttp_decapsulate, ohttp_encapsulate, OhttpEncapsulationError, OhttpKeys}; use crate::psbt::PsbtExt; use crate::receive::optional_parameters::Params; -use crate::v2::{HpkeKeyPair, HpkePublicKey, OhttpEncapsulationError}; -use crate::{OhttpKeys, PjUriBuilder, Request}; +use crate::{PjUriBuilder, Request}; pub(crate) mod error; @@ -45,16 +46,19 @@ where Ok(address.assume_checked()) } -/// Initializes a new payjoin session, including necessary context -/// information for communication and cryptographic operations. -#[derive(Debug, Clone)] -pub struct SessionInitializer { +fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> String { + BASE64_URL_SAFE_NO_PAD.encode(pubkey.to_compressed_bytes()) +} + +/// A payjoin V2 receiver, allowing for polled requests to the +/// payjoin directory and response processing. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Receiver { context: SessionContext, } -#[cfg(feature = "v2")] -impl SessionInitializer { - /// Creates a new `SessionInitializer` with the provided parameters. +impl Receiver { + /// Creates a new `Receiver` with the provided parameters. /// /// # Parameters /// - `address`: The Bitcoin address for the payjoin session. @@ -64,7 +68,7 @@ impl SessionInitializer { /// - `expire_after`: The duration after which the session expires. /// /// # Returns - /// A new instance of `SessionInitializer`. + /// A new instance of `Receiver`. /// /// # References /// - [BIP 77: Payjoin Version 2: Serverless Payjoin](https://github.com/bitcoin/bips/pull/1483) @@ -90,62 +94,13 @@ impl SessionInitializer { } } - pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> { - let url = self.context.ohttp_relay.clone(); - let subdirectory = subdir_path_from_pubkey(self.context.s.public_key()); - let (body, ctx) = crate::v2::ohttp_encapsulate( - &mut self.context.ohttp_keys, - "POST", - self.context.directory.as_str(), - Some(subdirectory.as_bytes()), - )?; - let req = Request::new_v2(url, body); - Ok((req, ctx)) - } - - pub fn process_res( - mut self, - mut res: impl std::io::Read, - ctx: ohttp::ClientResponse, - ) -> Result { - let mut buf = Vec::new(); - let _ = res.read_to_end(&mut buf); - let response = crate::v2::ohttp_decapsulate(ctx, &buf)?; - if !response.status().is_success() { - return Err(Error::Server("Enrollment failed, expected success status".into())); - } - log::debug!("Received response headers: {:?}", response.headers()); - let location = response - .headers() - .get("location") - .ok_or(Error::Server("Missing location header".into()))? - .to_str() - .map_err(|e| Error::Server(format!("Invalid location header: {}", e).into()))?; - self.context.subdirectory = - Some(url::Url::parse(location).map_err(|e| Error::Server(e.into()))?); - - Ok(ActiveSession { context: self.context.clone() }) - } -} - -fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> String { - BASE64_URL_SAFE_NO_PAD.encode(pubkey.to_compressed_bytes()) -} - -/// An active payjoin V2 session, allowing for polled requests to the -/// payjoin directory and response processing. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct ActiveSession { - context: SessionContext, -} - -impl ActiveSession { + /// Extratct an OHTTP Encapsulated HTTP GET request for the Original PSBT pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), SessionError> { if SystemTime::now() > self.context.expiry { return Err(InternalSessionError::Expired(self.context.expiry).into()); } let (body, ohttp_ctx) = - self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulationError)?; + self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulation)?; let url = self.context.ohttp_relay.clone(); let req = Request::new_v2(url, body); Ok((req, ohttp_ctx)) @@ -161,7 +116,7 @@ impl ActiveSession { let mut buf = Vec::new(); let _ = body.read_to_end(&mut buf); log::trace!("decapsulating directory response"); - let response = crate::v2::ohttp_decapsulate(context, &buf)?; + let response = ohttp_decapsulate(context, &buf)?; if response.body().is_empty() { log::debug!("response is empty"); return Ok(None); @@ -178,12 +133,7 @@ impl ActiveSession { &mut self, ) -> Result<(Vec, ohttp::ClientResponse), OhttpEncapsulationError> { let fallback_target = self.pj_url(); - crate::v2::ohttp_encapsulate( - &mut self.context.ohttp_keys, - "GET", - fallback_target.as_str(), - None, - ) + ohttp_encapsulate(&mut self.context.ohttp_keys, "GET", fallback_target.as_str(), None) } fn extract_proposal_from_v1(&mut self, response: String) -> Result { @@ -191,8 +141,7 @@ impl ActiveSession { } fn extract_proposal_from_v2(&mut self, response: Vec) -> Result { - let (payload_bytes, e) = - crate::v2::decrypt_message_a(&response, self.context.s.secret_key().clone())?; + let (payload_bytes, e) = decrypt_message_a(&response, self.context.s.secret_key().clone())?; self.context.e = Some(e); let payload = String::from_utf8(payload_bytes).map_err(InternalRequestError::Utf8)?; Ok(self.unchecked_from_payload(payload)?) @@ -507,22 +456,34 @@ impl PayjoinProposal { #[cfg(feature = "v2")] pub fn extract_v2_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> { - let body = match &self.context.e { - Some(e) => { - let payjoin_bytes = self.inner.payjoin_psbt.serialize(); - log::debug!("THERE IS AN e: {:?}", e); - crate::v2::encrypt_message_b(payjoin_bytes, &self.context.s, e) - } - None => Ok(self.extract_v1_req().as_bytes().to_vec()), - }?; - let subdir_path = subdir_path_from_pubkey(self.context.s.public_key()); - let post_payjoin_target = - self.context.directory.join(&subdir_path).map_err(|e| Error::Server(e.into()))?; - log::debug!("Payjoin post target: {}", post_payjoin_target.as_str()); - let (body, ctx) = crate::v2::ohttp_encapsulate( + let target_resource: Url; + let body: Vec; + let method: &str; + + if let Some(e) = &self.context.e { + // Prepare v2 payload + let payjoin_bytes = self.inner.payjoin_psbt.serialize(); + let sender_subdir = subdir_path_from_pubkey(e); + target_resource = + self.context.directory.join(&sender_subdir).map_err(|e| Error::Server(e.into()))?; + body = encrypt_message_b(payjoin_bytes, &self.context.s, e)?; + method = "POST"; + } else { + // Prepare v2 wrapped and backwards-compatible v1 payload + body = self.extract_v1_req().as_bytes().to_vec(); + let receiver_subdir = subdir_path_from_pubkey(self.context.s.public_key()); + target_resource = self + .context + .directory + .join(&receiver_subdir) + .map_err(|e| Error::Server(e.into()))?; + method = "PUT"; + } + log::debug!("Payjoin PSBT target: {}", target_resource.as_str()); + let (body, ctx) = ohttp_encapsulate( &mut self.context.ohttp_keys, - "PUT", - post_payjoin_target.as_str(), + method, + target_resource.as_str(), Some(&body), )?; let url = self.context.ohttp_relay.clone(); @@ -543,7 +504,7 @@ impl PayjoinProposal { res: Vec, ohttp_context: ohttp::ClientResponse, ) -> Result<(), Error> { - let res = crate::v2::ohttp_decapsulate(ohttp_context, &res)?; + let res = ohttp_decapsulate(ohttp_context, &res)?; if res.status().is_success() { Ok(()) } else { @@ -561,7 +522,7 @@ mod test { #[test] #[cfg(feature = "v2")] - fn active_session_ser_de_roundtrip() { + fn receiver_ser_de_roundtrip() { use ohttp::hpke::{Aead, Kdf, Kem}; use ohttp::{KeyId, SymmetricSuite}; const KEY_ID: KeyId = 1; @@ -569,7 +530,7 @@ mod test { const SYMMETRIC: &[SymmetricSuite] = &[ohttp::SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)]; - let session = ActiveSession { + let session = Receiver { context: SessionContext { address: Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4") .unwrap() @@ -586,7 +547,7 @@ mod test { }, }; let serialized = serde_json::to_string(&session).unwrap(); - let deserialized: ActiveSession = serde_json::from_str(&serialized).unwrap(); + let deserialized: Receiver = serde_json::from_str(&serialized).unwrap(); assert_eq!(session, deserialized); } } diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index f453d866..94c13737 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -58,9 +58,9 @@ pub(crate) enum InternalValidationError { FeeRateBelowMinimum, Psbt(bitcoin::psbt::Error), #[cfg(feature = "v2")] - Hpke(crate::v2::HpkeError), + Hpke(crate::hpke::HpkeError), #[cfg(feature = "v2")] - OhttpEncapsulation(crate::v2::OhttpEncapsulationError), + OhttpEncapsulation(crate::ohttp::OhttpEncapsulationError), #[cfg(feature = "v2")] UnexpectedStatusCode, } @@ -190,9 +190,9 @@ pub(crate) enum InternalCreateRequestError { AddressType(crate::psbt::AddressTypeError), InputWeight(crate::psbt::InputWeightError), #[cfg(feature = "v2")] - Hpke(crate::v2::HpkeError), + Hpke(crate::hpke::HpkeError), #[cfg(feature = "v2")] - OhttpEncapsulation(crate::v2::OhttpEncapsulationError), + OhttpEncapsulation(crate::ohttp::OhttpEncapsulationError), #[cfg(feature = "v2")] ParseSubdirectory(ParseSubdirectoryError), #[cfg(feature = "v2")] @@ -289,7 +289,7 @@ impl From for CreateRequestError { pub(crate) enum ParseSubdirectoryError { MissingSubdirectory, SubdirectoryNotBase64(bitcoin::base64::DecodeError), - SubdirectoryInvalidPubkey(crate::v2::HpkeError), + SubdirectoryInvalidPubkey(crate::hpke::HpkeError), } #[cfg(feature = "v2")] diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 774104df..f9db3547 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -9,11 +9,10 @@ //! 2. Construct URI request parameters, a finalized “Original PSBT” paying .amount to .address //! 3. (optional) Spawn a thread or async task that will broadcast the original PSBT fallback after //! delay (e.g. 1 minute) unless canceled -//! 4. Construct the request using [`RequestBuilder`] with the PSBT and payjoin uri -//! 5. Send the request and receive response -//! 6. Process the response with [`ContextV1::process_response`] -//! 7. Sign and finalize the Payjoin Proposal PSBT -//! 8. Broadcast the Payjoin Transaction (and cancel the optional fallback broadcast) +//! 4. Construct the [`Sender`] using [`SenderBuilder`] with the PSBT and payjoin uri +//! 5. Send the request(s) and receive response(s) by following on the extracted [`Context`] +//! 6. Sign and finalize the Payjoin Proposal PSBT +//! 7. Broadcast the Payjoin Transaction (and cancel the optional fallback broadcast) //! //! This crate is runtime-agnostic. Data persistence, chain interactions, and networking may be //! provided by custom implementations or copy the reference @@ -24,6 +23,8 @@ use std::str::FromStr; +#[cfg(feature = "v2")] +use bitcoin::base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; use bitcoin::psbt::Psbt; use bitcoin::{Amount, FeeRate, Script, ScriptBuf, TxOut, Weight}; pub use error::{CreateRequestError, ResponseError, ValidationError}; @@ -32,10 +33,12 @@ pub(crate) use error::{InternalCreateRequestError, InternalValidationError}; use serde::{Deserialize, Serialize}; use url::Url; +#[cfg(feature = "v2")] +use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeKeyPair, HpkePublicKey}; +#[cfg(feature = "v2")] +use crate::ohttp::{ohttp_decapsulate, ohttp_encapsulate}; use crate::psbt::PsbtExt; use crate::request::Request; -#[cfg(feature = "v2")] -use crate::v2::{HpkePublicKey, HpkeSecretKey}; use crate::PjUri; // See usize casts @@ -47,7 +50,7 @@ mod error; type InternalResult = Result; #[derive(Clone)] -pub struct RequestBuilder<'a> { +pub struct SenderBuilder<'a> { psbt: Psbt, uri: PjUri<'a>, disable_output_substitution: bool, @@ -61,7 +64,7 @@ pub struct RequestBuilder<'a> { min_fee_rate: FeeRate, } -impl<'a> RequestBuilder<'a> { +impl<'a> SenderBuilder<'a> { /// Prepare an HTTP request and request context to process the response /// /// An HTTP client will own the Request data while Context sticks around so @@ -96,10 +99,7 @@ impl<'a> RequestBuilder<'a> { // The minfeerate parameter is set if the contribution is available in change. // // This method fails if no recommendation can be made or if the PSBT is malformed. - pub fn build_recommended( - self, - min_fee_rate: FeeRate, - ) -> Result { + pub fn build_recommended(self, min_fee_rate: FeeRate) -> Result { // TODO support optional batched payout scripts. This would require a change to // build() which now checks for a single payee. let mut payout_scripts = std::iter::once(self.uri.address.script_pubkey()); @@ -177,7 +177,7 @@ impl<'a> RequestBuilder<'a> { change_index: Option, min_fee_rate: FeeRate, clamp_fee_contribution: bool, - ) -> Result { + ) -> Result { self.fee_contribution = Some((max_fee_contribution, change_index)); self.clamp_fee_contribution = clamp_fee_contribution; self.min_fee_rate = min_fee_rate; @@ -191,7 +191,7 @@ impl<'a> RequestBuilder<'a> { pub fn build_non_incentivizing( mut self, min_fee_rate: FeeRate, - ) -> Result { + ) -> Result { // since this is a builder, these should already be cleared // but we'll reset them to be sure self.fee_contribution = None; @@ -200,7 +200,7 @@ impl<'a> RequestBuilder<'a> { self.build() } - fn build(self) -> Result { + fn build(self) -> Result { let mut psbt = self.psbt.validate().map_err(InternalCreateRequestError::InconsistentOriginalPsbt)?; psbt.validate_input_utxos(true) @@ -219,35 +219,31 @@ impl<'a> RequestBuilder<'a> { )?; clear_unneeded_fields(&mut psbt); - Ok(RequestContext { + Ok(Sender { psbt, endpoint, disable_output_substitution, fee_contribution, payee, min_fee_rate: self.min_fee_rate, - #[cfg(feature = "v2")] - e: crate::v2::HpkeKeyPair::gen_keypair().secret_key().clone(), }) } } #[derive(Clone, PartialEq, Eq)] #[cfg_attr(feature = "v2", derive(Serialize, Deserialize))] -pub struct RequestContext { +pub struct Sender { psbt: Psbt, endpoint: Url, disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, payee: ScriptBuf, - #[cfg(feature = "v2")] - e: crate::v2::HpkeSecretKey, } -impl RequestContext { +impl Sender { /// Extract serialized V1 Request and Context from a Payjoin Proposal - pub fn extract_v1(&self) -> Result<(Request, ContextV1), CreateRequestError> { + pub fn extract_v1(&self) -> Result<(Request, V1Context), CreateRequestError> { let url = serialize_url( self.endpoint.clone(), self.disable_output_substitution, @@ -259,13 +255,15 @@ impl RequestContext { let body = self.psbt.to_string().as_bytes().to_vec(); Ok(( Request::new_v1(url, body), - ContextV1 { - original_psbt: self.psbt.clone(), - disable_output_substitution: self.disable_output_substitution, - fee_contribution: self.fee_contribution, - payee: self.payee.clone(), - min_fee_rate: self.min_fee_rate, - allow_mixed_input_scripts: false, + V1Context { + psbt_context: PsbtContext { + original_psbt: self.psbt.clone(), + disable_output_substitution: self.disable_output_substitution, + fee_contribution: self.fee_contribution, + payee: self.payee.clone(), + min_fee_rate: self.min_fee_rate, + allow_mixed_input_scripts: false, + }, }, )) } @@ -277,10 +275,10 @@ impl RequestContext { /// /// The `ohttp_relay` merely passes the encrypted payload to the ohttp gateway of the receiver #[cfg(feature = "v2")] - pub fn extract_v2( + pub fn extract_highest_version( &mut self, ohttp_relay: Url, - ) -> Result<(Request, ContextV2), CreateRequestError> { + ) -> Result<(Request, Context), CreateRequestError> { use crate::uri::UrlExt; if let Some(expiry) = self.endpoint.exp() { @@ -290,11 +288,11 @@ impl RequestContext { } match self.extract_rs_pubkey() { - Ok(rs) => self.extract_v2_strict(ohttp_relay, rs), + Ok(rs) => self.extract_v2(ohttp_relay, rs), Err(e) => { log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e); let (req, context_v1) = self.extract_v1()?; - Ok((req, ContextV2 { context_v1, rs: None, e: None, ohttp_res: None })) + Ok((req, Context::V1(context_v1))) } } } @@ -304,11 +302,11 @@ impl RequestContext { /// This method requires the `rs` pubkey to be extracted from the endpoint /// and has no fallback to v1. #[cfg(feature = "v2")] - fn extract_v2_strict( + fn extract_v2( &mut self, ohttp_relay: Url, rs: HpkePublicKey, - ) -> Result<(Request, ContextV2), CreateRequestError> { + ) -> Result<(Request, Context), CreateRequestError> { use crate::uri::UrlExt; let url = self.endpoint.clone(); let body = serialize_v2_body( @@ -317,18 +315,23 @@ impl RequestContext { self.fee_contribution, self.min_fee_rate, )?; - let body = crate::v2::encrypt_message_a(body, &self.e.clone(), &rs) - .map_err(InternalCreateRequestError::Hpke)?; + let hpke_ctx = HpkeContext::new(rs); + let body = encrypt_message_a( + body, + &hpke_ctx.reply_pair.public_key().clone(), + &hpke_ctx.receiver.clone(), + ) + .map_err(InternalCreateRequestError::Hpke)?; let mut ohttp = self.endpoint.ohttp().ok_or(InternalCreateRequestError::MissingOhttpConfig)?; - let (body, ohttp_res) = - crate::v2::ohttp_encapsulate(&mut ohttp, "POST", url.as_str(), Some(&body)) - .map_err(InternalCreateRequestError::OhttpEncapsulation)?; + let (body, ohttp_ctx) = ohttp_encapsulate(&mut ohttp, "POST", url.as_str(), Some(&body)) + .map_err(InternalCreateRequestError::OhttpEncapsulation)?; log::debug!("ohttp_relay_url: {:?}", ohttp_relay); Ok(( Request::new_v2(ohttp_relay, body), - ContextV2 { - context_v1: ContextV1 { + Context::V2(V2PostContext { + endpoint: self.endpoint.clone(), + psbt_ctx: PsbtContext { original_psbt: self.psbt.clone(), disable_output_substitution: self.disable_output_substitution, fee_contribution: self.fee_contribution, @@ -336,17 +339,14 @@ impl RequestContext { min_fee_rate: self.min_fee_rate, allow_mixed_input_scripts: true, }, - rs: Some(self.extract_rs_pubkey()?), - e: Some(self.e.clone()), - ohttp_res: Some(ohttp_res), - }, + hpke_ctx, + ohttp_ctx, + }), )) } #[cfg(feature = "v2")] fn extract_rs_pubkey(&self) -> Result { - use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; - use bitcoin::base64::Engine; use error::ParseSubdirectoryError; let subdirectory = self @@ -366,12 +366,122 @@ impl RequestContext { pub fn endpoint(&self) -> &Url { &self.endpoint } } +pub enum Context { + V1(V1Context), + #[cfg(feature = "v2")] + V2(V2PostContext), +} + +pub struct V1Context { + psbt_context: PsbtContext, +} + +impl V1Context { + pub fn process_response( + self, + response: &mut impl std::io::Read, + ) -> Result { + self.psbt_context.process_response(response) + } +} + +#[cfg(feature = "v2")] +pub struct V2PostContext { + endpoint: Url, + psbt_ctx: PsbtContext, + hpke_ctx: HpkeContext, + ohttp_ctx: ohttp::ClientResponse, +} + +#[cfg(feature = "v2")] +impl V2PostContext { + pub fn process_response( + self, + response: &mut impl std::io::Read, + ) -> Result { + let mut res_buf = Vec::new(); + response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; + let response = ohttp_decapsulate(self.ohttp_ctx, &res_buf) + .map_err(InternalValidationError::OhttpEncapsulation)?; + match response.status() { + http::StatusCode::OK => { + // return OK with new Typestate + Ok(V2GetContext { + endpoint: self.endpoint, + psbt_ctx: self.psbt_ctx, + hpke_ctx: self.hpke_ctx, + }) + } + _ => Err(InternalValidationError::UnexpectedStatusCode)?, + } + } +} + +#[cfg(feature = "v2")] +pub struct V2GetContext { + endpoint: Url, + psbt_ctx: PsbtContext, + hpke_ctx: HpkeContext, +} + +#[cfg(feature = "v2")] +impl V2GetContext { + pub fn extract_req( + &self, + ohttp_relay: Url, + ) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> { + use crate::uri::UrlExt; + let mut url = self.endpoint.clone(); + let subdir = BASE64_URL_SAFE_NO_PAD + .encode(self.hpke_ctx.reply_pair.public_key().to_compressed_bytes()); + url.set_path(&subdir); + let body = encrypt_message_a( + Vec::new(), + &self.hpke_ctx.reply_pair.public_key().clone(), + &self.hpke_ctx.receiver.clone(), + ) + .map_err(InternalCreateRequestError::Hpke)?; + let mut ohttp = + self.endpoint.ohttp().ok_or(InternalCreateRequestError::MissingOhttpConfig)?; + let (body, ohttp_ctx) = ohttp_encapsulate(&mut ohttp, "GET", url.as_str(), Some(&body)) + .map_err(InternalCreateRequestError::OhttpEncapsulation)?; + + Ok((Request::new_v2(ohttp_relay, body), ohttp_ctx)) + } + + pub fn process_response( + &self, + response: &mut impl std::io::Read, + ohttp_ctx: ohttp::ClientResponse, + ) -> Result, ResponseError> { + let mut res_buf = Vec::new(); + response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; + let response = ohttp_decapsulate(ohttp_ctx, &res_buf) + .map_err(InternalValidationError::OhttpEncapsulation)?; + let body = match response.status() { + http::StatusCode::OK => response.body().to_vec(), + http::StatusCode::ACCEPTED => return Ok(None), + _ => return Err(InternalValidationError::UnexpectedStatusCode)?, + }; + let psbt = decrypt_message_b( + &body, + self.hpke_ctx.receiver.clone(), + self.hpke_ctx.reply_pair.secret_key().clone(), + ) + .map_err(InternalValidationError::Hpke)?; + + let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?; + let processed_proposal = self.psbt_ctx.clone().process_proposal(proposal)?; + Ok(Some(processed_proposal)) + } +} + /// Data required for validation of response. /// /// This type is used to process the response. Get it from [`RequestBuilder`]'s build methods. /// Then you only need to call [`Self::process_response`] on it to continue BIP78 flow. #[derive(Debug, Clone)] -pub struct ContextV1 { +pub struct PsbtContext { original_psbt: Psbt, disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, @@ -381,11 +491,16 @@ pub struct ContextV1 { } #[cfg(feature = "v2")] -pub struct ContextV2 { - context_v1: ContextV1, - rs: Option, - e: Option, - ohttp_res: Option, +struct HpkeContext { + receiver: HpkePublicKey, + reply_pair: HpkeKeyPair, +} + +#[cfg(feature = "v2")] +impl HpkeContext { + pub fn new(receiver: HpkePublicKey) -> Self { + Self { receiver, reply_pair: HpkeKeyPair::gen_keypair() } + } } macro_rules! check_eq { @@ -406,43 +521,7 @@ macro_rules! ensure { }; } -#[cfg(feature = "v2")] -impl ContextV2 { - /// Decodes and validates the response. - /// - /// Call this method with response from receiver to continue BIP-??? flow. - /// A successful response can either be None if the directory has not response yet or Some(Psbt). - /// - /// If the response is some valid PSBT you should sign and broadcast. - #[inline] - pub fn process_response( - self, - response: &mut impl std::io::Read, - ) -> Result, ResponseError> { - match (self.ohttp_res, self.rs, self.e) { - (Some(ohttp_res), Some(rs), Some(e)) => { - let mut res_buf = Vec::new(); - response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; - let response = crate::v2::ohttp_decapsulate(ohttp_res, &res_buf) - .map_err(InternalValidationError::OhttpEncapsulation)?; - let body = match response.status() { - http::StatusCode::OK => response.body().to_vec(), - http::StatusCode::ACCEPTED => return Ok(None), - _ => return Err(InternalValidationError::UnexpectedStatusCode)?, - }; - let psbt = crate::v2::decrypt_message_b(&body, rs, e) - .map_err(InternalValidationError::Hpke)?; - - let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?; - let processed_proposal = self.context_v1.process_proposal(proposal)?; - Ok(Some(processed_proposal)) - } - _ => self.context_v1.process_response(response).map(Some), - } - } -} - -impl ContextV1 { +impl PsbtContext { /// Decodes and validates the response. /// /// Call this method with response from receiver to continue BIP78 flow. If the response is @@ -850,11 +929,11 @@ mod test { const ORIGINAL_PSBT: &str = "cHNidP8BAHMCAAAAAY8nutGgJdyYGXWiBEb45Hoe9lWGbkxh/6bNiOJdCDuDAAAAAAD+////AtyVuAUAAAAAF6kUHehJ8GnSdBUOOv6ujXLrWmsJRDCHgIQeAAAAAAAXqRR3QJbbz0hnQ8IvQ0fptGn+votneofTAAAAAAEBIKgb1wUAAAAAF6kU3k4ekGHKWRNbA1rV5tR5kEVDVNCHAQcXFgAUx4pFclNVgo1WWAdN1SYNX8tphTABCGsCRzBEAiB8Q+A6dep+Rz92vhy26lT0AjZn4PRLi8Bf9qoB/CMk0wIgP/Rj2PWZ3gEjUkTlhDRNAQ0gXwTO7t9n+V14pZ6oljUBIQMVmsAaoNWHVMS02LfTSe0e388LNitPa1UQZyOihY+FFgABABYAFEb2Giu6c4KO5YW0pfw3lGp9jMUUAAA="; const PAYJOIN_PROPOSAL: &str = "cHNidP8BAJwCAAAAAo8nutGgJdyYGXWiBEb45Hoe9lWGbkxh/6bNiOJdCDuDAAAAAAD+////jye60aAl3JgZdaIERvjkeh72VYZuTGH/ps2I4l0IO4MBAAAAAP7///8CJpW4BQAAAAAXqRQd6EnwadJ0FQ46/q6NcutaawlEMIcACT0AAAAAABepFHdAltvPSGdDwi9DR+m0af6+i2d6h9MAAAAAAQEgqBvXBQAAAAAXqRTeTh6QYcpZE1sDWtXm1HmQRUNU0IcBBBYAFMeKRXJTVYKNVlgHTdUmDV/LaYUwIgYDFZrAGqDVh1TEtNi300ntHt/PCzYrT2tVEGcjooWPhRYYSFzWUDEAAIABAACAAAAAgAEAAAAAAAAAAAEBIICEHgAAAAAAF6kUyPLL+cphRyyI5GTUazV0hF2R2NWHAQcXFgAUX4BmVeWSTJIEwtUb5TlPS/ntohABCGsCRzBEAiBnu3tA3yWlT0WBClsXXS9j69Bt+waCs9JcjWtNjtv7VgIge2VYAaBeLPDB6HGFlpqOENXMldsJezF9Gs5amvDQRDQBIQJl1jz1tBt8hNx2owTm+4Du4isx0pmdKNMNIjjaMHFfrQABABYAFEb2Giu6c4KO5YW0pfw3lGp9jMUUIgICygvBWB5prpfx61y1HDAwo37kYP3YRJBvAjtunBAur3wYSFzWUDEAAIABAACAAAAAgAEAAAABAAAAAAA="; - fn create_v1_context() -> super::ContextV1 { + fn create_v1_context() -> super::PsbtContext { let original_psbt = Psbt::from_str(ORIGINAL_PSBT).unwrap(); eprintln!("original: {:#?}", original_psbt); let payee = original_psbt.unsigned_tx.output[1].script_pubkey.clone(); - let ctx = super::ContextV1 { + let ctx = super::PsbtContext { original_psbt, disable_output_substitution: false, fee_contribution: Some((bitcoin::Amount::from_sat(182), 0)), @@ -906,20 +985,14 @@ mod test { #[test] #[cfg(feature = "v2")] fn req_ctx_ser_de_roundtrip() { - use hpke::Deserializable; - use super::*; - let req_ctx = RequestContext { + let req_ctx = Sender { psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(), endpoint: Url::parse("http://localhost:1234").unwrap(), disable_output_substitution: false, fee_contribution: None, min_fee_rate: FeeRate::ZERO, payee: ScriptBuf::from(vec![0x00]), - e: HpkeSecretKey( - ::PrivateKey::from_bytes(&[0x01; 32]) - .unwrap(), - ), }; let serialized = serde_json::to_string(&req_ctx).unwrap(); let deserialized = serde_json::from_str(&serialized).unwrap(); diff --git a/payjoin/src/v2.rs b/payjoin/src/v2.rs deleted file mode 100644 index 6c62f3ef..00000000 --- a/payjoin/src/v2.rs +++ /dev/null @@ -1,524 +0,0 @@ -use std::ops::{Deref, DerefMut}; -use std::{error, fmt}; - -use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; -use bitcoin::base64::Engine; -use bitcoin::key::constants::UNCOMPRESSED_PUBLIC_KEY_SIZE; -use hpke::aead::ChaCha20Poly1305; -use hpke::kdf::HkdfSha256; -use hpke::kem::SecpK256HkdfSha256; -use hpke::rand_core::OsRng; -use hpke::{Deserializable, OpModeR, OpModeS, Serializable}; -use serde::{Deserialize, Serialize}; - -pub const PADDED_MESSAGE_BYTES: usize = 7168; -pub const PADDED_PLAINTEXT_A_LENGTH: usize = - PADDED_MESSAGE_BYTES - UNCOMPRESSED_PUBLIC_KEY_SIZE * 2; -pub const PADDED_PLAINTEXT_B_LENGTH: usize = PADDED_MESSAGE_BYTES - UNCOMPRESSED_PUBLIC_KEY_SIZE; -pub const INFO_A: &[u8] = b"PjV2MsgA"; -pub const INFO_B: &[u8] = b"PjV2MsgB"; - -pub type SecretKey = ::PrivateKey; -pub type PublicKey = ::PublicKey; -pub type EncappedKey = ::EncappedKey; - -fn sk_to_pk(sk: &SecretKey) -> PublicKey { ::sk_to_pk(sk) } - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct HpkeKeyPair(pub HpkeSecretKey, pub HpkePublicKey); - -impl From for (HpkeSecretKey, HpkePublicKey) { - fn from(value: HpkeKeyPair) -> Self { (value.0, value.1) } -} - -impl HpkeKeyPair { - pub fn gen_keypair() -> Self { - let (sk, pk) = ::gen_keypair(&mut OsRng); - Self(HpkeSecretKey(sk), HpkePublicKey(pk)) - } - pub fn secret_key(&self) -> &HpkeSecretKey { &self.0 } - pub fn public_key(&self) -> &HpkePublicKey { &self.1 } -} - -#[derive(Clone, PartialEq, Eq)] -pub struct HpkeSecretKey(pub SecretKey); - -impl Deref for HpkeSecretKey { - type Target = SecretKey; - - fn deref(&self) -> &Self::Target { &self.0 } -} - -impl core::fmt::Debug for HpkeSecretKey { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "SecpHpkeSecretKey({:?})", self.0.to_bytes()) - } -} - -impl serde::Serialize for HpkeSecretKey { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_bytes(&self.0.to_bytes()) - } -} - -impl<'de> serde::Deserialize<'de> for HpkeSecretKey { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let bytes = Vec::::deserialize(deserializer)?; - Ok(HpkeSecretKey( - SecretKey::from_bytes(&bytes) - .map_err(|_| serde::de::Error::custom("Invalid secret key"))?, - )) - } -} - -#[derive(Clone, PartialEq, Eq)] -pub struct HpkePublicKey(pub PublicKey); - -impl HpkePublicKey { - pub fn to_compressed_bytes(&self) -> [u8; 33] { - let compressed_key = bitcoin::secp256k1::PublicKey::from_slice(&self.0.to_bytes()) - .expect("Invalid public key from known valid bytes"); - compressed_key.serialize() - } - - pub fn from_compressed_bytes(bytes: &[u8]) -> Result { - let compressed_key = bitcoin::secp256k1::PublicKey::from_slice(bytes)?; - Ok(HpkePublicKey(PublicKey::from_bytes( - compressed_key.serialize_uncompressed().as_slice(), - )?)) - } -} - -impl Deref for HpkePublicKey { - type Target = PublicKey; - - fn deref(&self) -> &Self::Target { &self.0 } -} - -impl core::fmt::Debug for HpkePublicKey { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "SecpHpkePublicKey({:?})", self.0) - } -} - -impl serde::Serialize for HpkePublicKey { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_bytes(&self.0.to_bytes()) - } -} - -impl<'de> serde::Deserialize<'de> for HpkePublicKey { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let bytes = Vec::::deserialize(deserializer)?; - Ok(HpkePublicKey( - PublicKey::from_bytes(&bytes) - .map_err(|_| serde::de::Error::custom("Invalid public key"))?, - )) - } -} - -/// Message A is sent from the sender to the receiver containing an Original PSBT payload -#[cfg(feature = "send")] -pub fn encrypt_message_a( - mut plaintext: Vec, - sender_sk: &HpkeSecretKey, - receiver_pk: &HpkePublicKey, -) -> Result, HpkeError> { - let pk = sk_to_pk(&sender_sk.0); - let (encapsulated_key, mut encryption_context) = - hpke::setup_sender::( - &OpModeS::Auth((sender_sk.0.clone(), pk.clone())), - &receiver_pk.0, - INFO_A, - &mut OsRng, - )?; - let aad = pk.to_bytes().to_vec(); - let plaintext = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_A_LENGTH)?; - let ciphertext = encryption_context.seal(plaintext, &aad)?; - let mut message_a = encapsulated_key.to_bytes().to_vec(); - message_a.extend(&aad); - message_a.extend(&ciphertext); - Ok(message_a.to_vec()) -} - -#[cfg(feature = "receive")] -pub fn decrypt_message_a( - message_a: &[u8], - receiver_sk: HpkeSecretKey, -) -> Result<(Vec, HpkePublicKey), HpkeError> { - let enc = message_a.get(..65).ok_or(HpkeError::PayloadTooShort)?; - let enc = EncappedKey::from_bytes(enc)?; - let aad = message_a.get(65..130).ok_or(HpkeError::PayloadTooShort)?; - let pk_s = PublicKey::from_bytes(aad)?; - let mut decryption_ctx = hpke::setup_receiver::< - ChaCha20Poly1305, - HkdfSha256, - SecpK256HkdfSha256, - >(&OpModeR::Auth(pk_s.clone()), &receiver_sk.0, &enc, INFO_A)?; - let ciphertext = message_a.get(130..).ok_or(HpkeError::PayloadTooShort)?; - let plaintext = decryption_ctx.open(ciphertext, aad)?; - Ok((plaintext, HpkePublicKey(pk_s))) -} - -/// Message B is sent from the receiver to the sender containing a Payjoin PSBT payload or an error -#[cfg(feature = "receive")] -pub fn encrypt_message_b( - mut plaintext: Vec, - receiver_keypair: &HpkeKeyPair, - sender_pk: &HpkePublicKey, -) -> Result, HpkeError> { - let (encapsulated_key, mut encryption_context) = - hpke::setup_sender::( - &OpModeS::Auth(( - receiver_keypair.secret_key().0.clone(), - receiver_keypair.public_key().0.clone(), - )), - &sender_pk.0, - INFO_B, - &mut OsRng, - )?; - let plaintext = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_B_LENGTH)?; - let ciphertext = encryption_context.seal(plaintext, &[])?; - let mut message_b = encapsulated_key.to_bytes().to_vec(); - message_b.extend(&ciphertext); - Ok(message_b.to_vec()) -} - -#[cfg(feature = "send")] -pub fn decrypt_message_b( - message_b: &[u8], - receiver_pk: HpkePublicKey, - sender_sk: HpkeSecretKey, -) -> Result, HpkeError> { - let enc = message_b.get(..65).ok_or(HpkeError::PayloadTooShort)?; - let enc = EncappedKey::from_bytes(enc)?; - let mut decryption_ctx = hpke::setup_receiver::< - ChaCha20Poly1305, - HkdfSha256, - SecpK256HkdfSha256, - >(&OpModeR::Auth(receiver_pk.0), &sender_sk.0, &enc, INFO_B)?; - let plaintext = - decryption_ctx.open(message_b.get(65..).ok_or(HpkeError::PayloadTooShort)?, &[])?; - Ok(plaintext) -} - -fn pad_plaintext(msg: &mut Vec, padded_length: usize) -> Result<&[u8], HpkeError> { - if msg.len() > padded_length { - return Err(HpkeError::PayloadTooLarge { actual: msg.len(), max: padded_length }); - } - msg.resize(padded_length, 0); - Ok(msg) -} - -/// Error from de/encrypting a v2 Hybrid Public Key Encryption payload. -#[derive(Debug)] -pub enum HpkeError { - Secp256k1(bitcoin::secp256k1::Error), - Hpke(hpke::HpkeError), - InvalidKeyLength, - PayloadTooLarge { actual: usize, max: usize }, - PayloadTooShort, -} - -impl From for HpkeError { - fn from(value: hpke::HpkeError) -> Self { Self::Hpke(value) } -} - -impl From for HpkeError { - fn from(value: bitcoin::secp256k1::Error) -> Self { Self::Secp256k1(value) } -} - -impl fmt::Display for HpkeError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use HpkeError::*; - - match &self { - Hpke(e) => e.fmt(f), - InvalidKeyLength => write!(f, "Invalid Length"), - PayloadTooLarge { actual, max } => { - write!( - f, - "Plaintext too large, max size is {} bytes, actual size is {} bytes", - max, actual - ) - } - PayloadTooShort => write!(f, "Payload too small"), - Secp256k1(e) => e.fmt(f), - } - } -} - -impl error::Error for HpkeError { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - use HpkeError::*; - - match &self { - Hpke(e) => Some(e), - PayloadTooLarge { .. } => None, - InvalidKeyLength | PayloadTooShort => None, - Secp256k1(e) => Some(e), - } - } -} - -pub fn ohttp_encapsulate( - ohttp_keys: &mut ohttp::KeyConfig, - method: &str, - target_resource: &str, - body: Option<&[u8]>, -) -> Result<(Vec, ohttp::ClientResponse), OhttpEncapsulationError> { - use std::fmt::Write; - - let ctx = ohttp::ClientRequest::from_config(ohttp_keys)?; - let url = url::Url::parse(target_resource)?; - let authority_bytes = url.host().map_or_else(Vec::new, |host| { - let mut authority = host.to_string(); - if let Some(port) = url.port() { - write!(authority, ":{}", port).unwrap(); - } - authority.into_bytes() - }); - let mut bhttp_message = bhttp::Message::request( - method.as_bytes().to_vec(), - url.scheme().as_bytes().to_vec(), - authority_bytes, - url.path().as_bytes().to_vec(), - ); - // None of our messages include headers, so we don't add them - if let Some(body) = body { - bhttp_message.write_content(body); - } - let mut bhttp_req = Vec::new(); - let _ = bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_req); - let encapsulated = ctx.encapsulate(&bhttp_req)?; - Ok(encapsulated) -} - -/// decapsulate ohttp, bhttp response and return http response body and status code -pub fn ohttp_decapsulate( - res_ctx: ohttp::ClientResponse, - ohttp_body: &[u8], -) -> Result>, OhttpEncapsulationError> { - let bhttp_body = res_ctx.decapsulate(ohttp_body)?; - let mut r = std::io::Cursor::new(bhttp_body); - let m: bhttp::Message = bhttp::Message::read_bhttp(&mut r)?; - let mut builder = http::Response::builder(); - for field in m.header().iter() { - builder = builder.header(field.name(), field.value()); - } - builder - .status(m.control().status().unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR.into())) - .body(m.content().to_vec()) - .map_err(OhttpEncapsulationError::Http) -} - -/// Error from de/encapsulating an Oblivious HTTP request or response. -#[derive(Debug)] -pub enum OhttpEncapsulationError { - Http(http::Error), - Ohttp(ohttp::Error), - Bhttp(bhttp::Error), - ParseUrl(url::ParseError), -} - -impl From for OhttpEncapsulationError { - fn from(value: http::Error) -> Self { Self::Http(value) } -} - -impl From for OhttpEncapsulationError { - fn from(value: ohttp::Error) -> Self { Self::Ohttp(value) } -} - -impl From for OhttpEncapsulationError { - fn from(value: bhttp::Error) -> Self { Self::Bhttp(value) } -} - -impl From for OhttpEncapsulationError { - fn from(value: url::ParseError) -> Self { Self::ParseUrl(value) } -} - -impl fmt::Display for OhttpEncapsulationError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use OhttpEncapsulationError::*; - - match &self { - Http(e) => e.fmt(f), - Ohttp(e) => e.fmt(f), - Bhttp(e) => e.fmt(f), - ParseUrl(e) => e.fmt(f), - } - } -} - -impl error::Error for OhttpEncapsulationError { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - use OhttpEncapsulationError::*; - - match &self { - Http(e) => Some(e), - Ohttp(e) => Some(e), - Bhttp(e) => Some(e), - ParseUrl(e) => Some(e), - } - } -} - -#[derive(Debug, Clone)] -pub struct OhttpKeys(pub ohttp::KeyConfig); - -impl OhttpKeys { - /// Decode an OHTTP KeyConfig - pub fn decode(bytes: &[u8]) -> Result { - ohttp::KeyConfig::decode(bytes).map(Self) - } -} - -const KEM_ID: &[u8] = b"\x00\x16"; // DHKEM(secp256k1, HKDF-SHA256) -const SYMMETRIC_LEN: &[u8] = b"\x00\x04"; // 4 bytes -const SYMMETRIC_KDF_AEAD: &[u8] = b"\x00\x01\x00\x03"; // KDF(HKDF-SHA256), AEAD(ChaCha20Poly1305) - -impl fmt::Display for OhttpKeys { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let bytes = self.encode().map_err(|_| fmt::Error)?; - let key_id = bytes[0]; - let pubkey = &bytes[3..68]; - - let compressed_pubkey = - bitcoin::secp256k1::PublicKey::from_slice(pubkey).map_err(|_| fmt::Error)?.serialize(); - - let mut buf = vec![key_id]; - buf.extend_from_slice(&compressed_pubkey); - - let encoded = BASE64_URL_SAFE_NO_PAD.encode(buf); - write!(f, "{}", encoded) - } -} - -impl std::str::FromStr for OhttpKeys { - type Err = ParseOhttpKeysError; - - /// Parses a base64URL-encoded string into OhttpKeys. - /// The string format is: key_id || compressed_public_key - fn from_str(s: &str) -> Result { - let bytes = BASE64_URL_SAFE_NO_PAD.decode(s).map_err(ParseOhttpKeysError::DecodeBase64)?; - - let key_id = *bytes.first().ok_or(ParseOhttpKeysError::InvalidFormat)?; - let compressed_pk = bytes.get(1..34).ok_or(ParseOhttpKeysError::InvalidFormat)?; - - let pubkey = bitcoin::secp256k1::PublicKey::from_slice(compressed_pk) - .map_err(|_| ParseOhttpKeysError::InvalidPublicKey)?; - - let mut buf = vec![key_id]; - buf.extend_from_slice(KEM_ID); - buf.extend_from_slice(&pubkey.serialize_uncompressed()); - buf.extend_from_slice(SYMMETRIC_LEN); - buf.extend_from_slice(SYMMETRIC_KDF_AEAD); - - ohttp::KeyConfig::decode(&buf).map(Self).map_err(ParseOhttpKeysError::DecodeKeyConfig) - } -} - -impl PartialEq for OhttpKeys { - fn eq(&self, other: &Self) -> bool { - match (self.encode(), other.encode()) { - (Ok(self_encoded), Ok(other_encoded)) => self_encoded == other_encoded, - // If OhttpKeys::encode(&self) is Err, return false - _ => false, - } - } -} - -impl Eq for OhttpKeys {} - -impl Deref for OhttpKeys { - type Target = ohttp::KeyConfig; - - fn deref(&self) -> &Self::Target { &self.0 } -} - -impl DerefMut for OhttpKeys { - fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } -} - -impl<'de> serde::Deserialize<'de> for OhttpKeys { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let bytes = Vec::::deserialize(deserializer)?; - OhttpKeys::decode(&bytes).map_err(serde::de::Error::custom) - } -} - -impl serde::Serialize for OhttpKeys { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let bytes = self.encode().map_err(serde::ser::Error::custom)?; - bytes.serialize(serializer) - } -} - -#[derive(Debug)] -pub enum ParseOhttpKeysError { - InvalidFormat, - InvalidPublicKey, - DecodeBase64(bitcoin::base64::DecodeError), - DecodeKeyConfig(ohttp::Error), -} - -impl std::fmt::Display for ParseOhttpKeysError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ParseOhttpKeysError::InvalidFormat => write!(f, "Invalid format"), - ParseOhttpKeysError::InvalidPublicKey => write!(f, "Invalid public key"), - ParseOhttpKeysError::DecodeBase64(e) => write!(f, "Failed to decode base64: {}", e), - ParseOhttpKeysError::DecodeKeyConfig(e) => - write!(f, "Failed to decode KeyConfig: {}", e), - } - } -} - -impl std::error::Error for ParseOhttpKeysError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - ParseOhttpKeysError::DecodeBase64(e) => Some(e), - ParseOhttpKeysError::DecodeKeyConfig(e) => Some(e), - ParseOhttpKeysError::InvalidFormat | ParseOhttpKeysError::InvalidPublicKey => None, - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_ohttp_keys_roundtrip() { - use std::str::FromStr; - - use ohttp::hpke::{Aead, Kdf, Kem}; - use ohttp::{KeyId, SymmetricSuite}; - const KEY_ID: KeyId = 1; - const KEM: Kem = Kem::K256Sha256; - const SYMMETRIC: &[SymmetricSuite] = - &[ohttp::SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)]; - let keys = OhttpKeys(ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap()); - let serialized = &keys.to_string(); - let deserialized = OhttpKeys::from_str(serialized).unwrap(); - assert_eq!(keys.encode().unwrap(), deserialized.encode().unwrap()); - } -} diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index 1b19feff..ce304c00 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -12,7 +12,7 @@ mod integration { use bitcoind::bitcoincore_rpc::{self, RpcApi}; use log::{log_enabled, Level}; use once_cell::sync::{Lazy, OnceCell}; - use payjoin::send::RequestBuilder; + use payjoin::send::SenderBuilder; use payjoin::{PjUri, PjUriBuilder, Request, Uri}; use tracing_subscriber::{EnvFilter, FmtSubscriber}; use url::Url; @@ -92,7 +92,7 @@ mod integration { .unwrap(); let psbt = build_original_psbt(&sender, &uri)?; debug!("Original psbt: {:#?}", psbt); - let (req, ctx) = RequestBuilder::from_psbt_and_uri(psbt, uri)? + let (req, ctx) = SenderBuilder::from_psbt_and_uri(psbt, uri)? .build_with_additional_fee(Amount::from_sat(10000), None, FeeRate::ZERO, false)? .extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); @@ -157,7 +157,7 @@ mod integration { .unwrap(); let psbt = build_original_psbt(&sender, &uri)?; debug!("Original psbt: {:#?}", psbt); - let (req, _ctx) = RequestBuilder::from_psbt_and_uri(psbt, uri)? + let (req, _ctx) = SenderBuilder::from_psbt_and_uri(psbt, uri)? .build_with_additional_fee(Amount::from_sat(10000), None, FeeRate::ZERO, false)? .extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); @@ -178,9 +178,8 @@ mod integration { use bitcoin::Address; use http::StatusCode; - use payjoin::receive::v2::{ - ActiveSession, PayjoinProposal, SessionInitializer, UncheckedProposal, - }; + use payjoin::receive::v2::{PayjoinProposal, Receiver, UncheckedProposal}; + use payjoin::send::Context; use payjoin::{OhttpKeys, PjUri, UriExt}; use reqwest::{Client, ClientBuilder, Error, Response}; use testcontainers_modules::redis::Redis; @@ -202,7 +201,7 @@ mod integration { let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap(); tokio::select!( _ = init_directory(port, (cert.clone(), key)) => assert!(false, "Directory server is long running"), - res = enroll_with_bad_keys(directory, bad_ohttp_keys, cert) => { + res = try_request_with_bad_keys(directory, bad_ohttp_keys, cert) => { assert_eq!( res.unwrap().headers().get("content-type").unwrap(), "application/problem+json" @@ -210,7 +209,7 @@ mod integration { } ); - async fn enroll_with_bad_keys( + async fn try_request_with_bad_keys( directory: Url, bad_ohttp_keys: OhttpKeys, cert_der: Vec, @@ -221,13 +220,8 @@ mod integration { let mock_address = Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4") .unwrap() .assume_checked(); - let mut bad_initializer = SessionInitializer::new( - mock_address, - directory, - bad_ohttp_keys, - mock_ohttp_relay, - None, - ); + let mut bad_initializer = + Receiver::new(mock_address, directory, bad_ohttp_keys, mock_ohttp_relay, None); let (req, _ctx) = bad_initializer.extract_req().expect("Failed to extract request"); agent.post(req.url).body(req.body).send().await } @@ -270,10 +264,8 @@ mod integration { address.clone(), directory.clone(), ohttp_keys.clone(), - cert_der, Some(Duration::from_secs(0)), - ) - .await?; + ); match session.extract_req() { // Internal error types are private, so check against a string Err(err) => assert!(err.to_string().contains("expired")), @@ -292,9 +284,9 @@ mod integration { Some(std::time::SystemTime::now()), ) .build(); - let mut expired_req_ctx = RequestBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? + let mut expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? .build_non_incentivizing(FeeRate::BROADCAST_MIN)?; - match expired_req_ctx.extract_v2(directory.to_owned()) { + match expired_req_ctx.extract_highest_version(directory.to_owned()) { // Internal error types are private, so check against a string Err(err) => assert!(err.to_string().contains("expired")), _ => assert!(false, "Expired send session should error"), @@ -340,10 +332,8 @@ mod integration { address.clone(), directory.clone(), ohttp_keys.clone(), - cert_der.clone(), None, - ) - .await?; + ); println!("session: {:#?}", &session); let pj_uri_string = session.pj_uri_builder().build().to_string(); // Poll receive request @@ -364,10 +354,14 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_sweep_psbt(&sender, &pj_uri)?; - let mut req_ctx = RequestBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; let (Request { url, body, content_type, .. }, send_ctx) = - req_ctx.extract_v2(directory.to_owned())?; + req_ctx.extract_highest_version(directory.to_owned())?; + let send_ctx = match send_ctx { + Context::V2(ctx) => ctx, + _ => panic!("V2 context expected"), + }; let response = agent .post(url.clone()) .header("Content-Type", content_type) @@ -377,10 +371,9 @@ mod integration { .unwrap(); log::info!("Response: {:#?}", &response); assert!(response.status().is_success()); - let response_body = + let send_ctx = send_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?; - // No response body yet since we are async and pushed fallback_psbt to the buffer - assert!(response_body.is_none()); + // POST Original PSBT // ********************** // Inside the Receiver: @@ -394,7 +387,12 @@ mod integration { let mut payjoin_proposal = handle_directory_proposal(&receiver, proposal, None); assert!(!payjoin_proposal.is_output_substitution_disabled()); let (req, ctx) = payjoin_proposal.extract_v2_req()?; - let response = agent.post(req.url).body(req.body).send().await?; + let response = agent + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await?; let res = response.bytes().await?.to_vec(); payjoin_proposal.process_res(res, ctx)?; @@ -402,11 +400,18 @@ mod integration { // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts // Replay post fallback to get the response - let (Request { url, body, .. }, send_ctx) = - req_ctx.extract_v2(directory.to_owned())?; - let response = agent.post(url).body(body).send().await?; + let (Request { url, body, content_type, .. }, ohttp_ctx) = + send_ctx.extract_req(directory.to_owned())?; + let response = agent + .post(url.clone()) + .header("Content-Type", content_type) + .body(body.clone()) + .send() + .await + .unwrap(); + log::info!("Response: {:#?}", &response); let checked_payjoin_proposal_psbt = send_ctx - .process_response(&mut response.bytes().await?.to_vec().as_slice())? + .process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)? .unwrap(); let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?; @@ -493,10 +498,8 @@ mod integration { address.clone(), directory.clone(), ohttp_keys.clone(), - cert_der.clone(), None, - ) - .await?; + ); println!("session: {:#?}", &session); let pj_uri_string = session.pj_uri_builder().build().to_string(); // Poll receive request @@ -517,10 +520,10 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_sweep_psbt(&sender, &pj_uri)?; - let mut req_ctx = RequestBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; - let (Request { url, body, content_type, .. }, send_ctx) = - req_ctx.extract_v2(directory.to_owned())?; + let (Request { url, body, content_type, .. }, post_ctx) = + req_ctx.extract_highest_version(directory.to_owned())?; let response = agent .post(url.clone()) .header("Content-Type", content_type) @@ -530,10 +533,23 @@ mod integration { .unwrap(); log::info!("Response: {:#?}", &response); assert!(response.status().is_success()); - let response_body = - send_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?; + let get_ctx = match post_ctx { + Context::V2(ctx) => + ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, + _ => panic!("V2 context expected"), + }; + let (Request { url, body, content_type, .. }, ohttp_ctx) = + get_ctx.extract_req(directory.to_owned())?; + let response = agent + .post(url.clone()) + .header("Content-Type", content_type) + .body(body.clone()) + .send() + .await?; // No response body yet since we are async and pushed fallback_psbt to the buffer - assert!(response_body.is_none()); + assert!(get_ctx + .process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)? + .is_none()); // ********************** // Inside the Receiver: @@ -557,11 +573,16 @@ mod integration { // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts // Replay post fallback to get the response - let (Request { url, body, .. }, send_ctx) = - req_ctx.extract_v2(directory.to_owned())?; - let response = agent.post(url).body(body).send().await?; - let checked_payjoin_proposal_psbt = send_ctx - .process_response(&mut response.bytes().await?.to_vec().as_slice())? + let (Request { url, body, content_type, .. }, ohttp_ctx) = + get_ctx.extract_req(directory.to_owned())?; + let response = agent + .post(url.clone()) + .header("Content-Type", content_type) + .body(body.clone()) + .send() + .await?; + let checked_payjoin_proposal_psbt = get_ctx + .process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)? .unwrap(); let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?; @@ -600,9 +621,9 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_original_psbt(&sender, &pj_uri)?; - let mut req_ctx = RequestBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; - let (req, ctx) = req_ctx.extract_v2(EXAMPLE_URL.to_owned())?; + let (req, ctx) = req_ctx.extract_highest_version(EXAMPLE_URL.to_owned())?; let headers = HeaderMock::new(&req.body, req.content_type); // ********************** @@ -614,8 +635,11 @@ mod integration { // ********************** // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts - let checked_payjoin_proposal_psbt = - ctx.process_response(&mut response.as_bytes())?.unwrap(); + let ctx = match ctx { + Context::V1(ctx) => ctx, + _ => panic!("V1 context expected"), + }; + let checked_payjoin_proposal_psbt = ctx.process_response(&mut response.as_bytes())?; let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?; @@ -661,14 +685,7 @@ mod integration { .await?; let address = receiver.get_new_address(None, None)?.assume_checked(); - let mut session = initialize_session( - address, - directory, - ohttp_keys.clone(), - cert_der.clone(), - None, - ) - .await?; + let mut session = initialize_session(address, directory, ohttp_keys.clone(), None); let pj_uri_string = session.pj_uri_builder().build().to_string(); @@ -682,7 +699,7 @@ mod integration { .unwrap(); let psbt = build_original_psbt(&sender, &pj_uri)?; let (Request { url, body, content_type, .. }, send_ctx) = - RequestBuilder::from_psbt_and_uri(psbt, pj_uri)? + SenderBuilder::from_psbt_and_uri(psbt, pj_uri)? .build_with_additional_fee( Amount::from_sat(10000), None, @@ -780,14 +797,7 @@ mod integration { let db = docker.run(Redis::default()); let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); println!("Database running on {}", db.get_host_port_ipv4(6379)); - payjoin_directory::listen_tcp_with_tls( - format!("http://localhost:{}", port), - port, - db_host, - timeout, - local_cert_key, - ) - .await + payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await } // generates or gets a DER encoded localhost cert and key. @@ -802,27 +812,20 @@ mod integration { (cert_der, key_der) } - async fn initialize_session( + fn initialize_session( address: Address, directory: Url, ohttp_keys: OhttpKeys, - cert_der: Vec, custom_expire_after: Option, - ) -> Result { + ) -> Receiver { let mock_ohttp_relay = directory.clone(); // pass through to directory - let mut initializer = SessionInitializer::new( + Receiver::new( address, directory.clone(), ohttp_keys, mock_ohttp_relay.clone(), custom_expire_after, - ); - let (req, ctx) = initializer.extract_req()?; - println!("enroll req: {:#?}", &req); - let response = - http_agent(cert_der).unwrap().post(req.url).body(req.body).send().await?; - assert!(response.status().is_success()); - Ok(initializer.process_res(response.bytes().await?.to_vec().as_slice(), ctx)?) + ) } fn handle_directory_proposal( @@ -1024,7 +1027,7 @@ mod integration { let psbt = build_original_psbt(&sender, &uri)?; log::debug!("Original psbt: {:#?}", psbt); let max_additional_fee = Amount::from_sat(1000); - let (req, ctx) = RequestBuilder::from_psbt_and_uri(psbt.clone(), uri)? + let (req, ctx) = SenderBuilder::from_psbt_and_uri(psbt.clone(), uri)? .build_with_additional_fee(max_additional_fee, None, FeeRate::ZERO, false)? .extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); @@ -1101,7 +1104,7 @@ mod integration { .unwrap(); let psbt = build_original_psbt(&sender, &uri)?; log::debug!("Original psbt: {:#?}", psbt); - let (req, ctx) = RequestBuilder::from_psbt_and_uri(psbt.clone(), uri)? + let (req, ctx) = SenderBuilder::from_psbt_and_uri(psbt.clone(), uri)? .build_with_additional_fee(Amount::from_sat(10000), None, FeeRate::ZERO, false)? .extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type);