Skip to content

Commit

Permalink
Remove ohttp_relay from SessionContext
Browse files Browse the repository at this point in the history
The context may use multiple OHTTP Relays, so pass it in the extract
method instead.
  • Loading branch information
DanGould committed Jan 8, 2025
1 parent eaf2398 commit 994f2db
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 65 deletions.
13 changes: 4 additions & 9 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,8 @@ impl AppTrait for App {
let address = self.bitcoind()?.get_new_address(None, None)?.assume_checked();
let amount = Amount::from_sat(amount_arg.parse()?);
let ohttp_keys = unwrap_ohttp_keys_or_else_fetch(&self.config).await?;
let session = Receiver::new(
address,
self.config.pj_directory.clone(),
ohttp_keys.clone(),
self.config.ohttp_relay.clone(),
None,
);
let session =
Receiver::new(address, self.config.pj_directory.clone(), ohttp_keys.clone(), None);
self.db.insert_recv_session(session.clone())?;
self.spawn_payjoin_receiver(session, Some(amount)).await
}
Expand Down Expand Up @@ -138,7 +133,7 @@ impl App {
.process_v2_proposal(res)
.map_err(|e| anyhow!("Failed to process proposal {}", e))?;
let (req, ohttp_ctx) = payjoin_proposal
.extract_v2_req()
.extract_v2_req(&self.config.ohttp_relay)
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
println!("Got a request from the sender. Responding with a Payjoin proposal.");
let res = post_request(req).await?;
Expand Down Expand Up @@ -239,7 +234,7 @@ impl App {
session: &mut payjoin::receive::v2::Receiver,
) -> Result<payjoin::receive::v2::UncheckedProposal> {
loop {
let (req, context) = session.extract_req()?;
let (req, context) = session.extract_req(&self.config.ohttp_relay)?;
println!("Polling receive request...");
let ohttp_response = post_request(req).await?;
let proposal = session
Expand Down
21 changes: 10 additions & 11 deletions payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ struct SessionContext {
subdirectory: Option<url::Url>,
ohttp_keys: OhttpKeys,
expiry: SystemTime,
ohttp_relay: url::Url,
s: HpkeKeyPair,
e: Option<HpkePublicKey>,
}
Expand Down Expand Up @@ -67,7 +66,6 @@ impl Receiver {
/// - `address`: The Bitcoin address for the payjoin session.
/// - `directory`: The URL of the store-and-forward payjoin directory.
/// - `ohttp_keys`: The OHTTP keys used for encrypting and decrypting HTTP requests and responses.
/// - `ohttp_relay`: The URL of the OHTTP relay, used to keep client IP address confidential.
/// - `expire_after`: The duration after which the session expires.
///
/// # Returns
Expand All @@ -79,7 +77,6 @@ impl Receiver {
address: Address,
directory: Url,
ohttp_keys: OhttpKeys,
ohttp_relay: Url,
expire_after: Option<Duration>,
) -> Self {
Self {
Expand All @@ -88,7 +85,6 @@ impl Receiver {
directory,
subdirectory: None,
ohttp_keys,
ohttp_relay,
expiry: SystemTime::now()
+ expire_after.unwrap_or(TWENTY_FOUR_HOURS_DEFAULT_EXPIRY),
s: HpkeKeyPair::gen_keypair(),
Expand All @@ -98,13 +94,16 @@ impl Receiver {
}

/// Extract an OHTTP Encapsulated HTTP GET request for the Original PSBT
pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), SessionError> {
pub fn extract_req(
&mut self,
ohttp_relay: &Url,
) -> Result<(Request, ohttp::ClientResponse), SessionError> {
if SystemTime::now() > self.context.expiry {
return Err(InternalSessionError::Expired(self.context.expiry).into());
}
let (body, ohttp_ctx) =
self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulation)?;
let url = self.context.ohttp_relay.clone();
let url = ohttp_relay.clone();
let req = Request::new_v2(url, body);
Ok((req, ohttp_ctx))
}
Expand Down Expand Up @@ -468,7 +467,10 @@ impl PayjoinProposal {
pub fn psbt(&self) -> &Psbt { self.v1.psbt() }

#[cfg(feature = "v2")]
pub fn extract_v2_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> {
pub fn extract_v2_req(
&mut self,
ohttp_relay: &Url,
) -> Result<(Request, ohttp::ClientResponse), Error> {
let target_resource: Url;
let body: Vec<u8>;
let method: &str;
Expand Down Expand Up @@ -502,8 +504,7 @@ impl PayjoinProposal {
target_resource.as_str(),
Some(&body),
)?;
let url = self.context.ohttp_relay.clone();
let req = Request::new_v2(url, body);
let req = Request::new_v2(ohttp_relay.clone(), body);
Ok((req, ctx))
}

Expand Down Expand Up @@ -562,7 +563,6 @@ mod test {
ohttp_keys: OhttpKeys(
ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(),
),
ohttp_relay: url::Url::parse("https://relay.com").unwrap(),
expiry: SystemTime::now() + Duration::from_secs(60),
s: HpkeKeyPair::gen_keypair(),
e: None,
Expand All @@ -589,7 +589,6 @@ mod test {
directory: arbitrary_url.clone(),
subdirectory: None,
ohttp_keys,
ohttp_relay: arbitrary_url.clone(),
expiry: SystemTime::now() + Duration::from_secs(60),
s: receiver_keys,
e: None,
Expand Down
69 changes: 24 additions & 45 deletions payjoin/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,10 @@ mod integration {
.unwrap()
.assume_checked();
let mut bad_initializer =
Receiver::new(mock_address, directory, bad_ohttp_keys, mock_ohttp_relay, None);
let (req, _ctx) = bad_initializer.extract_req().expect("Failed to extract request");
Receiver::new(mock_address, directory, bad_ohttp_keys, None);
let (req, _ctx) = bad_initializer
.extract_req(&mock_ohttp_relay)
.expect("Failed to extract request");
agent.post(req.url).body(req.body).send().await
}
}
Expand Down Expand Up @@ -269,7 +271,7 @@ mod integration {
wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap();
wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap();
let ohttp_keys = payjoin::io::fetch_ohttp_keys_with_cert(
ohttp_relay,
ohttp_relay.clone(),
directory.clone(),
cert_der,
)
Expand All @@ -279,13 +281,13 @@ mod integration {
// Inside the Receiver:
let address = receiver.get_new_address(None, None)?.assume_checked();
// test session with expiry in the past
let mut expired_receiver = initialize_session(
let mut expired_receiver = Receiver::new(
address.clone(),
directory.clone(),
ohttp_keys.clone(),
Some(Duration::from_secs(0)),
);
match expired_receiver.extract_req() {
match expired_receiver.extract_req(&ohttp_relay) {
// Internal error types are private, so check against a string
Err(err) => assert!(err.to_string().contains("expired")),
_ => panic!("Expired receive session should error"),
Expand Down Expand Up @@ -341,7 +343,7 @@ mod integration {
wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap();
wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap();
let ohttp_keys = payjoin::io::fetch_ohttp_keys_with_cert(
ohttp_relay,
ohttp_relay.clone(),
directory.clone(),
cert_der.clone(),
)
Expand All @@ -351,16 +353,12 @@ mod integration {
let address = receiver.get_new_address(None, None)?.assume_checked();

// test session with expiry in the future
let mut session = initialize_session(
address.clone(),
directory.clone(),
ohttp_keys.clone(),
None,
);
let mut session =
Receiver::new(address.clone(), directory.clone(), ohttp_keys.clone(), None);
println!("session: {:#?}", &session);
let pj_uri_string = session.pj_uri().to_string();
// Poll receive request
let (req, ctx) = session.extract_req()?;
let (req, ctx) = session.extract_req(&ohttp_relay)?;
let response = agent.post(req.url).body(req.body).send().await?;
assert!(response.status().is_success());
let response_body =
Expand Down Expand Up @@ -397,14 +395,14 @@ mod integration {
// Inside the Receiver:

// GET fallback psbt
let (req, ctx) = session.extract_req()?;
let (req, ctx) = session.extract_req(&ohttp_relay)?;
let response = agent.post(req.url).body(req.body).send().await?;
// POST payjoin
let proposal =
session.process_res(response.bytes().await?.to_vec().as_slice(), ctx)?.unwrap();
let mut payjoin_proposal = handle_directory_proposal(&receiver, proposal, None);
assert!(!payjoin_proposal.is_output_substitution_disabled());
let (req, ctx) = payjoin_proposal.extract_v2_req()?;
let (req, ctx) = payjoin_proposal.extract_v2_req(&ohttp_relay)?;
let response = agent
.post(req.url)
.header("Content-Type", req.content_type)
Expand Down Expand Up @@ -482,7 +480,7 @@ mod integration {
wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap();
wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap();
let ohttp_keys = payjoin::io::fetch_ohttp_keys_with_cert(
ohttp_relay,
ohttp_relay.clone(),
directory.clone(),
cert_der,
)
Expand Down Expand Up @@ -522,16 +520,12 @@ mod integration {
let address = receiver.get_new_address(None, None)?.assume_checked();

// test session with expiry in the future
let mut session = initialize_session(
address.clone(),
directory.clone(),
ohttp_keys.clone(),
None,
);
let mut session =
Receiver::new(address.clone(), directory.clone(), ohttp_keys.clone(), None);
println!("session: {:#?}", &session);
let pj_uri_string = session.pj_uri().to_string();
// Poll receive request
let (req, ctx) = session.extract_req()?;
let (req, ctx) = session.extract_req(&ohttp_relay)?;
let response = agent.post(req.url).body(req.body).send().await?;
assert!(response.status().is_success());
let response_body = session.process_res(&response.bytes().await?, ctx).unwrap();
Expand Down Expand Up @@ -576,7 +570,7 @@ mod integration {
// Inside the Receiver:

// GET fallback psbt
let (req, ctx) = session.extract_req()?;
let (req, ctx) = session.extract_req(&ohttp_relay)?;
let response = agent.post(req.url).body(req.body).send().await?;
// POST payjoin
let proposal =
Expand All @@ -585,7 +579,7 @@ mod integration {
let mut payjoin_proposal =
handle_directory_proposal(&receiver, proposal, Some(inputs));
assert!(!payjoin_proposal.is_output_substitution_disabled());
let (req, ctx) = payjoin_proposal.extract_v2_req()?;
let (req, ctx) = payjoin_proposal.extract_v2_req(&ohttp_relay)?;
let response = agent.post(req.url).body(req.body).send().await?;
payjoin_proposal.process_res(&response.bytes().await?, ctx)?;

Expand Down Expand Up @@ -703,14 +697,14 @@ mod integration {
wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await?;
wait_for_service_ready(directory.clone(), agent.clone()).await?;
let ohttp_keys = payjoin::io::fetch_ohttp_keys_with_cert(
ohttp_relay,
ohttp_relay.clone(),
directory.clone(),
cert_der.clone(),
)
.await?;
let address = receiver.get_new_address(None, None)?.assume_checked();

let mut session = initialize_session(address, directory, ohttp_keys.clone(), None);
let mut session = Receiver::new(address, directory, ohttp_keys.clone(), None);

let pj_uri_string = session.pj_uri().to_string();

Expand Down Expand Up @@ -746,10 +740,11 @@ mod integration {
let agent_clone: Arc<Client> = agent.clone();
let receiver: Arc<bitcoincore_rpc::Client> = Arc::new(receiver);
let receiver_clone = receiver.clone();
let ohttp_relay_clone = ohttp_relay.clone();
let receiver_loop = tokio::task::spawn(async move {
let agent_clone = agent_clone.clone();
let (response, ctx) = loop {
let (req, ctx) = session.extract_req().unwrap();
let (req, ctx) = session.extract_req(&ohttp_relay_clone).unwrap();
let response = agent_clone.post(req.url).body(req.body).send().await?;

if response.status() == 200 {
Expand All @@ -770,7 +765,7 @@ mod integration {
assert!(payjoin_proposal.is_output_substitution_disabled());
// Respond with payjoin psbt within the time window the sender is willing to wait
// this response would be returned as http response to the sender
let (req, ctx) = payjoin_proposal.extract_v2_req().unwrap();
let (req, ctx) = payjoin_proposal.extract_v2_req(&ohttp_relay).unwrap();
let response = agent_clone.post(req.url).body(req.body).send().await?;
payjoin_proposal
.process_res(&response.bytes().await?, ctx)
Expand Down Expand Up @@ -836,22 +831,6 @@ mod integration {
(cert_der, key_der)
}

fn initialize_session(
address: Address,
directory: Url,
ohttp_keys: OhttpKeys,
custom_expire_after: Option<Duration>,
) -> Receiver {
let mock_ohttp_relay = directory.clone(); // pass through to directory
Receiver::new(
address,
directory.clone(),
ohttp_keys,
mock_ohttp_relay.clone(),
custom_expire_after,
)
}

fn handle_directory_proposal(
receiver: &bitcoincore_rpc::Client,
proposal: UncheckedProposal,
Expand Down

0 comments on commit 994f2db

Please sign in to comment.