From 8f325dcd2d1aea1c908d6570e79f06b15195fea0 Mon Sep 17 00:00:00 2001 From: Jon Cinque Date: Mon, 5 Apr 2021 14:48:27 +0200 Subject: [PATCH] math: Improve sqrt using bit-wise operations (#1562) * math: Improve sqrt guess using bit-wise operations * Run fmt and bump up instruction for failed test * Bump up compute cost from CI failure * Update CI version of toolchain * Address feedback --- ci/solana-version.sh | 2 +- libraries/math/src/approximations.rs | 53 +++++++++++++---------- libraries/math/tests/instruction_count.rs | 5 +-- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/ci/solana-version.sh b/ci/solana-version.sh index 4a8346d101c..4a0dc7945dc 100755 --- a/ci/solana-version.sh +++ b/ci/solana-version.sh @@ -14,7 +14,7 @@ if [[ -n $SOLANA_VERSION ]]; then solana_version="$SOLANA_VERSION" else - solana_version=v1.5.15 + solana_version=v1.6.2 fi export solana_version="$solana_version" diff --git a/libraries/math/src/approximations.rs b/libraries/math/src/approximations.rs index a061e754387..570c32b1c84 100644 --- a/libraries/math/src/approximations.rs +++ b/libraries/math/src/approximations.rs @@ -1,35 +1,42 @@ //! Approximation calculations use { - num_traits::{CheckedAdd, CheckedDiv, One, Zero}, - std::cmp::Eq, + num_traits::{CheckedShl, CheckedShr, PrimInt}, + std::cmp::Ordering, }; -const SQRT_ITERATIONS: u8 = 50; - -/// Perform square root -pub fn sqrt(radicand: T) -> Option { - if radicand == T::zero() { - return Some(T::zero()); +/// Calculate square root of the given number +/// +/// Code lovingly adapted from the excellent work at: +/// https://github.com/derekdreery/integer-sqrt-rs +/// +/// The algorithm is based on the implementation in: +/// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2) +pub fn sqrt(radicand: T) -> Option { + match radicand.cmp(&T::zero()) { + Ordering::Less => return None, // fail for less than 0 + Ordering::Equal => return Some(T::zero()), // do nothing for 0 + _ => {} } - // A good initial guess is the average of the interval that contains the - // input number. For all numbers, that will be between 1 and the given number. - let one = T::one(); - let two = one.checked_add(&one)?; - let mut guess = radicand.checked_div(&two)?.checked_add(&one)?; - let mut last_guess = guess; - for _ in 0..SQRT_ITERATIONS { - // x_k+1 = (x_k + radicand / x_k) / 2 - guess = last_guess - .checked_add(&radicand.checked_div(&last_guess)?)? - .checked_div(&two)?; - if last_guess == guess { - break; + + // Compute bit, the largest power of 4 <= n + let max_shift: u32 = T::zero().leading_zeros() - 1; + let shift: u32 = (max_shift - radicand.leading_zeros()) & !1; + let mut bit = T::one().checked_shl(shift)?; + + let mut n = radicand; + let mut result = T::zero(); + while bit != T::zero() { + let result_with_bit = result.checked_add(&bit)?; + if n >= result_with_bit { + n = n.checked_sub(&result_with_bit)?; + result = result.checked_shr(1)?.checked_add(&bit)?; } else { - last_guess = guess; + result = result.checked_shr(1)?; } + bit = bit.checked_shr(2)?; } - Some(guess) + Some(result) } #[cfg(test)] diff --git a/libraries/math/tests/instruction_count.rs b/libraries/math/tests/instruction_count.rs index 0b81e1fcfc2..087f5ac0119 100644 --- a/libraries/math/tests/instruction_count.rs +++ b/libraries/math/tests/instruction_count.rs @@ -62,7 +62,7 @@ async fn test_sqrt_u128() { let mut pc = ProgramTest::new("spl_math", id(), processor!(process_instruction)); // Dial down the BPF compute budget to detect if the operation gets bloated in the future - pc.set_bpf_compute_max_units(5_500); + pc.set_bpf_compute_max_units(4_000); let (mut banks_client, payer, recent_blockhash) = pc.start().await; @@ -78,8 +78,7 @@ async fn test_sqrt_u128() { async fn test_sqrt_u128_max() { let mut pc = ProgramTest::new("spl_math", id(), processor!(process_instruction)); - // This is pretty big too! - pc.set_bpf_compute_max_units(90_000); + pc.set_bpf_compute_max_units(6_000); let (mut banks_client, payer, recent_blockhash) = pc.start().await;