From 6ded699524b109b6f62aa27451b3a91a77577646 Mon Sep 17 00:00:00 2001 From: Tiago Oliveira Date: Fri, 9 Feb 2024 22:42:05 +0100 Subject: [PATCH] Revert "mlkem: avx2: update keccakf1600 implementation" This reverts commit 2ac8b2e1ee71030345b9b1b772f848a7f25b49d2. --- .../mlkem/mlkem768/amd64/avx2/fips202.jinc | 30 +-- .../mlkem768/amd64/avx2/keccakf1600.jinc | 230 ++++++++---------- 2 files changed, 112 insertions(+), 148 deletions(-) diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc index 178c5c02..8fce754e 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc @@ -162,7 +162,7 @@ fn __keccak1600_scalar( s_inlen = inlen; s_rate = rate; - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); inlen = s_inlen; in = s_in; @@ -180,7 +180,7 @@ fn __keccak1600_scalar( s_outlen = outlen; s_rate = rate; - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); out = s_out; outlen = s_outlen; @@ -191,7 +191,7 @@ fn __keccak1600_scalar( s_out = out; } - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); out = s_out; outlen = s_outlen; @@ -267,7 +267,7 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] s_ilen = ilen; s_r8 = r8; - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); in = s_in; ilen = s_ilen; @@ -277,7 +277,7 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] t8 = 0x06; state = __add_final_block(state, in, ilen, t8, r8); - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); out = s_out; @@ -315,7 +315,7 @@ fn __isha3_512(reg ptr u8[64] out, reg u64 in, inline int inlen) -> stack u8[64] s_ilen = ilen; s_r8 = r8; - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); in = s_in; ilen = s_ilen; @@ -325,7 +325,7 @@ fn __isha3_512(reg ptr u8[64] out, reg u64 in, inline int inlen) -> stack u8[64] t8 = 0x06; state = __add_final_block(state, in, ilen, t8, r8); - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); out = s_out; @@ -361,7 +361,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { s_in = in1; - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); r8 = SHAKE256_RATE; ilen = MLKEM_INDCPA_CIPHERTEXTBYTES - (SHAKE256_RATE - MLKEM_SYMBYTES); @@ -376,7 +376,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { s_ilen = ilen; s_r8 = r8; - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); in = s_in; ilen = s_ilen; @@ -386,7 +386,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { t8 = 0x1f; state = __add_final_block(state, in, ilen, t8, r8); - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); out = s_out; @@ -422,7 +422,7 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12 state[u8 33] ^= 0x1f; state[u8 SHAKE256_RATE-1] ^= 0x80; - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); out = sout; @@ -455,7 +455,7 @@ fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u state[u8 MLKEM_SYMBYTES] ^= 0x06; state[u8 SHA3_256_RATE - 1] = 0x80; - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); out = s_out; @@ -489,7 +489,7 @@ fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64] out_s = out; - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); out = out_s; @@ -523,7 +523,7 @@ fn _sha3_512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] out_s = out; - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); out = out_s; @@ -567,7 +567,7 @@ fn _shake128_squeezeblock(reg ptr u64[25] state, reg ptr u8[SHAKE128_RATE] out) inline int i; out_s = out; - state = _keccakf1600_(state); + state = _keccakf1600_scalar(state); out = out_s; for i = 0 to SHAKE128_RATE/8 diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccakf1600.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccakf1600.jinc index 757a493b..02996b6a 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccakf1600.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccakf1600.jinc @@ -1,6 +1,4 @@ -param int KECCAK_ROUNDS = 24; - -u64[24] KECCAK1600_RC = +u64[24] KECCAK_RC = { 0x0000000000000001 ,0x0000000000008082 ,0x800000000000808a @@ -27,14 +25,14 @@ u64[24] KECCAK1600_RC = ,0x8000000080008008 }; -inline fn keccakf1600_index(inline int x y) -> inline int +inline fn __index(inline int x y) -> inline int { inline int r; r = (x % 5) + 5 * (y % 5); return r; } -inline fn keccakf1600_rho_offsets(inline int i) -> inline int +inline fn __keccak_rho_offsets(inline int i) -> inline int { inline int r x y z t; @@ -42,9 +40,10 @@ inline fn keccakf1600_rho_offsets(inline int i) -> inline int x = 1; y = 0; - for t = 0 to 24 - { if (i == x + 5 * y) - { r = ((t + 1) * (t + 2) / 2) % 64; } + for t = 0 to 24 { + if (i == x + 5 * y) { + r = ((t + 1) * (t + 2) / 2) % 64; + } z = (2 * x + 3 * y) % 5; x = y; y = z; @@ -53,178 +52,143 @@ inline fn keccakf1600_rho_offsets(inline int i) -> inline int return r; } -inline fn keccakf1600_rhotates(inline int x y) -> inline int +inline fn __rhotates(inline int x y) -> inline int { inline int i r; - i = keccakf1600_index(x, y); - r = keccakf1600_rho_offsets(i); + i = __index(x, y); + r = __keccak_rho_offsets(i); return r; } -// C[x] = A[x,0] ^ A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] -inline fn keccakf1600_theta_sum(reg ptr u64[25] a) -> reg u64[5] +inline fn __theta_sum_scalar(reg ptr u64[25] a) -> reg u64[5] { - inline int x y; + inline int i j ti; reg u64[5] c; - // C[x] = A[x, 0] - for x=0 to 5 - { c[x] = a[x + 0]; } + for i=0 to 5 + { + ti = __index(i, 0); + c[i] = a[ti]; + } - // C[x] ^= A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] - for y=1 to 5 - { for x=0 to 5 - { c[x] ^= a[x + y*5]; } + for j=1 to 5 + { for i=0 to 5 + { + ti = __index(i, j); + c[i] ^= a[ti]; + } } return c; } -// D[x] = C[x-1] ^ ROT(C[x+1], 1) -inline fn keccakf1600_theta_rol(reg u64[5] c) -> reg u64[5] +inline fn __theta_rol_scalar(reg u64[5] c) -> reg u64[5] { - inline int x; + inline int i; reg u64[5] d; - for x = 0 to 5 - { // D[x] = C[x + 1] - d[x] = c[(x + 1) % 5]; - - // D[x] = ROT(D[x], 1) - _, _, d[x] = #ROL_64(d[x], 1); - - // D[x] ^= C[x-1] - d[x] ^= c[(x - 1 + 5) % 5]; + for i = 0 to 5 + { d[i] = c[(i+1)%5]; + _, _, d[i] = #ROL_64(d[i], 1); + d[i] ^= c[(i+4)%5]; } return d; } -// B[x] = ROT( (A[x',y'] ^ D[x']), r[x',y'] ) with (x',y') = M^-1 (x,y) -// -// M = (0 1) M^-1 = (1 3) x' = 1x + 3y -// (2 3) (1 0) y' = 1x + 0y -// -inline fn keccakf1600_rol_sum( - reg ptr u64[25] a, +inline fn __rol_sum_scalar( reg u64[5] d, - inline int y) - -> - reg u64[5] + reg ptr u64[25] a, + inline int offset +) -> reg u64[5] { - inline int r x x_ y_; - reg u64[5] b; + inline int j j1 k ti; + reg u64[5] c; - for x = 0 to 5 + for j = 0 to 5 { - x_ = (x + 3*y) % 5; - y_ = x; - r = keccakf1600_rhotates(x_, y_); - - // B[x] = A[x',y'] - b[x] = a[x_ + y_*5]; - - // B[x] ^= D[x']; - b[x] ^= d[x_]; - - // B[x] = ROT( B[x], r[x',y'] ); - if(r != 0) - { _, _, b[x] = #ROL_64(b[x], r); } - + j1 = (j+offset) % 5; + k = __rhotates(j1, j); + ti = __index(j1, j); + c[j] = a[ti]; + c[j] ^= d[j1]; + _, _, c[j] = #ROL_64(c[j], k); } - return b; + return c; } -// E[x, y] = B[x] ^ ( (!B[x+1]) & B[x+2] ) -// -- when x and y are 0: E[0,0] ^= RC[i]; -inline fn keccakf1600_set_row( - reg ptr u64[25] e, - reg u64[5] b, - inline int y, - stack u64 s_rc) - -> - reg ptr u64[25] +inline fn __set_row_scalar( + reg ptr u64[25] r, + inline int row, + reg u64[5] c, + reg u64 iota +) -> reg ptr u64[25] { - inline int x x1 x2; + inline int j j1 j2 ti; reg u64 t; - for x=0 to 5 + for j= 0 to 5 { - x1 = (x + 1) % 5; - x2 = (x + 2) % 5; - - // t = !b[x1] & b[x2]; // bmi1 - t = b[x1]; t = !t; t &= b[x2]; - - t ^= b[x]; - if( x==0 && y==0 ){ t ^= s_rc; } - e[x + y*5] = t; + j1 = (j+1) % 5; + j2 = (j+2) % 5; + t = !c[j1] & c[j2]; + if row==0 && j==0 { t ^= iota; } + t ^= c[j]; + ti = __index(j, row); + r[ti] = t; } - return e; + return r; } -inline fn keccakf1600_round( - reg ptr u64[25] e, - reg ptr u64[25] a, - reg u64 rc) - -> - reg ptr u64[25] +inline fn __round2x_scalar(reg ptr u64[25] a r, reg u64 iota) -> reg ptr u64[25], reg ptr u64[25] { - inline int y; - reg u64[5] b c d; - stack u64 s_rc; - - s_rc = rc; - - c = keccakf1600_theta_sum(a); - d = keccakf1600_theta_rol(c); - - for y = 0 to 5 - { b = keccakf1600_rol_sum(a, d, y); - e = keccakf1600_set_row(e, b, y, s_rc); - } - - return e; + reg u64[5] c d; + + c = __theta_sum_scalar(a); + d = __theta_rol_scalar(c); + c = __rol_sum_scalar(d, a, 0); + r = __set_row_scalar(r, 0, c, iota); + c = __rol_sum_scalar(d, a, 3); + r = __set_row_scalar(r, 1, c, iota); + c = __rol_sum_scalar(d, a, 1); + r = __set_row_scalar(r, 2, c, iota); + c = __rol_sum_scalar(d, a, 4); + r = __set_row_scalar(r, 3, c, iota); + c = __rol_sum_scalar(d, a, 2); + r = __set_row_scalar(r, 4, c, iota); + + return a, r; } -inline fn __keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] +#[returnaddress="stack"] +fn _keccakf1600_scalar(reg ptr u64[25] a) -> reg ptr u64[25] { - reg ptr u64[24] RC; - stack u64[25] s_e; - reg ptr u64[25] e; + stack u64[25] r; + reg ptr u64[24] iotas_p; + reg u64 iota; + reg u64 round; + stack u64 round_s; - reg u64 c rc; + iotas_p = KECCAK_RC; - RC = KECCAK1600_RC; - e = s_e; + round = 0; - c = 0; - while (c < KECCAK_ROUNDS) + while(round < 24) { - rc = RC[(int) c]; - e = keccakf1600_round(e, a, rc); - - rc = RC[(int) c + 1]; - a = keccakf1600_round(a, e, rc); - - c += 2; + iota = iotas_p[(int) round]; + round_s = round; + a, r = __round2x_scalar(a, r, iota); + round = round_s; + round += 1; + + iota = iotas_p[(int) round]; + round_s = round; + r, a = __round2x_scalar(r, a, iotas_p[(int) round]); + round = round_s; + round += 1; } return a; } - -fn _keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] -{ - a = __keccakf1600(a); - return a; -} - -inline fn _keccakf1600_(reg ptr u64[25] a) -> reg ptr u64[25] -{ - a = a; - a = _keccakf1600(a); - a = a; - return a; -}