Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: bitshift takes shift: usize #989

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion crates/evm/src/instructions/comparison_operations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait {
if i > 31 {
return self.stack.push(0);
}
let i: usize = i.try_into().unwrap(); // Safe because i <= 31

// Right shift value by offset bits and then take the least significant byte.
let result = x.shr((31 - i) * 8) & 0xFF;
Expand All @@ -150,7 +151,7 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait {
if shift > 255 {
return self.stack.push(0);
}

let shift: usize = shift.try_into().unwrap(); // Safe because shift <= 255
let result = val.wrapping_shl(shift);
self.stack.push(result)
}
Expand All @@ -163,6 +164,11 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait {
let shift = *popped[0];
let value = *popped[1];

// if shift is bigger than 255 return 0
if shift > 255 {
return self.stack.push(0);
}
let shift: usize = shift.try_into().unwrap(); // Safe because shift <= 255
let result = value.wrapping_shr(shift);
self.stack.push(result)
}
Expand All @@ -187,6 +193,7 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait {
if (shift > 256) {
self.stack.push(sign)
} else {
let shift: usize = shift.try_into().unwrap(); // Safe because shift <= 256
// XORing with sign before and after the shift propagates the sign bit of the operation
let result = (sign ^ value.value).shr(shift) ^ sign;
self.stack.push(result)
Expand Down
2 changes: 1 addition & 1 deletion crates/evm/src/memory.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl MemoryImpl of MemoryTrait {

// First erase byte value at offset, then set the new value using bitwise ops
let word: u128 = self.items.get(chunk_index.into());
let new_word = (word & ~mask) | (value.into().shl(right_offset.into() * 8));
let new_word = (word & ~mask) | (value.into().shl(right_offset * 8));
self.items.insert(chunk_index.into(), new_word);
}

Expand Down
2 changes: 1 addition & 1 deletion crates/utils/src/crypto/blake2_compress.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fn rotate_right(value: u64, n: u32) -> u64 {
let bits = BitSize::<u64>::bits(); // The number of bits in a u64
let n = n % bits; // Ensure n is less than 64

let res = value.wrapping_shr(n.into()) | value.wrapping_shl((bits - n).into());
let res = value.wrapping_shr(n) | value.wrapping_shl((bits - n));
res
}
}
Expand Down
24 changes: 12 additions & 12 deletions crates/utils/src/crypto/modexp/arith.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ pub fn mod_inv(x: Word) -> Word {
break;
}

let mask: u64 = 1_u64.shl(i.into()) - 1;
let mask: u64 = 1_u64.shl(i) - 1;
let xy = x.wrapping_mul(y) & mask;
let q = (mask + 1) / 2;
if xy >= q {
Expand All @@ -310,7 +310,7 @@ pub fn mod_inv(x: Word) -> Word {
};

let xy = x.wrapping_mul(y);
let q = 1_u64.wrapping_shl((WORD_BITS - 1).into());
let q = 1_u64.wrapping_shl((WORD_BITS - 1));
if xy >= q {
y += q;
}
Expand Down Expand Up @@ -415,7 +415,7 @@ pub fn borrowing_sub(x: Word, y: Word, borrow: bool) -> (Word, bool) {
/// The double word obtained by joining `hi` and `lo`
pub fn join_as_double(hi: Word, lo: Word) -> DoubleWord {
let hi: DoubleWord = hi.into();
(hi.shl(WORD_BITS.into())).into() + lo.into()
hi.shl(WORD_BITS).into() + lo.into()
}

/// Computes `x^2`, storing the result in `out`.
Expand Down Expand Up @@ -457,14 +457,14 @@ fn big_sq(ref x: MPNat, ref out: Felt252Vec<Word>) {
}

out.set(i + j, res.as_u64());
c = new_c + res.shr(WORD_BITS.into());
c = new_c + res.shr(WORD_BITS);

j += 1;
};

let (sum, carry) = carrying_add(out[i + s], c.as_u64(), false);
out.set(i + s, sum);
out.set(i + s + 1, (c.shr(WORD_BITS.into()) + (carry.into())).as_u64());
out.set(i + s + 1, (c.shr(WORD_BITS) + (carry.into())).as_u64());

i += 1;
}
Expand All @@ -482,8 +482,8 @@ pub fn in_place_shl(ref a: Felt252Vec<Word>, shift: u32) -> Word {
}

let mut a_digit = a[i];
let carry = a_digit.wrapping_shr(carry_shift.into());
a_digit = a_digit.wrapping_shl(shift.into()) | c;
let carry = a_digit.wrapping_shr(carry_shift);
a_digit = a_digit.wrapping_shl(shift) | c;
a.set(i, a_digit);

c = carry;
Expand All @@ -508,8 +508,8 @@ pub fn in_place_shr(ref a: Felt252Vec<Word>, shift: u32) -> Word {
let j = i - 1;

let mut a_digit = a[j];
let borrow = a_digit.wrapping_shl(borrow_shift.into());
a_digit = a_digit.wrapping_shr(shift.into()) | b;
let borrow = a_digit.wrapping_shl(borrow_shift);
a_digit = a_digit.wrapping_shr(shift) | b;
a.set(j, a_digit);

b = borrow;
Expand Down Expand Up @@ -574,7 +574,7 @@ pub fn in_place_mul_sub(ref a: Felt252Vec<Word>, ref x: Felt252Vec<Word>, y: Wor
+ offset_carry.into()
- ((x_digit.into()) * (y.into()));

let new_offset_carry = (offset_sum.shr(WORD_BITS.into())).as_u64();
let new_offset_carry = (offset_sum.shr(WORD_BITS)).as_u64();
let new_x = offset_sum.as_u64();
offset_carry = new_offset_carry;
a.set(i, new_x);
Expand Down Expand Up @@ -661,15 +661,15 @@ mod tests {
let mut result = mp_nat_to_u128(ref x);

let mask = BASE.wrapping_pow(x.digits.len().into()).wrapping_sub(1);
assert_eq!(result, n.wrapping_shl(shift.into()) & mask);
assert_eq!(result, n.wrapping_shl(shift) & mask);
}

fn check_in_place_shr(n: u128, shift: u32) {
let mut x = MPNatTrait::from_big_endian(n.to_be_bytes_padded());
in_place_shr(ref x.digits, shift);
let mut result = mp_nat_to_u128(ref x);

assert_eq!(result, n.wrapping_shr(shift.into()));
assert_eq!(result, n.wrapping_shr(shift));
}

fn check_mod_inv(n: Word) {
Expand Down
11 changes: 5 additions & 6 deletions crates/utils/src/crypto/modexp/mpnat.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ pub impl MPNatTraitImpl of MPNatTrait {

in_place_shr(ref b.digits, 1);

res.digits.set(wordpos, res.digits[wordpos] | (x.shl(bitpos.into())));
res.digits.set(wordpos, res.digits[wordpos] | (x.shl(bitpos)));

bitpos += 1;
if bitpos == WORD_BITS {
Expand Down Expand Up @@ -404,7 +404,7 @@ pub impl MPNatTraitImpl of MPNatTrait {
let mut digits = Felt252VecImpl::new();
digits.expand(trailing_zeros + 1).unwrap();
let mut tmp = MPNat { digits };
tmp.digits.set(trailing_zeros, 1_u64.shl(additional_zero_bits.into()));
tmp.digits.set(trailing_zeros, 1_u64.shl(additional_zero_bits));
tmp
};

Expand All @@ -415,7 +415,7 @@ pub impl MPNatTraitImpl of MPNatTrait {
digits.expand(num_digits).unwrap();
let mut tmp = MPNat { digits };
if additional_zero_bits > 0 {
tmp.digits.set(0, modulus.digits[trailing_zeros].shr(additional_zero_bits.into()));
tmp.digits.set(0, modulus.digits[trailing_zeros].shr(additional_zero_bits));
let mut i = 1;
loop {
if i == num_digits {
Expand All @@ -429,10 +429,9 @@ pub impl MPNatTraitImpl of MPNatTrait {
i - 1,
tmp.digits[i
- 1]
+ (d & power_of_two_mask)
.shl((WORD_BITS - additional_zero_bits).into())
+ (d & power_of_two_mask).shl(WORD_BITS - additional_zero_bits)
);
tmp.digits.set(i, d.shr(additional_zero_bits.into()));
tmp.digits.set(i, d.shr(additional_zero_bits));

i += 1;
};
Expand Down
53 changes: 39 additions & 14 deletions crates/utils/src/math.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::integer::{u512};
use core::num::traits::{Zero, One, BitSize, OverflowingAdd, OverflowingMul};
use core::num::traits::{Zero, One, BitSize, OverflowingAdd, OverflowingMul, Bounded};
use core::panic_with_felt252;
use core::traits::{BitAnd};

Expand Down Expand Up @@ -203,7 +203,7 @@ pub trait Bitshift<T> {
///
/// Panics if the shift is greater than 255.
/// Panics if the result overflows the type T.
fn shl(self: T, shift: T) -> T;
fn shl(self: T, shift: usize) -> T;

/// Shift a number right by a given number of bits.
///
Expand All @@ -219,7 +219,7 @@ pub trait Bitshift<T> {
/// # Panics
///
/// Panics if the shift is greater than 255.
fn shr(self: T, shift: T) -> T;
fn shr(self: T, shift: usize) -> T;
}

impl BitshiftImpl<
Expand All @@ -237,23 +237,35 @@ impl BitshiftImpl<
+BitSize<T>,
+TryInto<usize, T>,
> of Bitshift<T> {
fn shl(self: T, shift: T) -> T {
fn shl(self: T, shift: usize) -> T {
// if we shift by more than nb_bits of T, the result is 0
// we early return to save gas and prevent unexpected behavior
if shift > BitSize::<T>::bits().try_into().unwrap() - One::one() {
if shift > BitSize::<T>::bits() - One::one() {
panic_with_felt252('mul Overflow');
}
// if the shift is within the bit size of u256 (<= 255 bits),
// use the POW_2 lookup table to get 2^shift for efficient multiplication
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
// In case the pow2 is greater than the max value of T, we have an overflow
// so we can panic
return self * (*POW_2_256.span().at(shift)).try_into().expect('mul Overflow');
}
// for shifts greater than 255 bits, perform the shift manually
let two = One::one() + One::one();
self * two.pow(shift)
self * two.pow(shift.try_into().expect('mul Overflow'))
}

fn shr(self: T, shift: T) -> T {
fn shr(self: T, shift: usize) -> T {
// early return to save gas if shift > nb_bits of T
if shift > BitSize::<T>::bits().try_into().unwrap() - One::one() {
if shift > BitSize::<T>::bits() - One::one() {
panic_with_felt252('mul Overflow');
}
// use the POW_2 lookup table when the bit size
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
return self / (*POW_2_256.span().at(shift)).try_into().expect('mul Overflow');
}
let two = One::one() + One::one();
self / two.pow(shift)
self / two.pow(shift.try_into().expect('mul Overflow'))
}
}

Expand All @@ -270,7 +282,7 @@ pub trait WrappingBitshift<T> {
/// # Returns
///
/// The result of shifting `self` left by `shift` bits, wrapped if necessary
fn wrapping_shl(self: T, shift: T) -> T;
fn wrapping_shl(self: T, shift: usize) -> T;

/// Shift a number right by a given number of bits.
/// If the shift is greater than 255, the result is 0.
Expand All @@ -283,7 +295,7 @@ pub trait WrappingBitshift<T> {
/// # Returns
///
/// The result of shifting `self` right by `shift` bits, or 0 if shift > 255
fn wrapping_shr(self: T, shift: T) -> T;
fn wrapping_shr(self: T, shift: usize) -> T;
}

pub impl WrappingBitshiftImpl<
Expand All @@ -300,18 +312,31 @@ pub impl WrappingBitshiftImpl<
+OverflowingMul<T>,
+WrappingExponentiation<T>,
+BitSize<T>,
+Bounded<T>,
+Into<T, u256>,
+TryInto<usize, T>,
> of WrappingBitshift<T> {
fn wrapping_shl(self: T, shift: T) -> T {
fn wrapping_shl(self: T, shift: usize) -> T {
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
let pow_2: u256 = (*POW_2_256.span().at(shift));
let pow2_mod_t: u256 = pow_2 % Bounded::<T>::MAX.into();
let (result, _) = self.overflowing_mul(pow2_mod_t.try_into().unwrap());
return result;
}
let two = One::<T>::one() + One::<T>::one();
let (result, _) = self.overflowing_mul(two.wrapping_pow(shift));
result
}

fn wrapping_shr(self: T, shift: T) -> T {
fn wrapping_shr(self: T, shift: usize) -> T {
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
let pow_2: u256 = (*POW_2_256.span().at(shift));
let pow2_mod_t: u256 = pow_2 % Bounded::<T>::MAX.into();
return self / pow2_mod_t.try_into().unwrap();
}
let two = One::<T>::one() + One::<T>::one();

if shift > BitSize::<T>::bits().try_into().unwrap() - One::one() {
if shift > BitSize::<T>::bits() - One::one() {
return Zero::zero();
}
self / two.pow(shift)
Expand Down
2 changes: 1 addition & 1 deletion crates/utils/src/traits.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pub impl SpanU8TryIntoResultEthAddress of TryIntoResult<Span<u8>, EthAddress> {
let mut i: u32 = 0;
while i != len {
let byte: u256 = (*self.at(i)).into();
result += byte.shl(8 * (offset - i).into());
result += byte.shl(8 * (offset - i));
i += 1;
};
let address: felt252 = result.try_into_result()?;
Expand Down
19 changes: 6 additions & 13 deletions crates/utils/src/traits/bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub impl U8SpanExImpl of U8SpanExTrait {
Option::Some(byte) => {
let byte: u64 = (*byte.unbox()).into();
// Accumulate pending_word in a little endian manner
byte.shl(8_u64 * byte_counter.into())
byte.shl(8_u32 * byte_counter.into())
},
Option::None => { break; },
};
Expand All @@ -69,7 +69,7 @@ pub impl U8SpanExImpl of U8SpanExTrait {
last_input_word += match self.get(full_u64_word_count * 8 + byte_counter.into()) {
Option::Some(byte) => {
let byte: u64 = (*byte.unbox()).into();
byte.shl(8_u64 * byte_counter.into())
byte.shl(8_u32 * byte_counter.into())
},
Option::None => { break; },
};
Expand Down Expand Up @@ -249,17 +249,13 @@ pub impl ToBytesImpl<
fn to_be_bytes(self: T) -> Span<u8> {
let bytes_used = self.bytes_used();

let one = One::<T>::one();
let two = one + one;
let eight = two * two * two;

// 0xFF
let mask = Bounded::<u8>::MAX.into();

let mut bytes: Array<u8> = Default::default();
let mut i: u8 = 0;
while i != bytes_used {
let val = Bitshift::<T>::shr(self, eight * (bytes_used - i - 1).into());
let val = Bitshift::<T>::shr(self, 8_u32 * (bytes_used.into() - i.into() - 1));
bytes.append((val & mask).try_into().unwrap());
i += 1;
};
Expand All @@ -274,9 +270,6 @@ pub impl ToBytesImpl<

fn to_le_bytes(mut self: T) -> Span<u8> {
let bytes_used = self.bytes_used();
let one = One::<T>::one();
let two = one + one;
let eight = two * two * two;

// 0xFF
let mask = Bounded::<u8>::MAX.into();
Expand All @@ -285,7 +278,7 @@ pub impl ToBytesImpl<

let mut i: u8 = 0;
while i != bytes_used {
let val = self.shr(eight * i.into());
let val = self.shr(8_u32 * i.into());
bytes.append((val & mask).try_into().unwrap());
i += 1;
};
Expand Down Expand Up @@ -536,7 +529,7 @@ pub impl ByteArrayExt of ByteArrayExTrait {
Option::Some(byte) => {
let byte: u64 = byte.into();
// Accumulate pending_word in a little endian manner
byte.shl(8_u64 * byte_counter.into())
byte.shl(8_u32 * byte_counter.into())
},
Option::None => { break; },
};
Expand All @@ -555,7 +548,7 @@ pub impl ByteArrayExt of ByteArrayExTrait {
last_input_word += match self.at(full_u64_word_count * 8 + byte_counter.into()) {
Option::Some(byte) => {
let byte: u64 = byte.into();
byte.shl(8_u64 * byte_counter.into())
byte.shl(8_u32 * byte_counter.into())
},
Option::None => { break; },
};
Expand Down
Loading
Loading