diff --git a/src/zkproofs/correct_key_ni.rs b/src/zkproofs/correct_key_ni.rs index b1c8540..04879ef 100644 --- a/src/zkproofs/correct_key_ni.rs +++ b/src/zkproofs/correct_key_ni.rs @@ -34,14 +34,24 @@ pub struct CorrectKeyProofError; pub struct NICorrectKeyProof { #[serde(with = "crate::serialize::vecbigint")] pub sigma_vec: Vec, + salt: Option<&'static [u8]>, } impl NICorrectKeyProof { - pub fn proof(dk: &DecryptionKey) -> NICorrectKeyProof { + pub fn get_salt(&self) -> Option<&'static [u8]> { + self.salt + } + + pub fn proof(dk: &DecryptionKey, salt_str: Option<&'static [u8]>) -> NICorrectKeyProof { let dk_n = &dk.q * &dk.p; let key_length = dk_n.bit_length(); - let salt_bn = super::compute_digest(iter::once(BigInt::from(SALT_STRING))); + let salt = match salt_str { + Some(salt) => salt, + None => SALT_STRING, + }; + + let salt_bn = super::compute_digest(iter::once(BigInt::from(salt))); // TODO: use flatten (Morten?) let rho_vec = (0..M2) @@ -60,12 +70,15 @@ impl NICorrectKeyProof { .iter() .map(|i| extract_nroot(dk, i)) .collect::>(); - NICorrectKeyProof { sigma_vec } + NICorrectKeyProof { + sigma_vec, + salt: Some(salt), + } } pub fn verify(&self, ek: &EncryptionKey) -> Result<(), CorrectKeyProofError> { let key_length = ek.n.bit_length() as usize; - let salt_bn = super::compute_digest(iter::once(BigInt::from(SALT_STRING))); + let salt_bn = super::compute_digest(iter::once(BigInt::from(self.salt.unwrap()))); let rho_vec = (0..M2) .map(|i| { @@ -116,9 +129,19 @@ mod tests { use paillier::Paillier; #[test] - fn test_correct_zk_proof() { + fn test_correct_zk_proof_no_salt_str() { + let (ek, dk) = Paillier::keypair().keys(); + let proof = NICorrectKeyProof::proof(&dk, None); + assert!(proof.verify(&ek).is_ok()); + assert_eq!(proof.get_salt().unwrap(), SALT_STRING); + } + + #[test] + fn test_correct_zk_proof_with_salt_str() { + let salt_str: &[u8] = &[90, 101, 110, 32, 71, 111, 32, 88]; let (ek, dk) = Paillier::keypair().keys(); - let proof = NICorrectKeyProof::proof(&dk); + let proof = NICorrectKeyProof::proof(&dk, Some(salt_str)); assert!(proof.verify(&ek).is_ok()); + assert_eq!(proof.get_salt().unwrap(), salt_str); } }