diff --git a/src/lib.rs b/src/lib.rs index 0b68b05..6b23047 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -822,17 +822,18 @@ impl ConstantTimeCmp for T { fn ct_cmp(&self, other: &Self) -> Ordering { // The compiler should be forced to run *both* of these constant-time checks in order to get // a value for `combined_result` through the `.unwrap_u8()` calls. - let is_eq = self.ct_eq(other); - let is_gt = self.ct_gt(other); + let is_eq = self.ct_eq(other).unwrap_u8(); + let is_gt = self.ct_gt(other).unwrap_u8(); // If `ConstantTime{Eq,Greater}` are implemented correctly, *at most one* of "equal" or // "greater" will be true, and if neither are true, then the result must be "less". - let combined_result: u8 = is_eq.unwrap_u8() * 2 + is_gt.unwrap_u8(); - match combined_result { - 2 => Ordering::Equal, - 1 => Ordering::Greater, - 0 => Ordering::Less, - x => unreachable!(".unwrap_u8() should never produce a value above 1: {}", x), - } + const ORDERS: [Ordering; 3] = [Ordering::Less, Ordering::Greater, Ordering::Equal]; + // Since both of these may be 0 or 1, we do a one-hot encoding into two bits, + // which can take the consecutive values 0, 1, or 2. This is used to index into + // a lookup table to avoid branching. + let combined_result: u8 = (is_eq + is_eq) + is_gt; + *ORDERS + .get(combined_result as usize) + .expect(".unwrap_u8() should never produce a value above 1") } }