diff --git a/v2/kyber/kyber.cc b/v2/kyber/kyber.cc index 377e3b2..5ac61bd 100644 --- a/v2/kyber/kyber.cc +++ b/v2/kyber/kyber.cc @@ -1559,6 +1559,7 @@ bool kyber_kem_keygen(int g, kyber_parameters& p, int* kem_ek_len, byte* kem_ek, printf("kyber_kem_keygen crypto_get_random_bytes returne wrong nuber of bytes\n"); return false; } + int dk_PKE_len = 384 * p.k_; byte dk_PKE[dk_PKE_len]; if (! kyber_keygen(g, p, kem_ek_len, kem_ek, &dk_PKE_len, dk_PKE)) { @@ -1571,6 +1572,7 @@ bool kyber_kem_keygen(int g, kyber_parameters& p, int* kem_ek_len, byte* kem_ek, len += dk_PKE_len; memcpy(&kem_dk[len], kem_ek, *kem_ek_len); len += *kem_ek_len; + sha3 h; if (!h.init(256, 256)) { return false; @@ -1578,6 +1580,7 @@ bool kyber_kem_keygen(int g, kyber_parameters& p, int* kem_ek_len, byte* kem_ek, h.add_to_hash(*kem_ek_len, kem_ek); h.finalize(); h.get_digest(32, &kem_dk[len]); + len += 32; memcpy(&kem_dk[len], z, 32); if (*kem_dk_len < len) { @@ -1601,37 +1604,48 @@ bool kyber_kem_encaps(int g, kyber_parameters& p, int kem_ek_len, byte* kem_ek, printf("kyber_kem_encaps: crypto_get_random_bytes return wrong nuber of bytes\n"); return false; } - byte h_to_hash[kem_ek_len + 32]; - memcpy(h_to_hash, m, 32); + + byte ek_hash[32]; + memset(ek_hash, 0, 32); + sha3 h; if (!h.init(256, 256)) { + printf("kyber_kem_encaps: Can't init sha3\n"); return false; } h.add_to_hash(kem_ek_len, kem_ek); h.finalize(); - h.get_digest(32, &h_to_hash[32]); + h.get_digest(32, ek_hash); - // (K, r) := G(H(pk), m) + byte G_input[64]; + memcpy(G_input, m, 32); + memcpy(&G_input[32], ek_hash, 32); + + // (K, r) := G(H(pk), m) byte K_r[64]; - if (!G(kem_ek_len + 32, h_to_hash, 64, K_r)) { + if (!G(64, G_input, 64, K_r)) { + printf("kyber_kem_encaps: Can't compute G\n"); return false; } + byte* pK= K_r; + byte* pr = &K_r[32]; - int len_c = 32*(p.du_* p.k_ + p.dv_); + int len_c = 32 * (p.du_* p.k_ + p.dv_); if (*kem_c_len < len_c) { + printf("kyber_kem_encaps: kem_c_len too small\n"); return false; } - byte c[len_c]; if (!kyber_encrypt(g, p, kem_ek_len, kem_ek, - 32, m, 32, &K_r[32], kem_c_len, kem_c)) { + 32, m, 32, pr, kem_c_len, kem_c)) { + printf("kyber_kem_encaps: kyber_encrypt failed\n"); return false; } - if (*kem_c_len < (len_c + 32)) { + if (*k_len < 32) { + printf("kyber_kem_encaps: k_len too small\n"); return false; } - *kem_c_len = 32 + len_c; - memcpy(kem_c, K_r, 32); - memcpy(&kem_c[32], c, len_c); + + memcpy(k, pK, 32); return true; } @@ -1651,29 +1665,38 @@ bool kyber_kem_encaps(int g, kyber_parameters& p, int kem_ek_len, byte* kem_ek, // return K bool kyber_kem_decaps(int g, kyber_parameters& p, int kem_dk_len, byte* kem_dk, int c_len, byte* c, int* k_len, byte* k) { - byte* dk = kem_dk; - byte* ek = &kem_dk[48]; - byte* h = &kem_dk[128]; - byte* z = &kem_dk[160]; + int ek_PKE_len = 384 * p.k_ + 32; + int dk_PKE_len = 384 * p.k_; + byte* dk_PKE = kem_dk; + byte* ek_PKE = &kem_dk[dk_PKE_len]; + byte* h = &kem_dk[ek_PKE_len + dk_PKE_len]; + byte* z = &kem_dk[ek_PKE_len + dk_PKE_len + 32]; int m_prime_len = 32; byte m_prime[m_prime_len]; - if (!kyber_decrypt(g, p, 48, dk, c_len, c, &m_prime_len, m_prime)) { + if (!kyber_decrypt(g, p, dk_PKE_len, dk_PKE, c_len, c, &m_prime_len, m_prime)) { + printf("kyber_kem_decaps: PKE decrypt failed\n"); return false; } - int ek_len = 80; - byte h_to_hash[ek_len + 32]; - memcpy(h_to_hash, m_prime, 32); - memcpy(&h_to_hash[32], h, 32); + + byte G_to_hash[64]; + memcpy(G_to_hash, m_prime, 32); + memcpy(&G_to_hash[32], h, 32); // (K_prime, r_prime) := G(H(pk), m) byte K_r_prime[64]; - if (!G(ek_len + 32, h_to_hash, 64, K_r_prime)) { + if (!G(64, G_to_hash, 64, K_r_prime)) { + printf("kyber_kem_decaps: G failed\n"); return false; } + byte* pK_prime = K_r_prime; + byte* pr_prime = &K_r_prime[32]; byte K_bar[32]; + + // J (shake-256) sha3 h_o; if (!h_o.init(512, 256)) { + printf("kyber_kem_decaps: shake-init failed\n"); return false; } h_o.add_to_hash(32, z); @@ -1681,13 +1704,15 @@ bool kyber_kem_decaps(int g, kyber_parameters& p, int kem_dk_len, byte* kem_dk, h_o.shake_finalize(); h_o.get_digest(32, K_bar); - int c_prime_len = 48 * p.k_; + int c_prime_len = 32 * (p.du_ * p.k_ + p.dv_); byte c_prime[c_prime_len]; - if (!kyber_encrypt(g, p, 80, ek, 32, m_prime, 32, - &K_r_prime[32], &c_prime_len, c_prime)) { + if (!kyber_encrypt(g, p, ek_PKE_len, ek_PKE, 32, m_prime, 32, + pr_prime, &c_prime_len, c_prime)) { + printf("kyber_kem_decaps: kyber_encrypt failed\n"); return false; } - if (memcmp(c, c_prime, c_prime_len) == 0) { + if (memcmp(c, c_prime, c_prime_len) != 0) { + printf("kyber_kem_decaps: c != c_prime\n"); return false; } memcpy(k, K_bar, 32); diff --git a/v2/kyber/test_kyber.cc b/v2/kyber/test_kyber.cc index 85debe2..46c97d4 100644 --- a/v2/kyber/test_kyber.cc +++ b/v2/kyber/test_kyber.cc @@ -98,37 +98,72 @@ bool test_kyber1() { printf("message and recovered message dont match\n"); return false; } -return true; int kem_ek_len = 384 * p.k_ + 32; byte kem_ek[kem_ek_len]; memset(kem_ek, 0, kem_ek_len); int kem_dk_len = 768 * p.k_ + 96; + byte kem_dk[kem_dk_len]; memset(kem_dk, 0, kem_dk_len); if (!kyber_kem_keygen(g, p, &kem_ek_len, kem_ek, &kem_dk_len, kem_dk)) { printf("Could not init kem_keygen\n"); return false; } + int kem_c_len = 32 * (p.du_ * p.k_ + p.dv_); byte kem_c[kem_dk_len]; memset(kem_c, 0, kem_c_len); + int kem_k_len = 32; byte kem_k[kem_k_len]; memset(kem_k, 0, kem_k_len); + + if (FLAGS_print_all) { + printf("\n\nken_keygen\n\n"); + printf("kem_ek (%d):\n", kem_ek_len); + print_bytes(kem_ek_len, kem_ek); + printf("\n"); + printf("kem_dk (%d):\n", kem_dk_len); + print_bytes(kem_dk_len, kem_dk); + printf("\n"); + } + if (!kyber_kem_encaps(g, p, kem_ek_len, kem_ek, &kem_k_len, kem_k, &kem_c_len, kem_c)) { printf("Could not init kem_encaps\n"); return false; } + + if (FLAGS_print_all) { + printf("\n\nken_encaps\n\n"); + printf("k (%d):\n", kem_k_len); + print_bytes(kem_ek_len, kem_ek); + printf("\n"); + printf("c (%d):\n", kem_c_len); + print_bytes(kem_c_len, kem_c); + printf("\n"); + } + int recovered_k_len = 32; byte recovered_k[recovered_k_len]; memset(recovered_k, 0, recovered_k_len); + if (!kyber_kem_decaps(g, p, kem_dk_len, kem_dk, kem_c_len, kem_c, &recovered_k_len, recovered_k)) { printf("Could not init kem_decaps\n"); return false; } + + if (FLAGS_print_all) { + printf("\n\nkem_decaps\n\n"); + printf("key (%d): ", kem_k_len); + print_bytes(kem_k_len, kem_k); + printf("recovered key (%d): ", recovered_k_len); + print_bytes(recovered_k_len, recovered_k); + } + return true; + if (memcmp(kem_k, recovered_k, recovered_k_len) != 0) { printf("Generated and encapsulated keys don't match\n"); return false;