Skip to content

Commit

Permalink
math: Improve sqrt using bit-wise operations (#1562)
Browse files Browse the repository at this point in the history
* math: Improve sqrt guess using bit-wise operations

* Run fmt and bump up instruction for failed test

* Bump up compute cost from CI failure

* Update CI version of toolchain

* Address feedback
  • Loading branch information
joncinque authored Apr 5, 2021
1 parent 0e2b080 commit 8f325dc
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 27 deletions.
2 changes: 1 addition & 1 deletion ci/solana-version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
if [[ -n $SOLANA_VERSION ]]; then
solana_version="$SOLANA_VERSION"
else
solana_version=v1.5.15
solana_version=v1.6.2
fi

export solana_version="$solana_version"
Expand Down
53 changes: 30 additions & 23 deletions libraries/math/src/approximations.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,42 @@
//! Approximation calculations
use {
num_traits::{CheckedAdd, CheckedDiv, One, Zero},
std::cmp::Eq,
num_traits::{CheckedShl, CheckedShr, PrimInt},
std::cmp::Ordering,
};

const SQRT_ITERATIONS: u8 = 50;

/// Perform square root
pub fn sqrt<T: CheckedAdd + CheckedDiv + One + Zero + Eq + Copy>(radicand: T) -> Option<T> {
if radicand == T::zero() {
return Some(T::zero());
/// Calculate square root of the given number
///
/// Code lovingly adapted from the excellent work at:
/// https://github.com/derekdreery/integer-sqrt-rs
///
/// The algorithm is based on the implementation in:
/// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)
pub fn sqrt<T: PrimInt + CheckedShl + CheckedShr>(radicand: T) -> Option<T> {
match radicand.cmp(&T::zero()) {
Ordering::Less => return None, // fail for less than 0
Ordering::Equal => return Some(T::zero()), // do nothing for 0
_ => {}
}
// A good initial guess is the average of the interval that contains the
// input number. For all numbers, that will be between 1 and the given number.
let one = T::one();
let two = one.checked_add(&one)?;
let mut guess = radicand.checked_div(&two)?.checked_add(&one)?;
let mut last_guess = guess;
for _ in 0..SQRT_ITERATIONS {
// x_k+1 = (x_k + radicand / x_k) / 2
guess = last_guess
.checked_add(&radicand.checked_div(&last_guess)?)?
.checked_div(&two)?;
if last_guess == guess {
break;

// Compute bit, the largest power of 4 <= n
let max_shift: u32 = T::zero().leading_zeros() - 1;
let shift: u32 = (max_shift - radicand.leading_zeros()) & !1;
let mut bit = T::one().checked_shl(shift)?;

let mut n = radicand;
let mut result = T::zero();
while bit != T::zero() {
let result_with_bit = result.checked_add(&bit)?;
if n >= result_with_bit {
n = n.checked_sub(&result_with_bit)?;
result = result.checked_shr(1)?.checked_add(&bit)?;
} else {
last_guess = guess;
result = result.checked_shr(1)?;
}
bit = bit.checked_shr(2)?;
}
Some(guess)
Some(result)
}

#[cfg(test)]
Expand Down
5 changes: 2 additions & 3 deletions libraries/math/tests/instruction_count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async fn test_sqrt_u128() {
let mut pc = ProgramTest::new("spl_math", id(), processor!(process_instruction));

// Dial down the BPF compute budget to detect if the operation gets bloated in the future
pc.set_bpf_compute_max_units(5_500);
pc.set_bpf_compute_max_units(4_000);

let (mut banks_client, payer, recent_blockhash) = pc.start().await;

Expand All @@ -78,8 +78,7 @@ async fn test_sqrt_u128() {
async fn test_sqrt_u128_max() {
let mut pc = ProgramTest::new("spl_math", id(), processor!(process_instruction));

// This is pretty big too!
pc.set_bpf_compute_max_units(90_000);
pc.set_bpf_compute_max_units(6_000);

let (mut banks_client, payer, recent_blockhash) = pc.start().await;

Expand Down

0 comments on commit 8f325dc

Please sign in to comment.