Skip to content

Commit

Permalink
Add authenticated encryption.
Browse files Browse the repository at this point in the history
  • Loading branch information
porcuquine committed Jan 21, 2022
1 parent 76e328f commit 1e95f58
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 6 deletions.
2 changes: 1 addition & 1 deletion rust-toolchain
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.51.0
nightly-2021-12-12
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pub enum Error {
#[cfg(feature = "futhark")]
TritonError(String),
DecodingError,
TagMismatch,
Other(String),
}

Expand Down Expand Up @@ -95,6 +96,7 @@ impl fmt::Display for Error {
#[cfg(feature = "futhark")]
Error::TritonError(e) => write!(f, "Neptune-triton Error: {}", e),
Error::DecodingError => write!(f, "PrimeFieldDecodingError"),
Error::TagMismatch => write!(f, "Tag mismatch"),
Error::Other(s) => write!(f, "{}", s),
}
}
Expand Down
17 changes: 16 additions & 1 deletion src/hash_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl<F: PrimeField, A: Arity<F>> HashType<F, A> {
// bitmask
HashType::MerkleTreeSparse(bitmask) => F::from(*bitmask),
// 2^64
HashType::VariableLength => pow2::<F>(64),
HashType::VariableLength => self.encryption_domain_tag(0, 0),
// length * 2^64
// length must be greater than 0 and <= arity
HashType::ConstantLength(length) => {
Expand All @@ -58,6 +58,21 @@ impl<F: PrimeField, A: Arity<F>> HashType<F, A> {
})
}

pub fn encryption_domain_tag(&self, key_length: usize, message_length: usize) -> F {
match self {
Self::Encryption => {
assert!(key_length <= u64::MAX as usize);
assert!(message_length <= u64::MAX as usize);

let mut tag: F = pow2::<F>(32);
tag += x_pow2::<F>(key_length as u64, 33);
tag += x_pow2::<F>(message_length as u64, 97);
tag
}
_ => panic!("cannot set encryption domain tag"),
}
}

fn strength_tag_component(strength: &Strength) -> F {
let id = match strength {
// Standard strength doesn't affect the base tag.
Expand Down
197 changes: 193 additions & 4 deletions src/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,16 @@ where

/// Restore the initial state
pub fn reset(&mut self) {
self.constants_offset = 0;
self.current_round = 0;
self.partial_reset();
self.elements[1..]
.iter_mut()
.for_each(|l| *l = F::from(0u64));
self.elements[0] = self.constants.domain_tag;
}

fn partial_reset(&mut self) {
self.constants_offset = 0;
self.current_round = 0;
self.pos = 1;
}

Expand All @@ -343,7 +347,6 @@ where

Ok(self.pos - 1)
}

pub fn hash_in_mode(&mut self, mode: HashMode) -> F {
self.apply_padding();
match mode {
Expand All @@ -367,11 +370,22 @@ where
// There is nothing to do here, but only because the state elements were
// initialized to zero, and that is what we need to pad with.
}
HashType::Encryption => {
// We don't need to add padding because we will always know the exact length,
// since plaintext and ciphertext are the same size.
// Actual key and message lengths must be stored in the domain tag
// so no state is ever shared between (key, message) pairs.
}
HashType::VariableLength => todo!(),
_ => (),
}
}

#[inline]
pub fn extract_output(&self) -> F {
self.elements[1]
}

pub fn hash_optimized_static(&mut self) -> F {
// The first full round should use the initial constants.
self.add_round_constants();
Expand All @@ -398,7 +412,7 @@ where
self.constants.compressed_round_constants.len()
);

self.elements[1]
self.extract_output()
}

fn full_round(&mut self, last_round: bool) {
Expand Down Expand Up @@ -545,6 +559,98 @@ where
}
}

/// Encryption
/// https://link.springer.com/content/pdf/10.1007%2F978-3-642-28496-0_19.pdf
impl<'a, F, A> Poseidon<'a, F, A>
where
F: PrimeField,
A: Arity<F>,
{
fn set_encryption_domain_tag(&mut self, key_length: usize, message_length: usize) {
self.elements[0] = self
.constants
.hash_type
.encryption_domain_tag(key_length, message_length);
}

/// Initialize state with key, shared by encryption and decryption.
fn shared_initialize(&mut self, key: &[F], message_length: usize) -> Result<(), Error> {
self.reset();
self.set_encryption_domain_tag(key.len(), message_length);

for chunk in key.chunks(self.constants.arity()) {
self.duplex(chunk)?;
}

Ok(())
}

/// Incorporate input (by element-wise field addition), then permute the state.
/// Output is left in state (self.elements).
pub fn duplex(&mut self, input: &[F]) -> Result<(), Error> {
assert!(input.len() <= self.constants.arity());
self.partial_reset();

for elt in input.iter() {
self.elements[self.pos] += elt;
self.pos += 1;
}

// FIXME: WHy does this break?
self.hash();
//self.hash_in_mode(Correct);

Ok(())
}

pub fn encrypt(&mut self, key: &[F], plaintext: &[F]) -> Result<(Vec<F>, F), Error> {
// https://link.springer.com/content/pdf/10.1007%2F978-3-642-28496-0_19.pdf
let arity = A::to_usize();
assert!(key.len() > 0);

self.shared_initialize(key, plaintext.len())?;

let mut ciphertext = Vec::with_capacity(plaintext.len());

for plaintext_chunk in plaintext.chunks(arity) {
for (elt, resp) in plaintext_chunk.iter().zip(self.elements.iter().skip(1)) {
ciphertext.push(*elt + *resp);
}
self.duplex(plaintext_chunk)?;
}

let tag = self.extract_output();
Ok((ciphertext, tag))
}

pub fn decrypt(&mut self, key: &[F], ciphertext: &[F], tag: &F) -> Result<Vec<F>, Error> {
let arity = A::to_usize();
assert!(key.len() > 0);

self.shared_initialize(key, ciphertext.len())?;
let mut plaintext = Vec::with_capacity(ciphertext.len());

let mut last_chunk_start = 0;
for ciphertext_chunk in ciphertext.chunks(arity) {
for (elt, resp) in ciphertext_chunk.iter().zip(self.elements.iter().skip(1)) {
plaintext.push(*elt - *resp);
}
let plaintext_chunk = &plaintext[last_chunk_start..];
self.duplex(plaintext_chunk)?;

last_chunk_start += arity;
}

let computed_tag = self.extract_output();

if *tag != computed_tag {
return Err(Error::TagMismatch);
};

Ok(plaintext)
}
}

#[derive(Debug)]
pub struct SimplePoseidonBatchHasher<A>
where
Expand Down Expand Up @@ -821,4 +927,87 @@ mod tests {
default_constants.partial_rounds
);
}

#[test]
fn encrypt_decrypt() {
let constants = PoseidonConstants::<Fr, U2>::new_with_strength_and_type(
Strength::Standard,
HashType::Encryption,
);
let mut p = Poseidon::<Fr, U2>::new(&constants);

let plaintext = [1, 2, 3, 4, 5, 6, 7, 8, 9]
.iter()
.map(|n| Fr::from(*n as u64))
.collect::<Vec<_>>();

let key = [987, 234]
.iter()
.map(|n| Fr::from(*n as u64))
.collect::<Vec<_>>();

let (ciphertext, tag) = p.encrypt(&key, &plaintext).unwrap();
let decrypted = p.decrypt(&key, &ciphertext, &tag).unwrap();

assert_eq!(plaintext, decrypted);
assert_eq!(
ciphertext,
[
scalar_from_u64s([
0xaec64216978527ac,
0xdf5f10f7a1a9a8b7,
0xe34ddf5197d75feb,
0x1be88365866ae3d6
]),
scalar_from_u64s([
0xda0d5ec9eff654da,
0x7017055a0a081c34,
0x1bce42bb6937ab48,
0x35a2e74eeaa97f6c
]),
scalar_from_u64s([
0x298936f51cf3aa12,
0x906cf40d00e4411c,
0xc195c1ed48a6c223,
0x4598c18291315dbc
]),
scalar_from_u64s([
0xd59a3a87f0dec416,
0x0d9fd7b5282925d8,
0x0ea1b98d0b00d561,
0x023704693c4abf1b
]),
scalar_from_u64s([
0x211b61f66285bd55,
0xbf26070055e78d4a,
0x3682aa0ce38835cf,
0x4e6a9d5424f77ac5
]),
scalar_from_u64s([
0xa1b8442758bec43b,
0xaf3248c718643bf9,
0x66ad9b69d73bc44a,
0x243e604b5138226a
]),
scalar_from_u64s([
0xf92fd3ed19af0733,
0x6b96bc196f6c2d5b,
0xefe6d3b5c1dc730a,
0x0dabad8c3dbd4147
]),
scalar_from_u64s([
0x556595727f046c2a,
0xaecc434fb16c8631,
0xd5da55ffc78a420f,
0x081a166a1909cbed
]),
scalar_from_u64s([
0x40d5a2d5052cb583,
0x5c0b5265c006a5cb,
0xfd936f0a297114f8,
0x1191f085dc4d2286
]),
]
)
}
}

0 comments on commit 1e95f58

Please sign in to comment.