From 1e95f581223aa8f6ee27fe4dfd48378fdbe90f65 Mon Sep 17 00:00:00 2001 From: porcuquine Date: Thu, 20 Jan 2022 15:23:53 -0800 Subject: [PATCH] Add authenticated encryption. --- rust-toolchain | 2 +- src/error.rs | 2 + src/hash_type.rs | 17 +++- src/poseidon.rs | 197 ++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 212 insertions(+), 6 deletions(-) diff --git a/rust-toolchain b/rust-toolchain index ba0a7191..cbc4bf7f 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.51.0 +nightly-2021-12-12 diff --git a/src/error.rs b/src/error.rs index 49d598a2..9593aa1b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -55,6 +55,7 @@ pub enum Error { #[cfg(feature = "futhark")] TritonError(String), DecodingError, + TagMismatch, Other(String), } @@ -95,6 +96,7 @@ impl fmt::Display for Error { #[cfg(feature = "futhark")] Error::TritonError(e) => write!(f, "Neptune-triton Error: {}", e), Error::DecodingError => write!(f, "PrimeFieldDecodingError"), + Error::TagMismatch => write!(f, "Tag mismatch"), Error::Other(s) => write!(f, "{}", s), } } diff --git a/src/hash_type.rs b/src/hash_type.rs index 3eb559a3..a8e9c1e2 100644 --- a/src/hash_type.rs +++ b/src/hash_type.rs @@ -39,7 +39,7 @@ impl> HashType { // bitmask HashType::MerkleTreeSparse(bitmask) => F::from(*bitmask), // 2^64 - HashType::VariableLength => pow2::(64), + HashType::VariableLength => self.encryption_domain_tag(0, 0), // length * 2^64 // length must be greater than 0 and <= arity HashType::ConstantLength(length) => { @@ -58,6 +58,21 @@ impl> HashType { }) } + pub fn encryption_domain_tag(&self, key_length: usize, message_length: usize) -> F { + match self { + Self::Encryption => { + assert!(key_length <= u64::MAX as usize); + assert!(message_length <= u64::MAX as usize); + + let mut tag: F = pow2::(32); + tag += x_pow2::(key_length as u64, 33); + tag += x_pow2::(message_length as u64, 97); + tag + } + _ => panic!("cannot set encryption domain tag"), + } + } + fn strength_tag_component(strength: &Strength) -> F { let id = match strength { // Standard strength doesn't affect the base tag. diff --git a/src/poseidon.rs b/src/poseidon.rs index 448accc8..7ac7b70d 100644 --- a/src/poseidon.rs +++ b/src/poseidon.rs @@ -321,12 +321,16 @@ where /// Restore the initial state pub fn reset(&mut self) { - self.constants_offset = 0; - self.current_round = 0; + self.partial_reset(); self.elements[1..] .iter_mut() .for_each(|l| *l = F::from(0u64)); self.elements[0] = self.constants.domain_tag; + } + + fn partial_reset(&mut self) { + self.constants_offset = 0; + self.current_round = 0; self.pos = 1; } @@ -343,7 +347,6 @@ where Ok(self.pos - 1) } - pub fn hash_in_mode(&mut self, mode: HashMode) -> F { self.apply_padding(); match mode { @@ -367,11 +370,22 @@ where // There is nothing to do here, but only because the state elements were // initialized to zero, and that is what we need to pad with. } + HashType::Encryption => { + // We don't need to add padding because we will always know the exact length, + // since plaintext and ciphertext are the same size. + // Actual key and message lengths must be stored in the domain tag + // so no state is ever shared between (key, message) pairs. + } HashType::VariableLength => todo!(), _ => (), } } + #[inline] + pub fn extract_output(&self) -> F { + self.elements[1] + } + pub fn hash_optimized_static(&mut self) -> F { // The first full round should use the initial constants. self.add_round_constants(); @@ -398,7 +412,7 @@ where self.constants.compressed_round_constants.len() ); - self.elements[1] + self.extract_output() } fn full_round(&mut self, last_round: bool) { @@ -545,6 +559,98 @@ where } } +/// Encryption +/// https://link.springer.com/content/pdf/10.1007%2F978-3-642-28496-0_19.pdf +impl<'a, F, A> Poseidon<'a, F, A> +where + F: PrimeField, + A: Arity, +{ + fn set_encryption_domain_tag(&mut self, key_length: usize, message_length: usize) { + self.elements[0] = self + .constants + .hash_type + .encryption_domain_tag(key_length, message_length); + } + + /// Initialize state with key, shared by encryption and decryption. + fn shared_initialize(&mut self, key: &[F], message_length: usize) -> Result<(), Error> { + self.reset(); + self.set_encryption_domain_tag(key.len(), message_length); + + for chunk in key.chunks(self.constants.arity()) { + self.duplex(chunk)?; + } + + Ok(()) + } + + /// Incorporate input (by element-wise field addition), then permute the state. + /// Output is left in state (self.elements). + pub fn duplex(&mut self, input: &[F]) -> Result<(), Error> { + assert!(input.len() <= self.constants.arity()); + self.partial_reset(); + + for elt in input.iter() { + self.elements[self.pos] += elt; + self.pos += 1; + } + + // FIXME: WHy does this break? + self.hash(); + //self.hash_in_mode(Correct); + + Ok(()) + } + + pub fn encrypt(&mut self, key: &[F], plaintext: &[F]) -> Result<(Vec, F), Error> { + // https://link.springer.com/content/pdf/10.1007%2F978-3-642-28496-0_19.pdf + let arity = A::to_usize(); + assert!(key.len() > 0); + + self.shared_initialize(key, plaintext.len())?; + + let mut ciphertext = Vec::with_capacity(plaintext.len()); + + for plaintext_chunk in plaintext.chunks(arity) { + for (elt, resp) in plaintext_chunk.iter().zip(self.elements.iter().skip(1)) { + ciphertext.push(*elt + *resp); + } + self.duplex(plaintext_chunk)?; + } + + let tag = self.extract_output(); + Ok((ciphertext, tag)) + } + + pub fn decrypt(&mut self, key: &[F], ciphertext: &[F], tag: &F) -> Result, Error> { + let arity = A::to_usize(); + assert!(key.len() > 0); + + self.shared_initialize(key, ciphertext.len())?; + let mut plaintext = Vec::with_capacity(ciphertext.len()); + + let mut last_chunk_start = 0; + for ciphertext_chunk in ciphertext.chunks(arity) { + for (elt, resp) in ciphertext_chunk.iter().zip(self.elements.iter().skip(1)) { + plaintext.push(*elt - *resp); + } + let plaintext_chunk = &plaintext[last_chunk_start..]; + self.duplex(plaintext_chunk)?; + + last_chunk_start += arity; + } + + let computed_tag = self.extract_output(); + + if *tag != computed_tag { + return Err(Error::TagMismatch); + }; + + Ok(plaintext) + } +} + #[derive(Debug)] pub struct SimplePoseidonBatchHasher where @@ -821,4 +927,87 @@ mod tests { default_constants.partial_rounds ); } + + #[test] + fn encrypt_decrypt() { + let constants = PoseidonConstants::::new_with_strength_and_type( + Strength::Standard, + HashType::Encryption, + ); + let mut p = Poseidon::::new(&constants); + + let plaintext = [1, 2, 3, 4, 5, 6, 7, 8, 9] + .iter() + .map(|n| Fr::from(*n as u64)) + .collect::>(); + + let key = [987, 234] + .iter() + .map(|n| Fr::from(*n as u64)) + .collect::>(); + + let (ciphertext, tag) = p.encrypt(&key, &plaintext).unwrap(); + let decrypted = p.decrypt(&key, &ciphertext, &tag).unwrap(); + + assert_eq!(plaintext, decrypted); + assert_eq!( + ciphertext, + [ + scalar_from_u64s([ + 0xaec64216978527ac, + 0xdf5f10f7a1a9a8b7, + 0xe34ddf5197d75feb, + 0x1be88365866ae3d6 + ]), + scalar_from_u64s([ + 0xda0d5ec9eff654da, + 0x7017055a0a081c34, + 0x1bce42bb6937ab48, + 0x35a2e74eeaa97f6c + ]), + scalar_from_u64s([ + 0x298936f51cf3aa12, + 0x906cf40d00e4411c, + 0xc195c1ed48a6c223, + 0x4598c18291315dbc + ]), + scalar_from_u64s([ + 0xd59a3a87f0dec416, + 0x0d9fd7b5282925d8, + 0x0ea1b98d0b00d561, + 0x023704693c4abf1b + ]), + scalar_from_u64s([ + 0x211b61f66285bd55, + 0xbf26070055e78d4a, + 0x3682aa0ce38835cf, + 0x4e6a9d5424f77ac5 + ]), + scalar_from_u64s([ + 0xa1b8442758bec43b, + 0xaf3248c718643bf9, + 0x66ad9b69d73bc44a, + 0x243e604b5138226a + ]), + scalar_from_u64s([ + 0xf92fd3ed19af0733, + 0x6b96bc196f6c2d5b, + 0xefe6d3b5c1dc730a, + 0x0dabad8c3dbd4147 + ]), + scalar_from_u64s([ + 0x556595727f046c2a, + 0xaecc434fb16c8631, + 0xd5da55ffc78a420f, + 0x081a166a1909cbed + ]), + scalar_from_u64s([ + 0x40d5a2d5052cb583, + 0x5c0b5265c006a5cb, + 0xfd936f0a297114f8, + 0x1191f085dc4d2286 + ]), + ] + ) + } }