diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix.jinc index 3fbcfa20..59e3b518 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix.jinc @@ -9,54 +9,46 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] reg u16 val1 val2; reg u16 t; reg u64 pos ctr; - reg u64 cnd0 cnd1 exit; ctr = offset; pos = 0; - exit = 0; - while(exit == 0) - { - val1 = (16u)buf[(int)pos]; - pos += 1; - t = (16u)buf[(int)pos]; - val2 = t; - val2 >>= 4; - t &= 0x0F; - t <<= 8; - val1 |= t; - pos += 1; - - t = (16u)buf[(int)pos]; - t <<= 4; - val2 |= t; - pos += 1; - - if(val1 < MLKEM_Q) - { - rp[(int)ctr] = val1; - ctr += 1; - } - - if(val2 < MLKEM_Q) - { - if(ctr < MLKEM_N) - { - rp[(int)ctr] = val2; + while (pos < SHAKE128_RATE - 2) { + if ctr < MLKEM_N { + val1 = (16u)buf[pos]; + t = (16u)buf[pos + 1]; + val2 = t; + val2 >>= 4; + t &= 0x0F; + t <<= 8; + val1 |= t; + + t = (16u)buf[pos + 2]; + t <<= 4; + val2 |= t; + pos += 3; + + reg bool cond; + #[declassify] + cond = val1 < MLKEM_Q; + if cond { + rp[ctr] = val1; ctr += 1; } - } - // Check if we should exit the loop - cnd0 = MLKEM_N; - cnd0 -= ctr; - cnd0 -= 1; - cnd1 = SHAKE128_RATE; - cnd1 -= pos; - cnd1 -= 3; //TODO: (potentially) wasting 2 'good' bytes - exit = cnd0 | cnd1; - exit >>= 63; + #[declassify] + cond = val2 < MLKEM_Q; + if cond { + if(ctr < MLKEM_N) + { + rp[ctr] = val2; + ctr += 1; + } + } + } else { + pos = SHAKE128_RATE; + } } return ctr, rp; diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/ref/gen_matrix.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/ref/gen_matrix.jinc index 5fa706ca..f261b711 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/ref/gen_matrix.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/ref/gen_matrix.jinc @@ -7,54 +7,46 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] reg u16 val1 val2; reg u16 t; reg u64 pos ctr; - reg u64 cnd0 cnd1 exit; ctr = offset; pos = 0; - exit = 0; - while(exit == 0) - { - val1 = (16u)buf[(int)pos]; - pos += 1; - t = (16u)buf[(int)pos]; - val2 = t; - val2 >>= 4; - t &= 0x0F; - t <<= 8; - val1 |= t; - pos += 1; - - t = (16u)buf[(int)pos]; - t <<= 4; - val2 |= t; - pos += 1; - - if(val1 < MLKEM_Q) - { - rp[(int)ctr] = val1; - ctr += 1; - } - - if(val2 < MLKEM_Q) - { - if(ctr < MLKEM_N) - { - rp[(int)ctr] = val2; + while (pos < SHAKE128_RATE - 2) { + if ctr < MLKEM_N { + val1 = (16u)buf[pos]; + t = (16u)buf[pos + 1]; + val2 = t; + val2 >>= 4; + t &= 0x0F; + t <<= 8; + val1 |= t; + + t = (16u)buf[pos + 2]; + t <<= 4; + val2 |= t; + pos += 3; + + reg bool cond; + #[declassify] + cond = val1 < MLKEM_Q; + if cond { + rp[ctr] = val1; ctr += 1; } - } - // Check if we should exit the loop - cnd0 = MLKEM_N; - cnd0 -= ctr; - cnd0 -= 1; - cnd1 = SHAKE128_RATE; - cnd1 -= pos; - cnd1 -= 3; //TODO: (potentially) wasting 2 'good' bytes - exit = cnd0 | cnd1; - exit >>= 63; + #[declassify] + cond = val2 < MLKEM_Q; + if cond { + if(ctr < MLKEM_N) + { + rp[ctr] = val2; + ctr += 1; + } + } + } else { + pos = SHAKE128_RATE; + } } return ctr, rp;