diff --git a/benches/boxed_uint.rs b/benches/boxed_uint.rs index dbdf3dc1..28f88fb9 100644 --- a/benches/boxed_uint.rs +++ b/benches/boxed_uint.rs @@ -11,7 +11,7 @@ fn bench_shifts(c: &mut Criterion) { group.bench_function("shl_vartime", |b| { b.iter_batched( || BoxedUint::random(&mut OsRng, UINT_BITS), - |x| black_box(x.shl_vartime(UINT_BITS / 2 + 10)), + |x| black_box(x.overflowing_shl(UINT_BITS / 2 + 10).0), BatchSize::SmallInput, ) }); @@ -27,7 +27,7 @@ fn bench_shifts(c: &mut Criterion) { group.bench_function("shr_vartime", |b| { b.iter_batched( || BoxedUint::random(&mut OsRng, UINT_BITS), - |x| black_box(x.shr_vartime(UINT_BITS / 2 + 10)), + |x| black_box(x.overflowing_shr(UINT_BITS / 2 + 10).0), BatchSize::SmallInput, ) }); diff --git a/src/const_choice.rs b/src/const_choice.rs index b34a7b65..8a55365c 100644 --- a/src/const_choice.rs +++ b/src/const_choice.rs @@ -1,6 +1,6 @@ use subtle::{Choice, CtOption}; -use crate::{modular::BernsteinYangInverter, NonZero, Uint, Word}; +use crate::{modular::BernsteinYangInverter, Limb, NonZero, Uint, Word}; /// A boolean value returned by constant-time `const fn`s. // TODO: should be replaced by `subtle::Choice` or `CtOption` @@ -305,6 +305,20 @@ impl ConstCtOption>> { } } +impl ConstCtOption> { + /// Returns the contained value, consuming the `self` value. + /// + /// # Panics + /// + /// Panics if the value is none with a custom panic message provided by + /// `msg`. + #[inline] + pub const fn expect(self, msg: &str) -> NonZero { + assert!(self.is_some.is_true_vartime(), "{}", msg); + self.value + } +} + impl ConstCtOption> { diff --git a/src/lib.rs b/src/lib.rs index ab924070..22fcdbde 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -181,7 +181,7 @@ pub use crate::{ limb::{Limb, WideWord, Word}, non_zero::NonZero, traits::*, - uint::div_limb::Reciprocal, + uint::reciprocal::Reciprocal, uint::*, wrapping::Wrapping, }; diff --git a/src/modular/boxed_monty_form.rs b/src/modular/boxed_monty_form.rs index c2b3f2a1..c64c6b7e 100644 --- a/src/modular/boxed_monty_form.rs +++ b/src/modular/boxed_monty_form.rs @@ -9,6 +9,7 @@ mod pow; mod sub; use super::{ + div_by_2, reduction::{montgomery_reduction_boxed, montgomery_reduction_boxed_mut}, Retrieve, }; @@ -242,6 +243,18 @@ impl BoxedMontyForm { debug_assert!(self.montgomery_form < self.params.modulus); self.montgomery_form.clone() } + + /// Performs the modular division by 2, that is for given `x` returns `y` + /// such that `y * 2 = x mod p`. This means: + /// - if `x` is even, returns `x / 2`, + /// - if `x` is odd, returns `(x + p) / 2` + /// (since the modulus `p` in Montgomery form is always odd, this divides entirely). + pub fn div_by_2(&self) -> Self { + Self { + montgomery_form: div_by_2::boxed::div_by_2(&self.montgomery_form, &self.params.modulus), + params: self.params.clone(), // TODO: avoid clone? + } + } } impl Retrieve for BoxedMontyForm { @@ -263,7 +276,7 @@ fn convert_to_montgomery(integer: &mut BoxedUint, params: &BoxedMontyParams) { #[cfg(test)] mod tests { - use super::{BoxedMontyParams, BoxedUint}; + use super::{BoxedMontyForm, BoxedMontyParams, BoxedUint}; #[test] fn new_params_with_invalid_modulus() { @@ -280,4 +293,15 @@ mod tests { fn new_params_with_valid_modulus() { BoxedMontyParams::new(BoxedUint::from(3u8)).unwrap(); } + + #[test] + fn div_by_2() { + let params = BoxedMontyParams::new(BoxedUint::from(9u8)).unwrap(); + let zero = BoxedMontyForm::zero(params.clone()); + let one = BoxedMontyForm::one(params.clone()); + let two = one.add(&one); + + assert_eq!(zero.div_by_2(), zero); + assert_eq!(one.div_by_2().mul(&two), one); + } } diff --git a/src/modular/div_by_2.rs b/src/modular/div_by_2.rs index 1ad53b2a..c988ed05 100644 --- a/src/modular/div_by_2.rs +++ b/src/modular/div_by_2.rs @@ -28,3 +28,23 @@ pub(crate) fn div_by_2(a: &Uint, modulus: &Uint::select(&if_even, &if_odd, is_odd) } + +#[cfg(feature = "alloc")] +pub(crate) mod boxed { + use crate::{BoxedUint, ConstantTimeSelect}; + + pub(crate) fn div_by_2(a: &BoxedUint, modulus: &BoxedUint) -> BoxedUint { + debug_assert_eq!(a.bits_precision(), modulus.bits_precision()); + + let (mut half, is_odd) = a.shr1_with_carry(); + let half_modulus = modulus.shr1(); + + let if_odd = half + .wrapping_add(&half_modulus) + .wrapping_add(&BoxedUint::one_with_precision(a.bits_precision())); + + half.ct_assign(&if_odd, is_odd); + + half + } +} diff --git a/src/modular/monty_form.rs b/src/modular/monty_form.rs index e2c7ba58..f24bd065 100644 --- a/src/modular/monty_form.rs +++ b/src/modular/monty_form.rs @@ -266,4 +266,15 @@ mod test { MontyParams::::new(&Uint::from(2u8)).is_none() )) } + + #[test] + fn div_by_2() { + let params = MontyParams::new(&Uint::<1>::from(9u8)).unwrap(); + let zero = MontyForm::zero(params.clone()); + let one = MontyForm::one(params.clone()); + let two = one.add(&one); + + assert_eq!(zero.div_by_2(), zero); + assert_eq!(one.div_by_2().mul(&two), one); + } } diff --git a/src/uint.rs b/src/uint.rs index b721742e..33930426 100644 --- a/src/uint.rs +++ b/src/uint.rs @@ -24,6 +24,7 @@ pub(crate) mod mul; mod mul_mod; mod neg; mod neg_mod; +pub mod reciprocal; mod resize; mod shl; mod shr; diff --git a/src/uint/boxed.rs b/src/uint/boxed.rs index 7939d956..480dcde8 100644 --- a/src/uint/boxed.rs +++ b/src/uint/boxed.rs @@ -10,6 +10,7 @@ mod bits; mod cmp; mod ct; mod div; +mod div_limb; pub(crate) mod encoding; mod from; mod inv_mod; @@ -19,6 +20,7 @@ mod neg; mod neg_mod; mod shl; mod shr; +mod sqrt; mod sub; mod sub_mod; @@ -92,6 +94,17 @@ impl BoxedUint { .fold(Choice::from(1), |acc, limb| acc & limb.is_zero()) } + /// Returns the truthy value if `self`!=0 or the falsy value otherwise. + pub(crate) fn is_nonzero(&self) -> Choice { + let mut b = 0; + let mut i = 0; + while i < self.limbs.len() { + b |= self.limbs[i].0; + i += 1; + } + Limb(b).is_nonzero().into() + } + /// Is this [`BoxedUint`] equal to one? pub fn is_one(&self) -> Choice { let mut iter = self.limbs.iter(); @@ -245,6 +258,13 @@ impl BoxedUint { } impl NonZero { + /// TODO: this is not really "const", but I need a way to return (value, choice) since + /// BoxedUint is not [`ConditionallySelectable`] so `CtChoice::map` and such does not work + pub fn const_new(n: BoxedUint) -> (Self, Choice) { + let nonzero = n.is_nonzero(); + (Self(n), nonzero) + } + /// Widen this type's precision to the given number of bits. /// /// See [`BoxedUint::widen`] for more information, including panic conditions. diff --git a/src/uint/boxed/bits.rs b/src/uint/boxed/bits.rs index e60ebae0..c3756350 100644 --- a/src/uint/boxed/bits.rs +++ b/src/uint/boxed/bits.rs @@ -1,6 +1,6 @@ //! Bit manipulation functions. -use crate::{BoxedUint, Limb, Zero}; +use crate::{BoxedUint, ConstChoice, Limb, Zero}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; impl BoxedUint { @@ -24,6 +24,23 @@ impl BoxedUint { Limb::BITS * n - leading_zeros } + /// `floor(log2(self.bits_precision()))`. + pub(crate) fn log2_bits(&self) -> u32 { + u32::BITS - self.bits_precision().leading_zeros() - 1 + } + + /// Returns `true` if the bit at position `index` is set, `false` otherwise. + /// + /// # Remarks + /// This operation is variable time with respect to `index` only. + pub fn bit_vartime(&self, index: u32) -> bool { + if index >= self.bits_precision() { + false + } else { + (self.limbs[(index / Limb::BITS) as usize].0 >> (index % Limb::BITS)) & 1 == 1 + } + } + /// Calculate the number of bits needed to represent this number in variable-time with respect /// to `self`. pub fn bits_vartime(&self) -> u32 { @@ -55,6 +72,45 @@ impl BoxedUint { count } + /// Calculate the number of trailing ones in the binary representation of this number. + pub fn trailing_ones(&self) -> u32 { + let limbs = self.as_limbs(); + + let mut count = 0; + let mut i = 0; + let mut nonmax_limb_not_encountered = ConstChoice::TRUE; + while i < limbs.len() { + let l = limbs[i]; + let z = l.trailing_ones(); + count += nonmax_limb_not_encountered.if_true_u32(z); + nonmax_limb_not_encountered = + nonmax_limb_not_encountered.and(ConstChoice::from_word_eq(l.0, Limb::MAX.0)); + i += 1; + } + + count + } + + /// Calculate the number of trailing ones in the binary representation of this number, + /// variable time in `self`. + pub fn trailing_ones_vartime(&self) -> u32 { + let limbs = self.as_limbs(); + + let mut count = 0; + let mut i = 0; + while i < limbs.len() { + let l = limbs[i]; + let z = l.trailing_ones(); + count += z; + if z != Limb::BITS { + break; + } + i += 1; + } + + count + } + /// Sets the bit at `index` to 0 or 1 depending on the value of `bit_value`. pub(crate) fn set_bit(&mut self, index: u32, bit_value: Choice) { let limb_num = (index / Limb::BITS) as usize; @@ -84,7 +140,7 @@ mod tests { fn uint_with_bits_at(positions: &[u32]) -> BoxedUint { let mut result = BoxedUint::zero_with_precision(256); for &pos in positions { - result |= BoxedUint::one_with_precision(256).shl_vartime(pos).unwrap(); + result |= BoxedUint::one_with_precision(256).overflowing_shl(pos).0; } result } @@ -101,6 +157,54 @@ mod tests { assert_eq!(87, n2.bits()); } + #[test] + fn bit_vartime() { + let u = uint_with_bits_at(&[16, 48, 112, 127, 255]); + assert!(!u.bit_vartime(0)); + assert!(!u.bit_vartime(1)); + assert!(u.bit_vartime(16)); + assert!(u.bit_vartime(127)); + assert!(u.bit_vartime(255)); + assert!(!u.bit_vartime(256)); + assert!(!u.bit_vartime(260)); + } + + #[test] + fn trailing_ones() { + let u = !uint_with_bits_at(&[16, 79, 150]); + assert_eq!(u.trailing_ones(), 16); + + let u = !uint_with_bits_at(&[79, 150]); + assert_eq!(u.trailing_ones(), 79); + + let u = !uint_with_bits_at(&[150, 207]); + assert_eq!(u.trailing_ones(), 150); + + let u = !uint_with_bits_at(&[0, 150, 207]); + assert_eq!(u.trailing_ones(), 0); + + let u = !BoxedUint::zero_with_precision(256); + assert_eq!(u.trailing_ones(), 256); + } + + #[test] + fn trailing_ones_vartime() { + let u = !uint_with_bits_at(&[16, 79, 150]); + assert_eq!(u.trailing_ones_vartime(), 16); + + let u = !uint_with_bits_at(&[79, 150]); + assert_eq!(u.trailing_ones_vartime(), 79); + + let u = !uint_with_bits_at(&[150, 207]); + assert_eq!(u.trailing_ones_vartime(), 150); + + let u = !uint_with_bits_at(&[0, 150, 207]); + assert_eq!(u.trailing_ones_vartime(), 0); + + let u = !BoxedUint::zero_with_precision(256); + assert_eq!(u.trailing_ones_vartime(), 256); + } + #[test] fn set_bit() { let mut u = uint_with_bits_at(&[16, 79, 150]); diff --git a/src/uint/boxed/cmp.rs b/src/uint/boxed/cmp.rs index 14bef8d1..c93a0011 100644 --- a/src/uint/boxed/cmp.rs +++ b/src/uint/boxed/cmp.rs @@ -10,6 +10,28 @@ use subtle::{ Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess, }; +impl BoxedUint { + /// Returns the Ordering between `self` and `rhs` in variable time. + pub fn cmp_vartime(&self, rhs: &Self) -> Ordering { + debug_assert_eq!(self.limbs.len(), rhs.limbs.len()); + let mut i = self.limbs.len() - 1; + loop { + let (val, borrow) = self.limbs[i].sbb(rhs.limbs[i], Limb::ZERO); + if val.0 != 0 { + return if borrow.0 != 0 { + Ordering::Less + } else { + Ordering::Greater + }; + } + if i == 0 { + return Ordering::Equal; + } + i -= 1; + } + } +} + impl ConstantTimeEq for BoxedUint { #[inline] fn ct_eq(&self, other: &Self) -> Choice { diff --git a/src/uint/boxed/div.rs b/src/uint/boxed/div.rs index 1cf5ac8a..cfbc8f97 100644 --- a/src/uint/boxed/div.rs +++ b/src/uint/boxed/div.rs @@ -1,10 +1,24 @@ //! [`BoxedUint`] division operations. -use crate::{BoxedUint, CheckedDiv, ConstantTimeSelect, Limb, NonZero, Wrapping}; +use crate::{ + uint::boxed::div_limb, BoxedUint, CheckedDiv, ConstantTimeSelect, Limb, NonZero, Reciprocal, + Wrapping, +}; use core::ops::{Div, DivAssign, Rem, RemAssign}; use subtle::{Choice, ConstantTimeEq, ConstantTimeLess, CtOption}; impl BoxedUint { + /// Computes `self` / `rhs` using a pre-made reciprocal, + /// returns the quotient (q) and remainder (r). + pub fn div_rem_limb_with_reciprocal(&self, reciprocal: &Reciprocal) -> (Self, Limb) { + div_limb::div_rem_limb_with_reciprocal(self, reciprocal) + } + + /// Computes `self` / `rhs`, returns the quotient (q) and remainder (r). + pub fn div_rem_limb(&self, rhs: NonZero) -> (Self, Limb) { + div_limb::div_rem_limb_with_reciprocal(self, &Reciprocal::new(rhs)) + } + /// Computes self / rhs, returns the quotient, remainder. pub fn div_rem(&self, rhs: &NonZero) -> (Self, Self) { // Since `rhs` is nonzero, this should always hold. @@ -38,7 +52,7 @@ impl BoxedUint { let mut bd = self.bits_precision() - mb; let mut rem = self.clone(); // Will not overflow since `bd < bits_precision` - let mut c = rhs.shl_vartime(bd).expect("shift within range"); + let mut c = rhs.overflowing_shl(bd).0; loop { let (r, borrow) = rem.sbb(&c, Limb::ZERO); @@ -61,6 +75,15 @@ impl BoxedUint { self.div_rem(rhs).0 } + /// Wrapped division is just normal division i.e. `self` / `rhs` + /// + /// There’s no way wrapping could ever happen. + /// This function exists, so that all operations are accounted for in the wrapping operations. + pub fn wrapping_div_vartime(&self, rhs: &NonZero) -> Self { + let (q, _) = self.div_rem_vartime(rhs); + q + } + /// Perform checked division, returning a [`CtOption`] which `is_some` /// only if the rhs != 0 pub fn checked_div(&self, rhs: &Self) -> CtOption { @@ -112,7 +135,7 @@ impl BoxedUint { let mut remainder = self.clone(); let mut quotient = Self::zero_with_precision(self.bits_precision()); // Will not overflow since `bd < bits_precision` - let mut c = rhs.shl_vartime(bd).expect("shift within range"); + let mut c = rhs.overflowing_shl(bd).0; loop { let (mut r, borrow) = remainder.sbb(&c, Limb::ZERO); diff --git a/src/uint/boxed/div_limb.rs b/src/uint/boxed/div_limb.rs new file mode 100644 index 00000000..e108c045 --- /dev/null +++ b/src/uint/boxed/div_limb.rs @@ -0,0 +1,42 @@ +//! Implementation of constant-time division via reciprocal precomputation, as described in +//! "Improved Division by Invariant Integers" by Niels Möller and Torbjorn Granlund +//! (DOI: 10.1109/TC.2010.143, ). +use crate::{uint::reciprocal::div2by1, BoxedUint, Limb, Reciprocal}; + +/// Divides `u` by the divisor encoded in the `reciprocal`, and returns +/// the quotient and the remainder. +#[inline(always)] +pub(crate) fn div_rem_limb_with_reciprocal( + u: &BoxedUint, + reciprocal: &Reciprocal, +) -> (BoxedUint, Limb) { + let (u_shifted, u_hi) = u.shl_limb(reciprocal.shift()); + let mut r = u_hi.0; + let mut q = vec![Limb::ZERO; u.limbs.len()]; + + let mut j = u.limbs.len(); + while j > 0 { + j -= 1; + let (qj, rj) = div2by1(r, u_shifted.as_limbs()[j].0, reciprocal); + q[j] = Limb(qj); + r = rj; + } + (BoxedUint { limbs: q.into() }, Limb(r >> reciprocal.shift())) +} + +#[cfg(test)] +mod tests { + use super::{div2by1, Reciprocal}; + use crate::{Limb, NonZero, Word}; + #[test] + fn div2by1_overflow() { + // A regression test for a situation when in div2by1() an operation (`q1 + 1`) + // that is protected from overflowing by a condition in the original paper (`r >= d`) + // still overflows because we're calculating the results for both branches. + let r = Reciprocal::new(NonZero::new(Limb(Word::MAX - 1)).unwrap()); + assert_eq!( + div2by1(Word::MAX - 2, Word::MAX - 63, &r), + (Word::MAX, Word::MAX - 65) + ); + } +} diff --git a/src/uint/boxed/encoding.rs b/src/uint/boxed/encoding.rs index a5bd5478..b93d28cb 100644 --- a/src/uint/boxed/encoding.rs +++ b/src/uint/boxed/encoding.rs @@ -1,7 +1,7 @@ //! Const-friendly decoding operations for [`BoxedUint`]. use super::BoxedUint; -use crate::Limb; +use crate::{uint::encoding, Limb, Word}; use alloc::boxed::Box; use core::fmt; @@ -131,6 +131,40 @@ impl BoxedUint { out.into() } + + /// Create a new [`BoxedUint`] from the provided big endian hex string. + pub fn from_be_hex(hex: &str, bits_precision: u32) -> Self { + let nlimbs = (bits_precision / Limb::BITS) as usize; + let bytes = hex.as_bytes(); + + assert!( + bytes.len() == Limb::BYTES * nlimbs * 2, + "hex string is not the expected size" + ); + + let mut res = vec![Limb::ZERO; nlimbs]; + let mut buf = [0u8; Limb::BYTES]; + let mut i = 0; + let mut err = 0; + + while i < nlimbs { + let mut j = 0; + while j < Limb::BYTES { + let offset = (i * Limb::BYTES + j) * 2; + let (result, byte_err) = + encoding::decode_hex_byte([bytes[offset], bytes[offset + 1]]); + err |= byte_err; + buf[j] = result; + j += 1; + } + res[nlimbs - i - 1] = Limb(Word::from_be_bytes(buf)); + i += 1; + } + + assert!(err == 0, "invalid hex byte"); + + Self { limbs: res.into() } + } } #[cfg(test)] diff --git a/src/uint/boxed/shl.rs b/src/uint/boxed/shl.rs index 6b7ec99b..db641641 100644 --- a/src/uint/boxed/shl.rs +++ b/src/uint/boxed/shl.rs @@ -1,6 +1,6 @@ //! [`BoxedUint`] bitwise left shift operations. -use crate::{BoxedUint, ConstantTimeSelect, Limb, WrappingShl, Zero}; +use crate::{BoxedUint, ConstChoice, ConstantTimeSelect, Limb, Word, WrappingShl, Zero}; use core::ops::{Shl, ShlAssign}; use subtle::{Choice, ConstantTimeLess}; @@ -111,19 +111,6 @@ impl BoxedUint { Some(()) } - /// Computes `self << shift`. - /// Returns `None` if `shift >= self.bits_precision()`. - /// - /// NOTE: this operation is variable time with respect to `shift` *ONLY*. - /// - /// When used with a fixed `shift`, this function is constant-time with respect to `self`. - #[inline(always)] - pub fn shl_vartime(&self, shift: u32) -> Option { - let mut result = Self::zero_with_precision(self.bits_precision()); - let success = self.shl_vartime_into(&mut result, shift); - success.map(|_| result) - } - /// Computes `self << 1` in constant-time. pub(crate) fn shl1(&self) -> Self { let mut ret = self.clone(); @@ -142,6 +129,38 @@ impl BoxedUint { carry = new_carry } } + + /// Computes `self << shift` where `0 <= shift < Limb::BITS`, + /// returning the result and the carry. + pub(crate) fn shl_limb(&self, shift: u32) -> (Self, Limb) { + let mut limbs = vec![Limb::ZERO; self.limbs.len()]; + + let nz = ConstChoice::from_u32_nonzero(shift); + let lshift = shift; + let rshift = nz.if_true_u32(Limb::BITS - shift); + let carry = nz.if_true_word( + self.limbs[self.limbs.len() - 1] + .0 + .wrapping_shr(Word::BITS - shift), + ); + + limbs[0] = Limb(self.limbs[0].0 << lshift); + let mut i = 1; + while i < self.limbs.len() { + let mut limb = self.limbs[i].0 << lshift; + let hi = self.limbs[i - 1].0 >> rshift; + limb |= nz.if_true_word(hi); + limbs[i] = Limb(limb); + i += 1 + } + + ( + BoxedUint { + limbs: limbs.into(), + }, + Limb(carry), + ) + } } macro_rules! impl_shl { @@ -202,7 +221,7 @@ mod tests { assert_eq!(BoxedUint::from(4u8), &one << 2); assert_eq!( BoxedUint::from(0x80000000000000000u128), - one.shl_vartime(67).unwrap() + one.overflowing_shl(67).0 ); } @@ -210,11 +229,11 @@ mod tests { fn shl_vartime() { let one = BoxedUint::one_with_precision(128); - assert_eq!(BoxedUint::from(2u8), one.shl_vartime(1).unwrap()); - assert_eq!(BoxedUint::from(4u8), one.shl_vartime(2).unwrap()); + assert_eq!(BoxedUint::from(2u8), one.overflowing_shl(1).0); + assert_eq!(BoxedUint::from(4u8), one.overflowing_shl(2).0); assert_eq!( BoxedUint::from(0x80000000000000000u128), - one.shl_vartime(67).unwrap() + one.overflowing_shl(67).0 ); } } diff --git a/src/uint/boxed/shr.rs b/src/uint/boxed/shr.rs index 28397125..e94938f2 100644 --- a/src/uint/boxed/shr.rs +++ b/src/uint/boxed/shr.rs @@ -115,19 +115,6 @@ impl BoxedUint { Some(()) } - /// Computes `self >> shift`. - /// Returns `None` if `shift >= self.bits_precision()`. - /// - /// NOTE: this operation is variable time with respect to `shift` *ONLY*. - /// - /// When used with a fixed `shift`, this function is constant-time with respect to `self`. - #[inline(always)] - pub fn shr_vartime(&self, shift: u32) -> Option { - let mut result = Self::zero_with_precision(self.bits_precision()); - let success = self.shr_vartime_into(&mut result, shift); - success.map(|_| result) - } - /// Computes `self >> 1` in constant-time, returning a true [`Choice`] /// if the least significant bit was set, and a false [`Choice::FALSE`] otherwise. pub(crate) fn shr1_with_carry(&self) -> (Self, Choice) { @@ -216,9 +203,9 @@ mod tests { #[test] fn shr_vartime() { let n = BoxedUint::from(0x80000000000000000u128); - assert_eq!(BoxedUint::zero(), n.shr_vartime(68).unwrap()); - assert_eq!(BoxedUint::one(), n.shr_vartime(67).unwrap()); - assert_eq!(BoxedUint::from(2u8), n.shr_vartime(66).unwrap()); - assert_eq!(BoxedUint::from(4u8), n.shr_vartime(65).unwrap()); + assert_eq!(BoxedUint::zero(), n.overflowing_shr(68).0); + assert_eq!(BoxedUint::one(), n.overflowing_shr(67).0); + assert_eq!(BoxedUint::from(2u8), n.overflowing_shr(66).0); + assert_eq!(BoxedUint::from(4u8), n.overflowing_shr(65).0); } } diff --git a/src/uint/boxed/sqrt.rs b/src/uint/boxed/sqrt.rs new file mode 100644 index 00000000..f2eac2a3 --- /dev/null +++ b/src/uint/boxed/sqrt.rs @@ -0,0 +1,294 @@ +//! [`BoxedUint`] square root operations. + +use subtle::{ConstantTimeEq, ConstantTimeGreater, CtOption}; + +use crate::{BoxedUint, ConstantTimeSelect, NonZero}; + +impl BoxedUint { + /// Computes √(`self`) in constant time. + /// + /// Callers can check if `self` is a square by squaring the result + pub fn sqrt(&self) -> Self { + // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13. + // + // See Hast, "Note on computation of integer square roots" + // for the proof of the sufficiency of the bound on iterations. + // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf + + // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. + // Will not overflow since `b <= BITS`. + let (mut x, _overflow) = + Self::one_with_precision(self.bits_precision()).overflowing_shl((self.bits() + 1) >> 1); // ≥ √(`self`) + + // Repeat enough times to guarantee result has stabilized. + let mut i = 0; + // TODO: avoid this clone + let mut x_prev = x.clone(); // keep the previous iteration in case we need to roll back. + + // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough. + while i < self.log2_bits() + 2 { + x_prev = x.clone(); // TODO: can we avoid this clone? + + // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)` + + let (nz_x, is_some) = NonZero::::const_new(x.clone()); // TODO: avoid this clone + let (q, _) = self.div_rem(&nz_x); + + // A protection in case `self == 0`, which will make `x == 0` + let q = Self::ct_select( + &Self::zero_with_precision(self.bits_precision()), + &q, + is_some, + ); + + x = x.wrapping_add(&q).shr1(); + i += 1; + } + + // At this point `x_prev == x_{n}` and `x == x_{n+1}` + // where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`. + // Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`. + Self::ct_select(&x_prev, &x, Self::ct_gt(&x_prev, &x)) + } + + /// Computes √(`self`) + /// + /// Callers can check if `self` is a square by squaring the result + pub fn sqrt_vartime(&self) -> Self { + // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 + + // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. + // Will not overflow since `b <= BITS`. + let (mut x, _overflow) = + Self::one_with_precision(self.bits_precision()).overflowing_shl((self.bits() + 1) >> 1); // ≥ √(`self`) + + // Stop right away if `x` is zero to avoid divizion by zero. + while !x + .cmp_vartime(&Self::zero_with_precision(self.bits_precision())) + .is_eq() + { + // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)` + let q = + self.wrapping_div_vartime(&NonZero::::new(x.clone()).expect("Division by 0")); + let t = x.wrapping_add(&q); + let next_x = t.shr1(); + + // If `next_x` is the same as `x` or greater, we reached convergence + // (`x` is guaranteed to either go down or oscillate between + // `sqrt(self)` and `sqrt(self) + 1`) + if !x.cmp_vartime(&next_x).is_gt() { + break; + } + + x = next_x; + } + + if self.is_nonzero().into() { + x + } else { + Self::zero_with_precision(self.bits_precision()) + } + } + + /// Wrapped sqrt is just normal √(`self`) + /// There’s no way wrapping could ever happen. + /// This function exists so that all operations are accounted for in the wrapping operations. + pub fn wrapping_sqrt(&self) -> Self { + self.sqrt() + } + + /// Wrapped sqrt is just normal √(`self`) + /// There’s no way wrapping could ever happen. + /// This function exists so that all operations are accounted for in the wrapping operations. + pub fn wrapping_sqrt_vartime(&self) -> Self { + self.sqrt_vartime() + } + + /// Perform checked sqrt, returning a [`CtOption`] which `is_some` + /// only if the √(`self`)² == self + pub fn checked_sqrt(&self) -> CtOption { + let r = self.sqrt(); + let s = r.wrapping_mul(&r); + CtOption::new(r, ConstantTimeEq::ct_eq(self, &s)) + } + + /// Perform checked sqrt, returning a [`CtOption`] which `is_some` + /// only if the √(`self`)² == self + pub fn checked_sqrt_vartime(&self) -> CtOption { + let r = self.sqrt_vartime(); + let s = r.wrapping_mul(&r); + CtOption::new(r, ConstantTimeEq::ct_eq(self, &s)) + } +} + +#[cfg(test)] +mod tests { + use crate::{BoxedUint, Limb}; + + #[cfg(feature = "rand")] + use { + crate::CheckedMul, + rand_chacha::ChaChaRng, + rand_core::{RngCore, SeedableRng}, + }; + + #[test] + fn edge() { + assert_eq!( + BoxedUint::zero_with_precision(256).sqrt(), + BoxedUint::zero_with_precision(256) + ); + assert_eq!( + BoxedUint::one_with_precision(256).sqrt(), + BoxedUint::one_with_precision(256) + ); + let mut half = BoxedUint::zero_with_precision(256); + for i in 0..half.limbs.len() / 2 { + half.limbs[i] = Limb::MAX; + } + let u256_max = !BoxedUint::zero_with_precision(256); + assert_eq!(u256_max.sqrt(), half); + + // Test edge cases that use up the maximum number of iterations. + + // `x = (r + 1)^2 - 583`, where `r` is the expected square root. + assert_eq!( + BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192).sqrt(), + BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192) + ); + assert_eq!( + BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192) + .sqrt_vartime(), + BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192) + ); + + // `x = (r + 1)^2 - 205`, where `r` is the expected square root. + assert_eq!( + BoxedUint::from_be_hex( + "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597", + 256 + ) + .sqrt(), + BoxedUint::from_be_hex( + "000000000000000000000000000000008b3956339e8315cff66eb6107b610075", + 256 + ) + ); + assert_eq!( + BoxedUint::from_be_hex( + "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597", + 256 + ) + .sqrt_vartime(), + BoxedUint::from_be_hex( + "000000000000000000000000000000008b3956339e8315cff66eb6107b610075", + 256 + ) + ); + } + + #[test] + fn edge_vartime() { + assert_eq!( + BoxedUint::zero_with_precision(256).sqrt_vartime(), + BoxedUint::zero_with_precision(256) + ); + assert_eq!( + BoxedUint::one_with_precision(256).sqrt_vartime(), + BoxedUint::one_with_precision(256) + ); + let mut half = BoxedUint::zero_with_precision(256); + for i in 0..half.limbs.len() / 2 { + half.limbs[i] = Limb::MAX; + } + let u256_max = !BoxedUint::zero_with_precision(256); + assert_eq!(u256_max.sqrt_vartime(), half); + } + + #[test] + fn simple() { + let tests = [ + (4u8, 2u8), + (9, 3), + (16, 4), + (25, 5), + (36, 6), + (49, 7), + (64, 8), + (81, 9), + (100, 10), + (121, 11), + (144, 12), + (169, 13), + ]; + for (a, e) in &tests { + let l = BoxedUint::from(*a); + let r = BoxedUint::from(*e); + assert_eq!(l.sqrt(), r); + assert_eq!(l.sqrt_vartime(), r); + assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8); + assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8); + } + } + + #[test] + fn nonsquares() { + assert_eq!(BoxedUint::from(2u8).sqrt(), BoxedUint::from(1u8)); + assert_eq!(BoxedUint::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0); + assert_eq!(BoxedUint::from(3u8).sqrt(), BoxedUint::from(1u8)); + assert_eq!(BoxedUint::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0); + assert_eq!(BoxedUint::from(5u8).sqrt(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(6u8).sqrt(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(7u8).sqrt(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(8u8).sqrt(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(10u8).sqrt(), BoxedUint::from(3u8)); + } + + #[test] + fn nonsquares_vartime() { + assert_eq!(BoxedUint::from(2u8).sqrt_vartime(), BoxedUint::from(1u8)); + assert_eq!( + BoxedUint::from(2u8) + .checked_sqrt_vartime() + .is_some() + .unwrap_u8(), + 0 + ); + assert_eq!(BoxedUint::from(3u8).sqrt_vartime(), BoxedUint::from(1u8)); + assert_eq!( + BoxedUint::from(3u8) + .checked_sqrt_vartime() + .is_some() + .unwrap_u8(), + 0 + ); + assert_eq!(BoxedUint::from(5u8).sqrt_vartime(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(6u8).sqrt_vartime(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(7u8).sqrt_vartime(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(8u8).sqrt_vartime(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(10u8).sqrt_vartime(), BoxedUint::from(3u8)); + } + + #[cfg(feature = "rand")] + #[test] + fn fuzz() { + let mut rng = ChaChaRng::from_seed([7u8; 32]); + for _ in 0..50 { + let t = rng.next_u32() as u64; + let s = BoxedUint::from(t); + let s2 = s.checked_mul(&s).unwrap(); + assert_eq!(s2.sqrt(), s); + assert_eq!(s2.sqrt_vartime(), s); + assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1); + assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1); + } + + for _ in 0..50 { + let s = BoxedUint::random(&mut rng, 512); + let mut s2 = BoxedUint::zero_with_precision(512); + s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs); + assert_eq!(s.square().sqrt(), s2); + assert_eq!(s.square().sqrt_vartime(), s2); + } + } +} diff --git a/src/uint/div.rs b/src/uint/div.rs index d650235d..0f9c5044 100644 --- a/src/uint/div.rs +++ b/src/uint/div.rs @@ -1,7 +1,9 @@ //! [`Uint`] division operations. -use super::div_limb::{div_rem_limb_with_reciprocal, Reciprocal}; -use crate::{CheckedDiv, ConstChoice, Limb, NonZero, Uint, Word, Wrapping}; +use super::div_limb::div_rem_limb_with_reciprocal; +use crate::{ + uint::reciprocal::Reciprocal, CheckedDiv, ConstChoice, Limb, NonZero, Uint, Word, Wrapping, +}; use core::ops::{Div, DivAssign, Rem, RemAssign}; use subtle::CtOption; diff --git a/src/uint/div_limb.rs b/src/uint/div_limb.rs index f8d47f8b..7f01343b 100644 --- a/src/uint/div_limb.rs +++ b/src/uint/div_limb.rs @@ -1,213 +1,11 @@ //! Implementation of constant-time division via reciprocal precomputation, as described in //! "Improved Division by Invariant Integers" by Niels Möller and Torbjorn Granlund //! (DOI: 10.1109/TC.2010.143, ). -use subtle::{Choice, ConditionallySelectable}; - use crate::{ - primitives::{addhilo, mulhilo}, - ConstChoice, Limb, NonZero, Uint, Word, + uint::reciprocal::{div2by1, Reciprocal}, + Limb, Uint, }; -/// Calculates the reciprocal of the given 32-bit divisor with the highmost bit set. -#[cfg(target_pointer_width = "32")] -pub const fn reciprocal(d: Word) -> Word { - debug_assert!(d >= (1 << (Word::BITS - 1))); - - let d0 = d & 1; - let d10 = d >> 22; - let d21 = (d >> 11) + 1; - let d31 = (d >> 1) + d0; - let v0 = short_div((1 << 24) - (1 << 14) + (1 << 9), 24, d10, 10); - let (hi, _lo) = mulhilo(v0 * v0, d21); - let v1 = (v0 << 4) - hi - 1; - - // Checks that the expression for `e` can be simplified in the way we did below. - debug_assert!(mulhilo(v1, d31).0 == (1 << 16) - 1); - let e = Word::MAX - v1.wrapping_mul(d31) + 1 + (v1 >> 1) * d0; - - let (hi, _lo) = mulhilo(v1, e); - // Note: the paper does not mention a wrapping add here, - // but the 64-bit version has it at this stage, and the function panics without it - // when calculating a reciprocal for `Word::MAX`. - let v2 = (v1 << 15).wrapping_add(hi >> 1); - - // The paper has `(v2 + 1) * d / 2^32` (there's another 2^32, but it's accounted for later). - // If `v2 == 2^32-1` this should give `d`, but we can't achieve this in our wrapping arithmetic. - // Hence the `ct_select()`. - let x = v2.wrapping_add(1); - let (hi, _lo) = mulhilo(x, d); - let hi = ConstChoice::from_u32_nonzero(x).select_word(d, hi); - - v2.wrapping_sub(hi).wrapping_sub(d) -} - -/// Calculates the reciprocal of the given 64-bit divisor with the highmost bit set. -#[cfg(target_pointer_width = "64")] -pub const fn reciprocal(d: Word) -> Word { - debug_assert!(d >= (1 << (Word::BITS - 1))); - - let d0 = d & 1; - let d9 = d >> 55; - let d40 = (d >> 24) + 1; - let d63 = (d >> 1) + d0; - let v0 = short_div((1 << 19) - 3 * (1 << 8), 19, d9 as u32, 9) as u64; - let v1 = (v0 << 11) - ((v0 * v0 * d40) >> 40) - 1; - let v2 = (v1 << 13) + ((v1 * ((1 << 60) - v1 * d40)) >> 47); - - // Checks that the expression for `e` can be simplified in the way we did below. - debug_assert!(mulhilo(v2, d63).0 == (1 << 32) - 1); - let e = Word::MAX - v2.wrapping_mul(d63) + 1 + (v2 >> 1) * d0; - - let (hi, _lo) = mulhilo(v2, e); - let v3 = (v2 << 31).wrapping_add(hi >> 1); - - // The paper has `(v3 + 1) * d / 2^64` (there's another 2^64, but it's accounted for later). - // If `v3 == 2^64-1` this should give `d`, but we can't achieve this in our wrapping arithmetic. - // Hence the `ct_select()`. - let x = v3.wrapping_add(1); - let (hi, _lo) = mulhilo(x, d); - let hi = ConstChoice::from_word_nonzero(x).select_word(d, hi); - - v3.wrapping_sub(hi).wrapping_sub(d) -} - -/// Returns `u32::MAX` if `a < b` and `0` otherwise. -#[inline] -const fn lt(a: u32, b: u32) -> u32 { - let bit = (((!a) & b) | (((!a) | b) & (a.wrapping_sub(b)))) >> (u32::BITS - 1); - bit.wrapping_neg() -} - -/// Returns `a` if `c == 0` and `b` if `c == u32::MAX`. -#[inline(always)] -const fn select(a: u32, b: u32, c: u32) -> u32 { - a ^ (c & (a ^ b)) -} - -/// Calculates `dividend / divisor`, given `dividend` and `divisor` -/// along with their maximum bitsizes. -#[inline(always)] -const fn short_div(dividend: u32, dividend_bits: u32, divisor: u32, divisor_bits: u32) -> u32 { - // TODO: this may be sped up even more using the fact that `dividend` is a known constant. - - // In the paper this is a table lookup, but since we want it to be constant-time, - // we have to access all the elements of the table, which is quite large. - // So this shift-and-subtract approach is actually faster. - - // Passing `dividend_bits` and `divisor_bits` because calling `.leading_zeros()` - // causes a significant slowdown, and we know those values anyway. - - let mut dividend = dividend; - let mut divisor = divisor << (dividend_bits - divisor_bits); - let mut quotient: u32 = 0; - let mut i = dividend_bits - divisor_bits + 1; - - while i > 0 { - i -= 1; - let bit = lt(dividend, divisor); - dividend = select(dividend.wrapping_sub(divisor), dividend, bit); - divisor >>= 1; - let inv_bit = !bit; - quotient |= (inv_bit >> (u32::BITS - 1)) << i; - } - - quotient -} - -/// Calculate the quotient and the remainder of the division of a wide word -/// (supplied as high and low words) by `d`, with a precalculated reciprocal `v`. -#[inline(always)] -const fn div2by1(u1: Word, u0: Word, reciprocal: &Reciprocal) -> (Word, Word) { - let d = reciprocal.divisor_normalized; - - debug_assert!(d >= (1 << (Word::BITS - 1))); - debug_assert!(u1 < d); - - let (q1, q0) = mulhilo(reciprocal.reciprocal, u1); - let (q1, q0) = addhilo(q1, q0, u1, u0); - let q1 = q1.wrapping_add(1); - let r = u0.wrapping_sub(q1.wrapping_mul(d)); - - let r_gt_q0 = ConstChoice::from_word_lt(q0, r); - let q1 = r_gt_q0.select_word(q1, q1.wrapping_sub(1)); - let r = r_gt_q0.select_word(r, r.wrapping_add(d)); - - // If this was a normal `if`, we wouldn't need wrapping ops, because there would be no overflow. - // But since we calculate both results either way, we have to wrap. - // Added an assert to still check the lack of overflow in debug mode. - debug_assert!(r < d || q1 < Word::MAX); - let r_ge_d = ConstChoice::from_word_le(d, r); - let q1 = r_ge_d.select_word(q1, q1.wrapping_add(1)); - let r = r_ge_d.select_word(r, r.wrapping_sub(d)); - - (q1, r) -} - -/// A pre-calculated reciprocal for division by a single limb. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct Reciprocal { - divisor_normalized: Word, - shift: u32, - reciprocal: Word, -} - -impl Reciprocal { - /// Pre-calculates a reciprocal for a known divisor, - /// to be used in the single-limb division later. - pub const fn new(divisor: NonZero) -> Self { - let divisor = divisor.0; - - // Assuming this is constant-time for primitive types. - let shift = divisor.0.leading_zeros(); - - // Will not panic since divisor is non-zero - let divisor_normalized = divisor.0 << shift; - - Self { - divisor_normalized, - shift, - reciprocal: reciprocal(divisor_normalized), - } - } - - /// Returns a default instance of this object. - /// It is a self-consistent `Reciprocal` that will not cause panics in functions that take it. - /// - /// NOTE: intended for using it as a placeholder during compile-time array generation, - /// don't rely on the contents. - pub const fn default() -> Self { - Self { - divisor_normalized: Word::MAX, - shift: 0, - // The result of calling `reciprocal(Word::MAX)` - // This holds both for 32- and 64-bit versions. - reciprocal: 1, - } - } -} - -impl ConditionallySelectable for Reciprocal { - fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { - Self { - divisor_normalized: Word::conditional_select( - &a.divisor_normalized, - &b.divisor_normalized, - choice, - ), - shift: u32::conditional_select(&a.shift, &b.shift, choice), - reciprocal: Word::conditional_select(&a.reciprocal, &b.reciprocal, choice), - } - } -} - -// `CtOption.map()` needs this; for some reason it doesn't use the value it already has -// for the `None` branch. -impl Default for Reciprocal { - fn default() -> Self { - Self::default() - } -} - /// Divides `u` by the divisor encoded in the `reciprocal`, and returns /// the quotient and the remainder. #[inline(always)] @@ -215,7 +13,7 @@ pub(crate) const fn div_rem_limb_with_reciprocal( u: &Uint, reciprocal: &Reciprocal, ) -> (Uint, Limb) { - let (u_shifted, u_hi) = u.shl_limb(reciprocal.shift); + let (u_shifted, u_hi) = u.shl_limb(reciprocal.shift()); let mut r = u_hi.0; let mut q = [Limb::ZERO; L]; @@ -226,7 +24,7 @@ pub(crate) const fn div_rem_limb_with_reciprocal( q[j] = Limb(qj); r = rj; } - (Uint::::new(q), Limb(r >> reciprocal.shift)) + (Uint::::new(q), Limb(r >> reciprocal.shift())) } #[cfg(test)] diff --git a/src/uint/encoding.rs b/src/uint/encoding.rs index b89e980b..cdd73e8c 100644 --- a/src/uint/encoding.rs +++ b/src/uint/encoding.rs @@ -236,7 +236,7 @@ const fn decode_nibble(src: u8) -> u16 { /// Second element of the tuple is non-zero if the `bytes` values are not in the valid range /// (0-9, a-z, A-Z). #[inline(always)] -const fn decode_hex_byte(bytes: [u8; 2]) -> (u8, u16) { +pub(crate) const fn decode_hex_byte(bytes: [u8; 2]) -> (u8, u16) { let hi = decode_nibble(bytes[0]); let lo = decode_nibble(bytes[1]); let byte = (hi << 4) | lo; diff --git a/src/uint/reciprocal.rs b/src/uint/reciprocal.rs new file mode 100644 index 00000000..1f001cbc --- /dev/null +++ b/src/uint/reciprocal.rs @@ -0,0 +1,208 @@ +//! Reciprocal, shared across Uint and BoxedUint +use crate::{primitives, ConstChoice, Limb, NonZero, Word}; +use subtle::{Choice, ConditionallySelectable}; + +/// Calculates the reciprocal of the given 32-bit divisor with the highmost bit set. +#[cfg(target_pointer_width = "32")] +pub const fn reciprocal(d: Word) -> Word { + debug_assert!(d >= (1 << (Word::BITS - 1))); + + let d0 = d & 1; + let d10 = d >> 22; + let d21 = (d >> 11) + 1; + let d31 = (d >> 1) + d0; + let v0 = short_div((1 << 24) - (1 << 14) + (1 << 9), 24, d10, 10); + let (hi, _lo) = primitives::mulhilo(v0 * v0, d21); + let v1 = (v0 << 4) - hi - 1; + + // Checks that the expression for `e` can be simplified in the way we did below. + debug_assert!(primitives::mulhilo(v1, d31).0 == (1 << 16) - 1); + let e = Word::MAX - v1.wrapping_mul(d31) + 1 + (v1 >> 1) * d0; + + let (hi, _lo) = primitives::mulhilo(v1, e); + // Note: the paper does not mention a wrapping add here, + // but the 64-bit version has it at this stage, and the function panics without it + // when calculating a reciprocal for `Word::MAX`. + let v2 = (v1 << 15).wrapping_add(hi >> 1); + + // The paper has `(v2 + 1) * d / 2^32` (there's another 2^32, but it's accounted for later). + // If `v2 == 2^32-1` this should give `d`, but we can't achieve this in our wrapping arithmetic. + // Hence the `ct_select()`. + let x = v2.wrapping_add(1); + let (hi, _lo) = primitives::mulhilo(x, d); + let hi = ConstChoice::from_u32_nonzero(x).select_word(d, hi); + + v2.wrapping_sub(hi).wrapping_sub(d) +} + +/// Calculates the reciprocal of the given 64-bit divisor with the highmost bit set. +#[cfg(target_pointer_width = "64")] +pub const fn reciprocal(d: Word) -> Word { + debug_assert!(d >= (1 << (Word::BITS - 1))); + + let d0 = d & 1; + let d9 = d >> 55; + let d40 = (d >> 24) + 1; + let d63 = (d >> 1) + d0; + let v0 = short_div((1 << 19) - 3 * (1 << 8), 19, d9 as u32, 9) as u64; + let v1 = (v0 << 11) - ((v0 * v0 * d40) >> 40) - 1; + let v2 = (v1 << 13) + ((v1 * ((1 << 60) - v1 * d40)) >> 47); + + // Checks that the expression for `e` can be simplified in the way we did below. + debug_assert!(primitives::mulhilo(v2, d63).0 == (1 << 32) - 1); + let e = Word::MAX - v2.wrapping_mul(d63) + 1 + (v2 >> 1) * d0; + + let (hi, _lo) = primitives::mulhilo(v2, e); + let v3 = (v2 << 31).wrapping_add(hi >> 1); + + // The paper has `(v3 + 1) * d / 2^64` (there's another 2^64, but it's accounted for later). + // If `v3 == 2^64-1` this should give `d`, but we can't achieve this in our wrapping arithmetic. + // Hence the `ct_select()`. + let x = v3.wrapping_add(1); + let (hi, _lo) = primitives::mulhilo(x, d); + let hi = ConstChoice::from_word_nonzero(x).select_word(d, hi); + + v3.wrapping_sub(hi).wrapping_sub(d) +} + +/// Returns `u32::MAX` if `a < b` and `0` otherwise. +#[inline] +const fn lt(a: u32, b: u32) -> u32 { + let bit = (((!a) & b) | (((!a) | b) & (a.wrapping_sub(b)))) >> (u32::BITS - 1); + bit.wrapping_neg() +} + +/// Returns `a` if `c == 0` and `b` if `c == u32::MAX`. +#[inline(always)] +const fn select(a: u32, b: u32, c: u32) -> u32 { + a ^ (c & (a ^ b)) +} + +/// Calculates `dividend / divisor`, given `dividend` and `divisor` +/// along with their maximum bitsizes. +#[inline(always)] +const fn short_div(dividend: u32, dividend_bits: u32, divisor: u32, divisor_bits: u32) -> u32 { + // TODO: this may be sped up even more using the fact that `dividend` is a known constant. + + // In the paper this is a table lookup, but since we want it to be constant-time, + // we have to access all the elements of the table, which is quite large. + // So this shift-and-subtract approach is actually faster. + + // Passing `dividend_bits` and `divisor_bits` because calling `.leading_zeros()` + // causes a significant slowdown, and we know those values anyway. + + let mut dividend = dividend; + let mut divisor = divisor << (dividend_bits - divisor_bits); + let mut quotient: u32 = 0; + let mut i = dividend_bits - divisor_bits + 1; + + while i > 0 { + i -= 1; + let bit = lt(dividend, divisor); + dividend = select(dividend.wrapping_sub(divisor), dividend, bit); + divisor >>= 1; + let inv_bit = !bit; + quotient |= (inv_bit >> (u32::BITS - 1)) << i; + } + + quotient +} + +/// Calculate the quotient and the remainder of the division of a wide word +/// (supplied as high and low words) by `d`, with a precalculated reciprocal `v`. +#[inline(always)] +pub(crate) const fn div2by1(u1: Word, u0: Word, reciprocal: &Reciprocal) -> (Word, Word) { + let d = reciprocal.divisor_normalized; + + debug_assert!(d >= (1 << (Word::BITS - 1))); + debug_assert!(u1 < d); + + let (q1, q0) = primitives::mulhilo(reciprocal.reciprocal, u1); + let (q1, q0) = primitives::addhilo(q1, q0, u1, u0); + let q1 = q1.wrapping_add(1); + let r = u0.wrapping_sub(q1.wrapping_mul(d)); + + let r_gt_q0 = ConstChoice::from_word_lt(q0, r); + let q1 = r_gt_q0.select_word(q1, q1.wrapping_sub(1)); + let r = r_gt_q0.select_word(r, r.wrapping_add(d)); + + // If this was a normal `if`, we wouldn't need wrapping ops, because there would be no overflow. + // But since we calculate both results either way, we have to wrap. + // Added an assert to still check the lack of overflow in debug mode. + debug_assert!(r < d || q1 < Word::MAX); + let r_ge_d = ConstChoice::from_word_le(d, r); + let q1 = r_ge_d.select_word(q1, q1.wrapping_add(1)); + let r = r_ge_d.select_word(r, r.wrapping_sub(d)); + + (q1, r) +} + +/// A pre-calculated reciprocal for division by a single limb. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Reciprocal { + divisor_normalized: Word, + shift: u32, + reciprocal: Word, +} + +impl Reciprocal { + /// return the shift + pub const fn shift(&self) -> u32 { + self.shift + } + + /// Pre-calculates a reciprocal for a known divisor, + /// to be used in the single-limb division later. + pub const fn new(divisor: NonZero) -> Self { + let divisor = divisor.0; + + // Assuming this is constant-time for primitive types. + let shift = divisor.0.leading_zeros(); + + // Will not panic since divisor is non-zero + let divisor_normalized = divisor.0 << shift; + + Self { + divisor_normalized, + shift, + reciprocal: reciprocal(divisor_normalized), + } + } + + /// Returns a default instance of this object. + /// It is a self-consistent `Reciprocal` that will not cause panics in functions that take it. + /// + /// NOTE: intended for using it as a placeholder during compile-time array generation, + /// don't rely on the contents. + pub const fn default() -> Self { + Self { + divisor_normalized: Word::MAX, + shift: 0, + // The result of calling `reciprocal(Word::MAX)` + // This holds both for 32- and 64-bit versions. + reciprocal: 1, + } + } +} + +impl ConditionallySelectable for Reciprocal { + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + Self { + divisor_normalized: Word::conditional_select( + &a.divisor_normalized, + &b.divisor_normalized, + choice, + ), + shift: u32::conditional_select(&a.shift, &b.shift, choice), + reciprocal: Word::conditional_select(&a.reciprocal, &b.reciprocal, choice), + } + } +} + +// `CtOption.map()` needs this; for some reason it doesn't use the value it already has +// for the `None` branch. +impl Default for Reciprocal { + fn default() -> Self { + Self::default() + } +} diff --git a/tests/boxed_uint_proptests.rs b/tests/boxed_uint_proptests.rs index 52b45c88..f793878c 100644 --- a/tests/boxed_uint_proptests.rs +++ b/tests/boxed_uint_proptests.rs @@ -7,6 +7,7 @@ use crypto_bigint::{BoxedUint, CheckedAdd, Integer, Limb, NonZero}; use num_bigint::{BigUint, ModInverse}; use num_traits::identities::One; use proptest::prelude::*; +use subtle::Choice; fn to_biguint(uint: &BoxedUint) -> BigUint { BigUint::from_bytes_be(&uint.to_be_bytes()) @@ -239,13 +240,13 @@ proptest! { let shift = u32::from(shift) % (a.bits_precision() * 2); let expected = to_uint((a_bi << shift as usize) & ((BigUint::one() << a.bits_precision() as usize) - BigUint::one())); - let actual = a.shl_vartime(shift); + let (actual, overflow) = a.overflowing_shl(shift); if shift >= a.bits_precision() { - assert!(actual.is_none()); + assert!(>::into(overflow)); } else { - assert_eq!(expected, actual.unwrap()); + assert_eq!(expected, actual); } } @@ -275,13 +276,12 @@ proptest! { let shift = u32::from(shift) % (a.bits_precision() * 2); let expected = to_uint(a_bi >> shift as usize); - let actual = a.shr_vartime(shift); + let (actual, overflow) = a.overflowing_shr(shift); if shift >= a.bits_precision() { - assert!(actual.is_none()); - } - else { - assert_eq!(expected, actual.unwrap()); + assert!(>::into(overflow)); + } else { + assert_eq!(expected, actual); } } }