From 2ac8b2e1ee71030345b9b1b772f848a7f25b49d2 Mon Sep 17 00:00:00 2001 From: Tiago Oliveira Date: Fri, 9 Feb 2024 22:36:11 +0100 Subject: [PATCH] mlkem: avx2: update keccakf1600 implementation --- .../mlkem/mlkem768/amd64/avx2/fips202.jinc | 30 +-- .../mlkem768/amd64/avx2/keccakf1600.jinc | 230 ++++++++++-------- 2 files changed, 148 insertions(+), 112 deletions(-) diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc index 8fce754e..178c5c02 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc @@ -162,7 +162,7 @@ fn __keccak1600_scalar( s_inlen = inlen; s_rate = rate; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); inlen = s_inlen; in = s_in; @@ -180,7 +180,7 @@ fn __keccak1600_scalar( s_outlen = outlen; s_rate = rate; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; outlen = s_outlen; @@ -191,7 +191,7 @@ fn __keccak1600_scalar( s_out = out; } - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; outlen = s_outlen; @@ -267,7 +267,7 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] s_ilen = ilen; s_r8 = r8; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); in = s_in; ilen = s_ilen; @@ -277,7 +277,7 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] t8 = 0x06; state = __add_final_block(state, in, ilen, t8, r8); - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; @@ -315,7 +315,7 @@ fn __isha3_512(reg ptr u8[64] out, reg u64 in, inline int inlen) -> stack u8[64] s_ilen = ilen; s_r8 = r8; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); in = s_in; ilen = s_ilen; @@ -325,7 +325,7 @@ fn __isha3_512(reg ptr u8[64] out, reg u64 in, inline int inlen) -> stack u8[64] t8 = 0x06; state = __add_final_block(state, in, ilen, t8, r8); - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; @@ -361,7 +361,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { s_in = in1; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); r8 = SHAKE256_RATE; ilen = MLKEM_INDCPA_CIPHERTEXTBYTES - (SHAKE256_RATE - MLKEM_SYMBYTES); @@ -376,7 +376,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { s_ilen = ilen; s_r8 = r8; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); in = s_in; ilen = s_ilen; @@ -386,7 +386,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { t8 = 0x1f; state = __add_final_block(state, in, ilen, t8, r8); - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; @@ -422,7 +422,7 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12 state[u8 33] ^= 0x1f; state[u8 SHAKE256_RATE-1] ^= 0x80; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = sout; @@ -455,7 +455,7 @@ fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u state[u8 MLKEM_SYMBYTES] ^= 0x06; state[u8 SHA3_256_RATE - 1] = 0x80; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; @@ -489,7 +489,7 @@ fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64] out_s = out; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = out_s; @@ -523,7 +523,7 @@ fn _sha3_512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] out_s = out; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = out_s; @@ -567,7 +567,7 @@ fn _shake128_squeezeblock(reg ptr u64[25] state, reg ptr u8[SHAKE128_RATE] out) inline int i; out_s = out; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = out_s; for i = 0 to SHAKE128_RATE/8 diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccakf1600.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccakf1600.jinc index 02996b6a..757a493b 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccakf1600.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccakf1600.jinc @@ -1,4 +1,6 @@ -u64[24] KECCAK_RC = +param int KECCAK_ROUNDS = 24; + +u64[24] KECCAK1600_RC = { 0x0000000000000001 ,0x0000000000008082 ,0x800000000000808a @@ -25,14 +27,14 @@ u64[24] KECCAK_RC = ,0x8000000080008008 }; -inline fn __index(inline int x y) -> inline int +inline fn keccakf1600_index(inline int x y) -> inline int { inline int r; r = (x % 5) + 5 * (y % 5); return r; } -inline fn __keccak_rho_offsets(inline int i) -> inline int +inline fn keccakf1600_rho_offsets(inline int i) -> inline int { inline int r x y z t; @@ -40,10 +42,9 @@ inline fn __keccak_rho_offsets(inline int i) -> inline int x = 1; y = 0; - for t = 0 to 24 { - if (i == x + 5 * y) { - r = ((t + 1) * (t + 2) / 2) % 64; - } + for t = 0 to 24 + { if (i == x + 5 * y) + { r = ((t + 1) * (t + 2) / 2) % 64; } z = (2 * x + 3 * y) % 5; x = y; y = z; @@ -52,143 +53,178 @@ inline fn __keccak_rho_offsets(inline int i) -> inline int return r; } -inline fn __rhotates(inline int x y) -> inline int +inline fn keccakf1600_rhotates(inline int x y) -> inline int { inline int i r; - i = __index(x, y); - r = __keccak_rho_offsets(i); + i = keccakf1600_index(x, y); + r = keccakf1600_rho_offsets(i); return r; } -inline fn __theta_sum_scalar(reg ptr u64[25] a) -> reg u64[5] +// C[x] = A[x,0] ^ A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] +inline fn keccakf1600_theta_sum(reg ptr u64[25] a) -> reg u64[5] { - inline int i j ti; + inline int x y; reg u64[5] c; - for i=0 to 5 - { - ti = __index(i, 0); - c[i] = a[ti]; - } + // C[x] = A[x, 0] + for x=0 to 5 + { c[x] = a[x + 0]; } - for j=1 to 5 - { for i=0 to 5 - { - ti = __index(i, j); - c[i] ^= a[ti]; - } + // C[x] ^= A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] + for y=1 to 5 + { for x=0 to 5 + { c[x] ^= a[x + y*5]; } } return c; } -inline fn __theta_rol_scalar(reg u64[5] c) -> reg u64[5] +// D[x] = C[x-1] ^ ROT(C[x+1], 1) +inline fn keccakf1600_theta_rol(reg u64[5] c) -> reg u64[5] { - inline int i; + inline int x; reg u64[5] d; - for i = 0 to 5 - { d[i] = c[(i+1)%5]; - _, _, d[i] = #ROL_64(d[i], 1); - d[i] ^= c[(i+4)%5]; + for x = 0 to 5 + { // D[x] = C[x + 1] + d[x] = c[(x + 1) % 5]; + + // D[x] = ROT(D[x], 1) + _, _, d[x] = #ROL_64(d[x], 1); + + // D[x] ^= C[x-1] + d[x] ^= c[(x - 1 + 5) % 5]; } return d; } -inline fn __rol_sum_scalar( - reg u64[5] d, +// B[x] = ROT( (A[x',y'] ^ D[x']), r[x',y'] ) with (x',y') = M^-1 (x,y) +// +// M = (0 1) M^-1 = (1 3) x' = 1x + 3y +// (2 3) (1 0) y' = 1x + 0y +// +inline fn keccakf1600_rol_sum( reg ptr u64[25] a, - inline int offset -) -> reg u64[5] + reg u64[5] d, + inline int y) + -> + reg u64[5] { - inline int j j1 k ti; - reg u64[5] c; + inline int r x x_ y_; + reg u64[5] b; - for j = 0 to 5 + for x = 0 to 5 { - j1 = (j+offset) % 5; - k = __rhotates(j1, j); - ti = __index(j1, j); - c[j] = a[ti]; - c[j] ^= d[j1]; - _, _, c[j] = #ROL_64(c[j], k); + x_ = (x + 3*y) % 5; + y_ = x; + r = keccakf1600_rhotates(x_, y_); + + // B[x] = A[x',y'] + b[x] = a[x_ + y_*5]; + + // B[x] ^= D[x']; + b[x] ^= d[x_]; + + // B[x] = ROT( B[x], r[x',y'] ); + if(r != 0) + { _, _, b[x] = #ROL_64(b[x], r); } + } - return c; + return b; } -inline fn __set_row_scalar( - reg ptr u64[25] r, - inline int row, - reg u64[5] c, - reg u64 iota -) -> reg ptr u64[25] +// E[x, y] = B[x] ^ ( (!B[x+1]) & B[x+2] ) +// -- when x and y are 0: E[0,0] ^= RC[i]; +inline fn keccakf1600_set_row( + reg ptr u64[25] e, + reg u64[5] b, + inline int y, + stack u64 s_rc) + -> + reg ptr u64[25] { - inline int j j1 j2 ti; + inline int x x1 x2; reg u64 t; - for j= 0 to 5 + for x=0 to 5 { - j1 = (j+1) % 5; - j2 = (j+2) % 5; - t = !c[j1] & c[j2]; - if row==0 && j==0 { t ^= iota; } - t ^= c[j]; - ti = __index(j, row); - r[ti] = t; + x1 = (x + 1) % 5; + x2 = (x + 2) % 5; + + // t = !b[x1] & b[x2]; // bmi1 + t = b[x1]; t = !t; t &= b[x2]; + + t ^= b[x]; + if( x==0 && y==0 ){ t ^= s_rc; } + e[x + y*5] = t; } - return r; + return e; } -inline fn __round2x_scalar(reg ptr u64[25] a r, reg u64 iota) -> reg ptr u64[25], reg ptr u64[25] +inline fn keccakf1600_round( + reg ptr u64[25] e, + reg ptr u64[25] a, + reg u64 rc) + -> + reg ptr u64[25] { - reg u64[5] c d; - - c = __theta_sum_scalar(a); - d = __theta_rol_scalar(c); - c = __rol_sum_scalar(d, a, 0); - r = __set_row_scalar(r, 0, c, iota); - c = __rol_sum_scalar(d, a, 3); - r = __set_row_scalar(r, 1, c, iota); - c = __rol_sum_scalar(d, a, 1); - r = __set_row_scalar(r, 2, c, iota); - c = __rol_sum_scalar(d, a, 4); - r = __set_row_scalar(r, 3, c, iota); - c = __rol_sum_scalar(d, a, 2); - r = __set_row_scalar(r, 4, c, iota); - - return a, r; + inline int y; + reg u64[5] b c d; + stack u64 s_rc; + + s_rc = rc; + + c = keccakf1600_theta_sum(a); + d = keccakf1600_theta_rol(c); + + for y = 0 to 5 + { b = keccakf1600_rol_sum(a, d, y); + e = keccakf1600_set_row(e, b, y, s_rc); + } + + return e; } -#[returnaddress="stack"] -fn _keccakf1600_scalar(reg ptr u64[25] a) -> reg ptr u64[25] +inline fn __keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] { - stack u64[25] r; - reg ptr u64[24] iotas_p; - reg u64 iota; - reg u64 round; - stack u64 round_s; + reg ptr u64[24] RC; + stack u64[25] s_e; + reg ptr u64[25] e; - iotas_p = KECCAK_RC; + reg u64 c rc; - round = 0; + RC = KECCAK1600_RC; + e = s_e; - while(round < 24) + c = 0; + while (c < KECCAK_ROUNDS) { - iota = iotas_p[(int) round]; - round_s = round; - a, r = __round2x_scalar(a, r, iota); - round = round_s; - round += 1; - - iota = iotas_p[(int) round]; - round_s = round; - r, a = __round2x_scalar(r, a, iotas_p[(int) round]); - round = round_s; - round += 1; + rc = RC[(int) c]; + e = keccakf1600_round(e, a, rc); + + rc = RC[(int) c + 1]; + a = keccakf1600_round(a, e, rc); + + c += 2; } return a; } + +fn _keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] +{ + a = __keccakf1600(a); + return a; +} + +inline fn _keccakf1600_(reg ptr u64[25] a) -> reg ptr u64[25] +{ + a = a; + a = _keccakf1600(a); + a = a; + return a; +}