From b53d83793fee8303cf2cfbd62171911741f327b0 Mon Sep 17 00:00:00 2001 From: Tiago Oliveira Date: Fri, 9 Feb 2024 22:24:03 +0100 Subject: [PATCH] mlkem: ref: update keccakf1600 implementation in fips202 --- .../mlkem/mlkem768/amd64/ref/fips202.jinc | 338 +++++++++++------- .../mlkem/mlkem768/amd64/ref/indcpa.jinc | 27 ++ .../mlkem/mlkem768/amd64/ref/kem.jinc | 5 +- 3 files changed, 239 insertions(+), 131 deletions(-) diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/ref/fips202.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/ref/fips202.jinc index 0ca1e83a..793fe166 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/ref/fips202.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/ref/fips202.jinc @@ -3,167 +3,236 @@ param int SHAKE256_RATE = 136; param int SHA3_256_RATE = 136; param int SHA3_512_RATE = 72; -inline -fn __index(inline int x, inline int y) -> inline int { +param int KECCAK_ROUNDS = 24; + +u64[24] KECCAK1600_RC = +{ 0x0000000000000001 + ,0x0000000000008082 + ,0x800000000000808a + ,0x8000000080008000 + ,0x000000000000808b + ,0x0000000080000001 + ,0x8000000080008081 + ,0x8000000000008009 + ,0x000000000000008a + ,0x0000000000000088 + ,0x0000000080008009 + ,0x000000008000000a + ,0x000000008000808b + ,0x800000000000008b + ,0x8000000000008089 + ,0x8000000000008003 + ,0x8000000000008002 + ,0x8000000000000080 + ,0x000000000000800a + ,0x800000008000000a + ,0x8000000080008081 + ,0x8000000000008080 + ,0x0000000080000001 + ,0x8000000080008008 +}; + +inline fn keccakf1600_index(inline int x y) -> inline int +{ inline int r; r = (x % 5) + 5 * (y % 5); return r; } +inline fn keccakf1600_rho_offsets(inline int i) -> inline int +{ + inline int r x y z t; -inline -fn __ROL64(reg u64 x, inline int c) -> reg u64 { - reg u64 y; - _, _, y = #ROL_64(x, c); - return y; + r = 0; + x = 1; + y = 0; + + for t = 0 to 24 + { if (i == x + 5 * y) + { r = ((t + 1) * (t + 2) / 2) % 64; } + z = (2 * x + 3 * y) % 5; + x = y; + y = z; + } + + return r; } -inline -fn __theta(reg ptr u64[25] a) -> reg ptr u64[25] { - inline int x, y; - reg u64[5] c, d; +inline fn keccakf1600_rhotates(inline int x y) -> inline int +{ + inline int i r; + i = keccakf1600_index(x, y); + r = keccakf1600_rho_offsets(i); + return r; +} - for x = 0 to 5 { - c[x] = 0; - for y = 0 to 5 { - c[x] ^= a[x + 5 * y]; - } +// C[x] = A[x,0] ^ A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] +inline fn keccakf1600_theta_sum(reg ptr u64[25] a) -> reg u64[5] +{ + inline int x y; + reg u64[5] c; + + // C[x] = A[x, 0] + for x=0 to 5 + { c[x] = a[x + 0]; } + + // C[x] ^= A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] + for y=1 to 5 + { for x=0 to 5 + { c[x] ^= a[x + y*5]; } } - for x = 0 to 5 { - /* d[x] = __ROL64(c[(x + 1) % 5], 1); */ - /* extraction fails */ + return c; +} - /* _, _, d[x] = #ROL_64(c[(x + 1) % 5], 1);*/ - /* d[x] ^= c[(x + 4) % 5];*/ - /* does not compile */ +// D[x] = C[x-1] ^ ROT(C[x+1], 1) +inline fn keccakf1600_theta_rol(reg u64[5] c) -> reg u64[5] +{ + inline int x; + reg u64[5] d; + for x = 0 to 5 + { // D[x] = C[x + 1] d[x] = c[(x + 1) % 5]; + + // D[x] = ROT(D[x], 1) _, _, d[x] = #ROL_64(d[x], 1); - d[x] ^= c[(x + 4) % 5]; - } - for x = 0 to 5 { - for y = 0 to 5 { - a[x + 5 * y] ^= d[x]; - } + // D[x] ^= C[x-1] + d[x] ^= c[(x - 1 + 5) % 5]; } - return a; + return d; } +// B[x] = ROT( (A[x',y'] ^ D[x']), r[x',y'] ) with (x',y') = M^-1 (x,y) +// +// M = (0 1) M^-1 = (1 3) x' = 1x + 3y +// (2 3) (1 0) y' = 1x + 0y +// +inline fn keccakf1600_rol_sum( + reg ptr u64[25] a, + reg u64[5] d, + inline int y) + -> + reg u64[5] +{ + inline int r x x_ y_; + reg u64[5] b; -inline -fn __keccakRhoOffsets(inline int i) -> inline int { - inline int r, x, y, z, t; - - r = 0; - x = 1; - y = 0; - for t = 0 to 24 { - if (i == x + 5 * y) { - r = ((t + 1) * (t + 2) / 2) % 64; - } - z = (2 * x + 3 * y) % 5; - x = y; - y = z; - } + for x = 0 to 5 + { + x_ = (x + 3*y) % 5; + y_ = x; + r = keccakf1600_rhotates(x_, y_); - return r; -} + // B[x] = A[x',y'] + b[x] = a[x_ + y_*5]; + // B[x] ^= D[x']; + b[x] ^= d[x_]; -inline -fn __rho(reg ptr u64[25] a) -> reg ptr u64[25] { - inline int x, y, i, z; + // B[x] = ROT( B[x], r[x',y'] ); + if(r != 0) + { _, _, b[x] = #ROL_64(b[x], r); } - for x = 0 to 5 { - for y = 0 to 5 { - i = __index(x, y); - z = __keccakRhoOffsets(i); - _, _, a[i] = #ROL_64(a[i], z); - } } - return a; + return b; } - -inline -fn __pi(reg ptr u64[25] a) -> reg ptr u64[25] { - stack u64[25] b; +// E[x, y] = B[x] ^ ( (!B[x+1]) & B[x+2] ) +// -- when x and y are 0: E[0,0] ^= RC[i]; +inline fn keccakf1600_set_row( + reg ptr u64[25] e, + reg u64[5] b, + inline int y, + stack u64 s_rc) + -> + reg ptr u64[25] +{ + inline int x x1 x2; reg u64 t; - inline int x, y, i; - for i = 0 to 25 { t = a[i]; b[i] = t; } - for x = 0 to 5 { - for y = 0 to 5 { - t = b[x + 5 * y]; - i = __index(y, 2 * x + 3 * y); - a[i] = t; - } - } - return a; -} + for x=0 to 5 + { + x1 = (x + 1) % 5; + x2 = (x + 2) % 5; + // t = !b[x1] & b[x2]; // bmi1 + t = b[x1]; t = !t; t &= b[x2]; -inline -fn __chi(reg ptr u64[25] a) -> reg ptr u64[25] { - inline int x, y, i; - reg u64[5] c; - for y = 0 to 5 { - for x = 0 to 5 { - i = __index(x + 1, y); - c[x] = a[i]; - c[x] = !c[x]; - i = __index(x + 2, y); - c[x] &= a[i]; - i = __index(x, y); - c[x] ^= a[i]; - } - for x = 0 to 5 { - a[x + 5 * y] = c[x]; - } + t ^= b[x]; + if( x==0 && y==0 ){ t ^= s_rc; } + e[x + y*5] = t; } - return a; + + return e; } +inline fn keccakf1600_round( + reg ptr u64[25] e, + reg ptr u64[25] a, + reg u64 rc) + -> + reg ptr u64[25] +{ + inline int y; + reg u64[5] b c d; + stack u64 s_rc; + + s_rc = rc; -inline -fn __iota(reg ptr u64[25] a, reg u64 c) -> reg ptr u64[25] { - a[0] ^= c; - return a; + c = keccakf1600_theta_sum(a); + d = keccakf1600_theta_rol(c); + + for y = 0 to 5 + { b = keccakf1600_rol_sum(a, d, y); + e = keccakf1600_set_row(e, b, y, s_rc); + } + + return e; } -u64[24] roundconstants = {0x0000000000000001, 0x0000000000008082, 0x800000000000808a, 0x8000000080008000, - 0x000000000000808b, 0x0000000080000001, 0x8000000080008081, 0x8000000000008009, - 0x000000000000008a, 0x0000000000000088, 0x0000000080008009, 0x000000008000000a, - 0x000000008000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003, - 0x8000000000008002, 0x8000000000000080, 0x000000000000800a, 0x800000008000000a, - 0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008}; +inline fn __keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] +{ + reg ptr u64[24] RC; + stack u64[25] s_e; + reg ptr u64[25] e; + reg u64 c rc; -fn __keccakf1600_ref(reg ptr u64[25] state) -> reg ptr u64[25] { - reg ptr u64[24] constptr; + RC = KECCAK1600_RC; + e = s_e; - reg u64 rctr; - - constptr = roundconstants; - rctr = 0; + c = 0; + while (c < KECCAK_ROUNDS) + { + rc = RC[(int) c]; + e = keccakf1600_round(e, a, rc); - while (rctr < 192) { - state = __theta(state); - state = __rho(state); - state = __pi(state); - state = __chi(state); - constptr = roundconstants; - state = __iota(state, constptr.[(int)rctr]); - rctr += 8; + rc = RC[(int) c + 1]; + a = keccakf1600_round(a, e, rc); + + c += 2; } - return state; + return a; } +fn _keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] +{ + a = __keccakf1600(a); + return a; +} + +inline fn _keccakf1600_(reg ptr u64[25] a) -> reg ptr u64[25] +{ + a = a; + a = _keccakf1600(a); + a = a; + return a; +} inline fn __st0(reg ptr u64[25] state) -> reg ptr u64[25] @@ -325,7 +394,7 @@ fn ____keccak1600_ref( s_inlen = inlen; s_rate = rate; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); inlen = s_inlen; in = s_in; @@ -343,7 +412,7 @@ fn ____keccak1600_ref( s_outlen = outlen; s_rate = rate; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = s_out; outlen = s_outlen; @@ -354,7 +423,7 @@ fn ____keccak1600_ref( s_out = out; } - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = s_out; outlen = s_outlen; @@ -414,7 +483,7 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12 state[u8 33] ^= 0x1f; state[u8 SHAKE256_RATE-1] ^= 0x80; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = sout; @@ -427,13 +496,14 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12 fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { stack u64[25] state; - stack u64 s_out; + stack u64 s_out s_in1; stack u64 s_in s_ilen s_r8; reg u64 ilen r8 t64 in; reg u8 t8; inline int i; s_out = out; + s_in1 = in1; state = __st0(state); for i = 0 to MLKEM_SYMBYTES/8 { @@ -446,11 +516,11 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { state[u64 i] ^= t64; } - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); r8 = SHAKE256_RATE; ilen = MLKEM_CT_LEN - (SHAKE256_RATE - MLKEM_SYMBYTES); - in = in1; + in = s_in1; in += SHAKE256_RATE - MLKEM_SYMBYTES; while(ilen >= r8) @@ -461,7 +531,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { s_ilen = ilen; s_r8 = r8; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); in = s_in; ilen = s_ilen; @@ -471,7 +541,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { t8 = 0x1f; state = __add_final_block(state, in, ilen, t8, r8); - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = s_out; @@ -489,6 +559,9 @@ fn _sha3512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] stack u64[25] state; reg u8 c; inline int i; + stack ptr u8[64] s_out; + + s_out = out; state = __st0(state); @@ -499,8 +572,9 @@ fn _sha3512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] state[u8 32] ^= 0x06; state[u8 SHA3_512_RATE-1] ^= 0x80; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); + out = s_out; for i = 0 to 64 { c = state[u8 (int) i]; out[i] = c; @@ -529,12 +603,16 @@ fn _shake128_absorb34(reg ptr u64[25] state, reg const ptr u8[34] in) -> reg ptr fn _shake128_squeezeblock(reg ptr u64[25] state, reg ptr u8[SHAKE128_RATE] out) -> reg ptr u64[25], reg ptr u8[SHAKE128_RATE] { + stack ptr u8[SHAKE128_RATE] s_out; reg u8 c; inline int i; - state = __keccakf1600_ref(state); + s_out = out; + + state = _keccakf1600_(state); - for i = 0 to SHAKE128_RATE { + out = s_out; + for i = 0 to SHAKE128_RATE { // SHAKE128 rate is 168: or 21 u64: TODO: 'compress' this for loop c = state[u8 (int) i]; out[i] = c; } @@ -567,7 +645,7 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] s_ilen = ilen; s_r8 = r8; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); in = s_in; ilen = s_ilen; @@ -577,7 +655,7 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] t8 = 0x06; state = __add_final_block(state, in, ilen, t8, r8); - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = s_out; @@ -611,7 +689,7 @@ fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u state[u8 MLKEM_SYMBYTES] ^= 0x06; state[u8 SHA3_256_RATE - 1] = 0x80; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = s_out; @@ -645,7 +723,7 @@ fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64] out_s = out; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = out_s; diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/ref/indcpa.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/ref/indcpa.jinc index 5e959a51..b20da783 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/ref/indcpa.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/ref/indcpa.jinc @@ -93,6 +93,9 @@ fn __indcpa_enc(stack u64 sctp, reg ptr u8[32] msgp, reg u64 pkp, reg ptr u8[MLK reg u64 i t64; reg u64 ctp; reg u8 nonce; + stack ptr u8[MLKEM_SYMBYTES] s_noiseseed; + + s_noiseseed = noiseseed; pkpv = __polyvec_frombytes(pkp); @@ -110,20 +113,31 @@ fn __indcpa_enc(stack u64 sctp, reg ptr u8[32] msgp, reg u64 pkp, reg ptr u8[MLK aat = __gen_matrix(publicseed, 1); + noiseseed = s_noiseseed; nonce = 0; sp[0:MLKEM_N] = _poly_getnoise(sp[0:MLKEM_N], noiseseed, nonce); + + noiseseed = s_noiseseed; nonce = 1; sp[MLKEM_N:MLKEM_N] = _poly_getnoise(sp[MLKEM_N:MLKEM_N], noiseseed, nonce); + + noiseseed = s_noiseseed; nonce = 2; sp[2*MLKEM_N:MLKEM_N] = _poly_getnoise(sp[2*MLKEM_N:MLKEM_N], noiseseed, nonce); + noiseseed = s_noiseseed; nonce = 3; ep[0:MLKEM_N] = _poly_getnoise(ep[0:MLKEM_N], noiseseed, nonce); + + noiseseed = s_noiseseed; nonce = 4; ep[MLKEM_N:MLKEM_N] = _poly_getnoise(ep[MLKEM_N:MLKEM_N], noiseseed, nonce); + + noiseseed = s_noiseseed; nonce = 5; ep[2*MLKEM_N:MLKEM_N] = _poly_getnoise(ep[2*MLKEM_N:MLKEM_N], noiseseed, nonce); + noiseseed = s_noiseseed; nonce = 6; epp = _poly_getnoise(epp, noiseseed, nonce); @@ -160,7 +174,9 @@ fn __iindcpa_enc(reg ptr u8[MLKEM_CT_LEN] ctp, reg ptr u8[32] msgp, reg u64 pkp, reg u64 i t64; reg u8 nonce; stack ptr u8[MLKEM_CT_LEN] sctp; + stack ptr u8[MLKEM_SYMBYTES] s_noiseseed; + s_noiseseed = noiseseed; sctp = ctp; pkpv = __polyvec_frombytes(pkp); @@ -179,20 +195,31 @@ fn __iindcpa_enc(reg ptr u8[MLKEM_CT_LEN] ctp, reg ptr u8[32] msgp, reg u64 pkp, aat = __gen_matrix(publicseed, 1); + noiseseed = s_noiseseed; nonce = 0; sp[0:MLKEM_N] = _poly_getnoise(sp[0:MLKEM_N], noiseseed, nonce); + + noiseseed = s_noiseseed; nonce = 1; sp[MLKEM_N:MLKEM_N] = _poly_getnoise(sp[MLKEM_N:MLKEM_N], noiseseed, nonce); + + noiseseed = s_noiseseed; nonce = 2; sp[2*MLKEM_N:MLKEM_N] = _poly_getnoise(sp[2*MLKEM_N:MLKEM_N], noiseseed, nonce); + noiseseed = s_noiseseed; nonce = 3; ep[0:MLKEM_N] = _poly_getnoise(ep[0:MLKEM_N], noiseseed, nonce); + + noiseseed = s_noiseseed; nonce = 4; ep[MLKEM_N:MLKEM_N] = _poly_getnoise(ep[MLKEM_N:MLKEM_N], noiseseed, nonce); + + noiseseed = s_noiseseed; nonce = 5; ep[2*MLKEM_N:MLKEM_N] = _poly_getnoise(ep[2*MLKEM_N:MLKEM_N], noiseseed, nonce); + noiseseed = s_noiseseed; nonce = 6; epp = _poly_getnoise(epp, noiseseed, nonce); diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/ref/kem.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/ref/kem.jinc index 4795a352..ee8c60ea 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/ref/kem.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/ref/kem.jinc @@ -98,7 +98,7 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) { stack u8[MLKEM_CT_LEN] ctpc; stack u8[2*MLKEM_SYMBYTES] kr buf; - stack u64 s_skp s_ctp s_shkp; + stack u64 s_skp s_ctp s_shkp s_cnd; reg u64 pkp hp zp t64 cnd; inline int i; @@ -127,14 +127,17 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) ctp = s_ctp; cnd = __verify(ctp, ctpc); + s_cnd = cnd; zp = s_skp; zp += 64; zp += 24 * MLKEM_K * MLKEM_N>>3; /* fixme: should this be done in memory? */ + shkp = s_shkp; _shake256_1120_32(shkp, zp, ctp); shkp = s_shkp; + cnd = s_cnd; __cmov(shkp, kr[0:MLKEM_SYMBYTES], cnd); }