diff --git a/v2/kyber/kyber.cc b/v2/kyber/kyber.cc index 5ac61bd..7846e52 100644 --- a/v2/kyber/kyber.cc +++ b/v2/kyber/kyber.cc @@ -1553,18 +1553,21 @@ bool kyber_decrypt(int g, kyber_parameters& p, int dk_len, byte* dk, // full bool kyber_kem_keygen(int g, kyber_parameters& p, int* kem_ek_len, byte* kem_ek, int* kem_dk_len, byte* kem_dk) { + byte z[32]; int n_b = crypto_get_random_bytes(32, z); if (n_b != 32) { - printf("kyber_kem_keygen crypto_get_random_bytes returne wrong nuber of bytes\n"); + printf("kyber_kem_keygen: crypto_get_random_bytes returne wrong nuber of bytes\n"); return false; } int dk_PKE_len = 384 * p.k_; byte dk_PKE[dk_PKE_len]; - if (! kyber_keygen(g, p, kem_ek_len, kem_ek, &dk_PKE_len, dk_PKE)) { + if (!kyber_keygen(g, p, kem_ek_len, kem_ek, &dk_PKE_len, dk_PKE)) { + printf("kyber_kem_keygen: kyber_keygen failed\n"); return false; } + int ek_PKE_len = 384 * p.k_ + 32; // kem_dk = dk_PKE || ek || H(ek) || z int len = 0; @@ -1580,13 +1583,32 @@ bool kyber_kem_keygen(int g, kyber_parameters& p, int* kem_ek_len, byte* kem_ek, h.add_to_hash(*kem_ek_len, kem_ek); h.finalize(); h.get_digest(32, &kem_dk[len]); - + byte* ph = &kem_dk[len]; len += 32; + + byte* pz = &kem_dk[len]; memcpy(&kem_dk[len], z, 32); + len += 32; + if (*kem_dk_len < len) { + printf("kyber_kem_keygen: dk_kem buffer too small\n"); return false; } *kem_dk_len = len; +#if 1 + printf("ek (%d) :\n", ek_PKE_len); + print_bytes(ek_PKE_len, kem_ek); + printf("\n"); + printf("dk_PKE (%d) :\n", dk_PKE_len); + print_bytes(dk_PKE_len, dk_PKE); + printf("\n"); + printf("h (%d) :\n", 32); + print_bytes(32, ph); + printf("\n"); + printf("z (%d) :\n", 32); + print_bytes(32, pz); + printf("\n"); +#endif return true; } @@ -1605,6 +1627,11 @@ bool kyber_kem_encaps(int g, kyber_parameters& p, int kem_ek_len, byte* kem_ek, return false; } +#ifdef LONG_DEBUG + printf("m: "); + print_bytes(32, m); +#endif + byte ek_hash[32]; memset(ek_hash, 0, 32); @@ -1623,7 +1650,7 @@ bool kyber_kem_encaps(int g, kyber_parameters& p, int kem_ek_len, byte* kem_ek, // (K, r) := G(H(pk), m) byte K_r[64]; - if (!G(64, G_input, 64, K_r)) { + if (!G(64, G_input, 512, K_r)) { printf("kyber_kem_encaps: Can't compute G\n"); return false; } @@ -1644,17 +1671,17 @@ bool kyber_kem_encaps(int g, kyber_parameters& p, int kem_ek_len, byte* kem_ek, printf("kyber_kem_encaps: k_len too small\n"); return false; } - memcpy(k, pK, 32); + return true; } // Kem.Decapsulate -// dk := dk[0:384k] = dk[0:48] (bytes) -// ek := dk[384k:768k + 32] = dk[48:128] (bytes) -// h := dk[768k + 32: 768k+64 = dk[128:160] (bytes)] -// z := dk[768k+64: 768k+96 = dk[160:192] (bytes) -// m' := Kyber.Dec(dk, c) +// dk_PKE := dk[0:384k] +// ek_PKE := dk[384k:768k + 32] +// h := dk[768k + 32: 768k+64] +// z := dk[768k+64: 768k+96] +// m' := Kyber.Dec(dk_PKE, c) // (K', r') := G(m'|| h) // K-bar = J(z||c, 32) // c' := Kyber.Enc(m',r') @@ -1672,6 +1699,18 @@ bool kyber_kem_decaps(int g, kyber_parameters& p, int kem_dk_len, byte* kem_dk, byte* h = &kem_dk[ek_PKE_len + dk_PKE_len]; byte* z = &kem_dk[ek_PKE_len + dk_PKE_len + 32]; +#ifdef LONG_DEBUG + printf("\nkey_decaps\n"); + printf("ek_PKE(%d):\n", ek_PKE_len); + print_bytes(ek_PKE_len, ek_PKE); + printf("dk_PKE(%d):\n", dk_PKE_len); + print_bytes(dk_PKE_len, dk_PKE); + printf("h (%d): ", 32); + print_bytes(32, h); + printf("z (%d): ", 32); + print_bytes(32, z); +#endif + int m_prime_len = 32; byte m_prime[m_prime_len]; if (!kyber_decrypt(g, p, dk_PKE_len, dk_PKE, c_len, c, &m_prime_len, m_prime)) { @@ -1679,13 +1718,18 @@ bool kyber_kem_decaps(int g, kyber_parameters& p, int kem_dk_len, byte* kem_dk, return false; } +#ifdef LONG_DEBUG + printf("recovered m: "); + print_bytes(32, m_prime); +#endif + byte G_to_hash[64]; memcpy(G_to_hash, m_prime, 32); memcpy(&G_to_hash[32], h, 32); // (K_prime, r_prime) := G(H(pk), m) byte K_r_prime[64]; - if (!G(64, G_to_hash, 64, K_r_prime)) { + if (!G(64, G_to_hash, 512, K_r_prime)) { printf("kyber_kem_decaps: G failed\n"); return false; } @@ -1693,6 +1737,11 @@ bool kyber_kem_decaps(int g, kyber_parameters& p, int kem_dk_len, byte* kem_dk, byte* pr_prime = &K_r_prime[32]; byte K_bar[32]; +#ifdef LONG_DEBUG + printf("K_prime: "); + print_bytes(32, pK_prime); +#endif + // J (shake-256) sha3 h_o; if (!h_o.init(512, 256)) { @@ -1715,7 +1764,7 @@ bool kyber_kem_decaps(int g, kyber_parameters& p, int kem_dk_len, byte* kem_dk, printf("kyber_kem_decaps: c != c_prime\n"); return false; } - memcpy(k, K_bar, 32); + memcpy(k, pK_prime, 32); *k_len = 32; return true; } diff --git a/v2/kyber/test_kyber.cc b/v2/kyber/test_kyber.cc index 46c97d4..5054a55 100644 --- a/v2/kyber/test_kyber.cc +++ b/v2/kyber/test_kyber.cc @@ -93,12 +93,17 @@ bool test_kyber1() { if (FLAGS_print_all) { printf("recovered m: "); print_bytes(recovered_m_len, recovered_m); + printf("\n\nkyber complete\n\n"); } if (memcmp(m, recovered_m, m_len) != 0) { printf("message and recovered message dont match\n"); return false; } + if (FLAGS_print_all) { + printf("\n\nkem\n\n"); + } + int kem_ek_len = 384 * p.k_ + 32; byte kem_ek[kem_ek_len]; memset(kem_ek, 0, kem_ek_len); @@ -120,7 +125,7 @@ bool test_kyber1() { memset(kem_k, 0, kem_k_len); if (FLAGS_print_all) { - printf("\n\nken_keygen\n\n"); + printf("\n\nkem_keygen\n\n"); printf("kem_ek (%d):\n", kem_ek_len); print_bytes(kem_ek_len, kem_ek); printf("\n"); @@ -136,9 +141,9 @@ bool test_kyber1() { } if (FLAGS_print_all) { - printf("\n\nken_encaps\n\n"); - printf("k (%d):\n", kem_k_len); - print_bytes(kem_ek_len, kem_ek); + printf("\n\nkem_encaps\n\n"); + printf("k (%d): ", kem_k_len); + print_bytes(kem_k_len, kem_k); printf("\n"); printf("c (%d):\n", kem_c_len); print_bytes(kem_c_len, kem_c); @@ -162,7 +167,6 @@ bool test_kyber1() { printf("recovered key (%d): ", recovered_k_len); print_bytes(recovered_k_len, recovered_k); } - return true; if (memcmp(kem_k, recovered_k, recovered_k_len) != 0) { printf("Generated and encapsulated keys don't match\n");