From cf432401c87401029b88f90edbe45f1ffae661aa Mon Sep 17 00:00:00 2001 From: DanGould Date: Tue, 22 Oct 2024 14:17:33 -0400 Subject: [PATCH] Test that the encrypted payloads are uniform. This randomized test will generate a false negative with negligible probability if all encrypted messages share an identical byte at a given position by chance. It should fail deterministically if any bit position has a fixed value. --- payjoin/src/hpke.rs | 71 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index c5cfb429..82719b84 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -466,4 +466,75 @@ mod test { }) ); } + + /// Test that the encrypted payloads are uniform. + /// + /// This randomized test will generate a false negative with negligible probability + /// if all encrypted messages share an identical bit at a given position by chance. + /// It should fail deterministically if any bit position has a fixed value. + #[test] + fn test_encrypted_payload_bit_uniformity() { + fn generate_messages(count: usize) -> (Vec>, Vec>) { + let mut messages_a = Vec::with_capacity(count); + let mut messages_b = Vec::with_capacity(count); + + for _ in 0..count { + let sender_keypair = HpkeKeyPair::gen_keypair(); + let receiver_keypair = HpkeKeyPair::gen_keypair(); + let reply_keypair = HpkeKeyPair::gen_keypair(); + + let plaintext_a = vec![0u8; PADDED_PLAINTEXT_A_LENGTH]; + let message_a = encrypt_message_a( + plaintext_a, + reply_keypair.public_key(), + receiver_keypair.public_key(), + ) + .expect("encryption should work"); + + let plaintext_b = vec![0u8; PADDED_PLAINTEXT_B_LENGTH]; + let message_b = + encrypt_message_b(plaintext_b, &receiver_keypair, sender_keypair.public_key()) + .expect("encryption should work"); + + messages_a.push(message_a); + messages_b.push(message_b); + } + + (messages_a, messages_b) + } + + /// For each of the (n choose 2) = n*(n-1)/2 combinations, ensure their lengths + /// are equal, XOR the two messages together, and OR this into an accumulator + /// that starts as all 0x00s. + fn check_uniformity(messages: Vec>) { + let mut accumulator = vec![0u8; PADDED_MESSAGE_BYTES]; + + for (i, msg1) in messages.iter().enumerate() { + for msg2 in messages.iter().skip(i + 1) { + assert_eq!(msg1.len(), msg2.len(), "Message lengths should be equal"); + for (acc, (&b1, &b2)) in + accumulator.iter_mut().zip(msg1.iter().zip(msg2.iter())) + { + *acc |= b1 ^ b2; + } + } + } + + assert!( + accumulator.iter().any(|&b| b != 0xFF), + "All bits in the accumulator should be non-zero" + ); + } + + let (messages_a, messages_b) = generate_messages(80); + let mut combined_messages = messages_a; + combined_messages.extend(messages_b); + check_uniformity(combined_messages); + + let (messages_a, _) = generate_messages(40); + check_uniformity(messages_a); + + let (_, messages_b) = generate_messages(40); + check_uniformity(messages_b); + } }