Skip to content

Commit

Permalink
implement ConstantTimePartialOrd
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicexplorer committed Jul 1, 2022
1 parent 9e258d9 commit 99b2b29
Showing 1 changed file with 91 additions and 28 deletions.
119 changes: 91 additions & 28 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ use core::option::Option;
pub struct Choice(u8);

impl Choice {
/// Create an instance in `const` context.
pub const fn of_bool(of: bool) -> Self {
Self(of as u8)
}

/// Unwrap the `Choice` wrapper to reveal the underlying `u8`.
///
/// # Note
Expand Down Expand Up @@ -740,7 +745,7 @@ generate_unsigned_integer_greater!(u128, 128);

/// A type which can be compared in some manner and be determined to be less
/// than another of the same type.
pub trait ConstantTimeLess: ConstantTimeEq + ConstantTimeGreater {
pub trait ConstantTimeLess: ConstantTimeGreater {
/// Determine whether `self < other`.
///
/// The bitwise-NOT of the return value of this function should be usable to
Expand Down Expand Up @@ -778,7 +783,7 @@ pub trait ConstantTimeLess: ConstantTimeEq + ConstantTimeGreater {
/// ```
#[inline]
fn ct_lt(&self, other: &Self) -> Choice {
!self.ct_gt(other) & !self.ct_eq(other)
other.ct_gt(self)
}
}

Expand All @@ -789,51 +794,109 @@ impl ConstantTimeLess for u64 {}
#[cfg(feature = "i128")]
impl ConstantTimeLess for u128 {}

/// A `Cmp`-like trait for constant-time comparisons.
/// A [`PartialOrd`][core::cmp::PartialOrd]-like trait for constant-time comparisons.
///
/// This trait is automatically implemented for types supporting the "equals", "less", and
/// "greater" comparisons.
///
/// # Example
///
/// ```
/// use std::cmp::Ordering;
/// use subtle_ng::{ConstantTimePartialOrd, CtOption};
/// let x: u8 = 5;
/// let y: u8 = 13;
///
/// assert_eq!(x.ct_partial_cmp(&x).unwrap(), Ordering::Equal);
/// assert_eq!(x.ct_partial_cmp(&y).unwrap(), Ordering::Less);
/// assert_eq!(y.ct_partial_cmp(&x).unwrap(), Ordering::Greater);
/// ```
pub trait ConstantTimePartialOrd {
/// This method returns an ordering between `self` and `other`, if it exists.
///
/// This method should execute in constant time.
fn ct_partial_cmp(&self, other: &Self) -> CtOption<Ordering>;
}

/// Select among `N + 1` results given `N` logical values, of which at most one should be true.
///
/// This trait is automatically implemented for types implementing both `ConstantTimeEq` and
/// `ConstantTimeGreater`, and makes it easy to implement [`core::cmp::Ord`] for types which can be
/// compared in constant time.
/// This method requires a whole set of logical checks to be performed before evaluating their
/// result, and uses a lookup table to avoid branching in a `match` expression.
fn index_mutually_exclusive_logical_results<T, const N: usize>(
results: &[T],
logicals: [Choice; N],
) -> &T {
assert_eq!(results.len(), N + 1);
let combined_result: u8 = logicals.iter().enumerate().fold(0u8, |x, (i, choice)| {
x + ((i as u8) + 1) * choice.unwrap_u8()
});
results
.get(combined_result as usize)
.expect("multiple inconsistent mutually exclusive logical operations returned true")
}

impl<T: ConstantTimeGreater + ConstantTimeLess + ConstantTimeEq> ConstantTimePartialOrd for T {
/// We do not assume a total ordering for `T`, so we have to individually check "less than",
/// "equal", and "greater". This also respects non-default implementations of `ct_lt()`.
fn ct_partial_cmp(&self, other: &Self) -> CtOption<Ordering> {
let is_eq = self.ct_eq(other);
let is_lt = self.ct_lt(other);
let is_gt = self.ct_gt(other);

const PARTIAL_ORDERS: [CtOption<Ordering>; 4] = [
CtOption {
value: Ordering::Equal,
is_some: Choice::of_bool(false),
},
CtOption {
value: Ordering::Equal,
is_some: Choice::of_bool(true),
},
CtOption {
value: Ordering::Less,
is_some: Choice::of_bool(true),
},
CtOption {
value: Ordering::Greater,
is_some: Choice::of_bool(true),
},
];
*index_mutually_exclusive_logical_results(&PARTIAL_ORDERS, [is_eq, is_lt, is_gt])
}
}

/// An [`Ord`][core::cmp::Ord]-like trait for constant-time comparisons.
///
/// This trait is automatically implemented for types supporting the "equals" and
/// "greater" comparisons.
///
/// # Example
///
/// ```
/// use std::cmp::Ordering;
/// use subtle_ng::ConstantTimeCmp;
/// use subtle_ng::ConstantTimeOrd;
/// let x: u8 = 5;
/// let y: u8 = 13;
///
/// assert_eq!(x.ct_cmp(&x), Ordering::Equal);
/// assert_eq!(x.ct_cmp(&y), Ordering::Less);
/// assert_eq!(y.ct_cmp(&x), Ordering::Greater);
/// ```
pub trait ConstantTimeCmp {
/// Determine if two items are equal, less, or greater than one another.
pub trait ConstantTimeOrd {
/// This method returns an ordering between `self` and other`.
///
/// The `ct_cmp` function should execute in constant time. A default implementation is provided
/// for implementers of both `ConstantTimeEq` and `ConstantTimeGreater`.
/// This method should execute in constant time.
fn ct_cmp(&self, other: &Self) -> Ordering;
}

impl<T: ConstantTimeEq + ConstantTimeGreater> ConstantTimeCmp for T {
/// Default implementation for whether `self <=> other`.
///
/// This implementation should execute in constant time.
impl<T: ConstantTimeEq + ConstantTimeGreater> ConstantTimeOrd for T {
/// We assume a total ordering for `T`, so we need to check only "equal" and "greater", and can
/// assume "less" if both `ct_eq()` and `ct_gt()` are false.
fn ct_cmp(&self, other: &Self) -> Ordering {
// The compiler should be forced to run *both* of these constant-time checks in order to get
// a value for `combined_result` through the `.unwrap_u8()` calls.
let is_eq = self.ct_eq(other).unwrap_u8();
let is_gt = self.ct_gt(other).unwrap_u8();
let is_gt = self.ct_gt(other);
let is_eq = self.ct_eq(other);

// If `ConstantTime{Eq,Greater}` are implemented correctly, *at most one* of "equal" or
// "greater" will be true, and if neither are true, then the result must be "less".
const ORDERS: [Ordering; 3] = [Ordering::Less, Ordering::Greater, Ordering::Equal];
// Since both of these may be 0 or 1, we do a one-hot encoding into two bits,
// which can take the consecutive values 0, 1, or 2. This is used to index into
// a lookup table to avoid branching.
let combined_result: u8 = (is_eq + is_eq) + is_gt;
*ORDERS
.get(combined_result as usize)
.expect("inconsistent implementations of ConstantTimeEq and ConstantTimeGreater returned both equal and greater for a comparison")
*index_mutually_exclusive_logical_results(&ORDERS, [is_gt, is_eq])
}
}

0 comments on commit 99b2b29

Please sign in to comment.