From e41bcaef9273c0f61216acbc236a86c6b8ca324c Mon Sep 17 00:00:00 2001 From: enitrat Date: Sat, 28 Sep 2024 16:44:53 +0200 Subject: [PATCH] refactor: Bitshift takes usize as arg --- .../instructions/comparison_operations.cairo | 9 +++- crates/evm/src/memory.cairo | 2 +- crates/utils/src/crypto/blake2_compress.cairo | 2 +- crates/utils/src/crypto/modexp/arith.cairo | 24 ++++----- crates/utils/src/crypto/modexp/mpnat.cairo | 11 ++-- crates/utils/src/math.cairo | 53 ++++++++++++++----- crates/utils/src/traits.cairo | 2 +- crates/utils/src/traits/bytes.cairo | 19 +++---- crates/utils/src/traits/eth_address.cairo | 10 ++-- crates/utils/src/traits/integer.cairo | 4 +- 10 files changed, 79 insertions(+), 57 deletions(-) diff --git a/crates/evm/src/instructions/comparison_operations.cairo b/crates/evm/src/instructions/comparison_operations.cairo index e99f9e079..44e8ca163 100644 --- a/crates/evm/src/instructions/comparison_operations.cairo +++ b/crates/evm/src/instructions/comparison_operations.cairo @@ -132,6 +132,7 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait { if i > 31 { return self.stack.push(0); } + let i: usize = i.try_into().unwrap(); // Safe because i <= 31 // Right shift value by offset bits and then take the least significant byte. let result = x.shr((31 - i) * 8) & 0xFF; @@ -150,7 +151,7 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait { if shift > 255 { return self.stack.push(0); } - + let shift: usize = shift.try_into().unwrap(); // Safe because shift <= 255 let result = val.wrapping_shl(shift); self.stack.push(result) } @@ -163,6 +164,11 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait { let shift = *popped[0]; let value = *popped[1]; + // if shift is bigger than 255 return 0 + if shift > 255 { + return self.stack.push(0); + } + let shift: usize = shift.try_into().unwrap(); // Safe because shift <= 255 let result = value.wrapping_shr(shift); self.stack.push(result) } @@ -187,6 +193,7 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait { if (shift > 256) { self.stack.push(sign) } else { + let shift: usize = shift.try_into().unwrap(); // Safe because shift <= 256 // XORing with sign before and after the shift propagates the sign bit of the operation let result = (sign ^ value.value).shr(shift) ^ sign; self.stack.push(result) diff --git a/crates/evm/src/memory.cairo b/crates/evm/src/memory.cairo index b24ec7ca7..062febdf7 100644 --- a/crates/evm/src/memory.cairo +++ b/crates/evm/src/memory.cairo @@ -88,7 +88,7 @@ impl MemoryImpl of MemoryTrait { // First erase byte value at offset, then set the new value using bitwise ops let word: u128 = self.items.get(chunk_index.into()); - let new_word = (word & ~mask) | (value.into().shl(right_offset.into() * 8)); + let new_word = (word & ~mask) | (value.into().shl(right_offset * 8)); self.items.insert(chunk_index.into(), new_word); } diff --git a/crates/utils/src/crypto/blake2_compress.cairo b/crates/utils/src/crypto/blake2_compress.cairo index 4ce2840e4..b1b158ea2 100644 --- a/crates/utils/src/crypto/blake2_compress.cairo +++ b/crates/utils/src/crypto/blake2_compress.cairo @@ -54,7 +54,7 @@ fn rotate_right(value: u64, n: u32) -> u64 { let bits = BitSize::::bits(); // The number of bits in a u64 let n = n % bits; // Ensure n is less than 64 - let res = value.wrapping_shr(n.into()) | value.wrapping_shl((bits - n).into()); + let res = value.wrapping_shr(n) | value.wrapping_shl((bits - n)); res } } diff --git a/crates/utils/src/crypto/modexp/arith.cairo b/crates/utils/src/crypto/modexp/arith.cairo index d82bf2f61..d302048a8 100644 --- a/crates/utils/src/crypto/modexp/arith.cairo +++ b/crates/utils/src/crypto/modexp/arith.cairo @@ -300,7 +300,7 @@ pub fn mod_inv(x: Word) -> Word { break; } - let mask: u64 = 1_u64.shl(i.into()) - 1; + let mask: u64 = 1_u64.shl(i) - 1; let xy = x.wrapping_mul(y) & mask; let q = (mask + 1) / 2; if xy >= q { @@ -310,7 +310,7 @@ pub fn mod_inv(x: Word) -> Word { }; let xy = x.wrapping_mul(y); - let q = 1_u64.wrapping_shl((WORD_BITS - 1).into()); + let q = 1_u64.wrapping_shl((WORD_BITS - 1)); if xy >= q { y += q; } @@ -415,7 +415,7 @@ pub fn borrowing_sub(x: Word, y: Word, borrow: bool) -> (Word, bool) { /// The double word obtained by joining `hi` and `lo` pub fn join_as_double(hi: Word, lo: Word) -> DoubleWord { let hi: DoubleWord = hi.into(); - (hi.shl(WORD_BITS.into())).into() + lo.into() + hi.shl(WORD_BITS).into() + lo.into() } /// Computes `x^2`, storing the result in `out`. @@ -457,14 +457,14 @@ fn big_sq(ref x: MPNat, ref out: Felt252Vec) { } out.set(i + j, res.as_u64()); - c = new_c + res.shr(WORD_BITS.into()); + c = new_c + res.shr(WORD_BITS); j += 1; }; let (sum, carry) = carrying_add(out[i + s], c.as_u64(), false); out.set(i + s, sum); - out.set(i + s + 1, (c.shr(WORD_BITS.into()) + (carry.into())).as_u64()); + out.set(i + s + 1, (c.shr(WORD_BITS) + (carry.into())).as_u64()); i += 1; } @@ -482,8 +482,8 @@ pub fn in_place_shl(ref a: Felt252Vec, shift: u32) -> Word { } let mut a_digit = a[i]; - let carry = a_digit.wrapping_shr(carry_shift.into()); - a_digit = a_digit.wrapping_shl(shift.into()) | c; + let carry = a_digit.wrapping_shr(carry_shift); + a_digit = a_digit.wrapping_shl(shift) | c; a.set(i, a_digit); c = carry; @@ -508,8 +508,8 @@ pub fn in_place_shr(ref a: Felt252Vec, shift: u32) -> Word { let j = i - 1; let mut a_digit = a[j]; - let borrow = a_digit.wrapping_shl(borrow_shift.into()); - a_digit = a_digit.wrapping_shr(shift.into()) | b; + let borrow = a_digit.wrapping_shl(borrow_shift); + a_digit = a_digit.wrapping_shr(shift) | b; a.set(j, a_digit); b = borrow; @@ -574,7 +574,7 @@ pub fn in_place_mul_sub(ref a: Felt252Vec, ref x: Felt252Vec, y: Wor + offset_carry.into() - ((x_digit.into()) * (y.into())); - let new_offset_carry = (offset_sum.shr(WORD_BITS.into())).as_u64(); + let new_offset_carry = (offset_sum.shr(WORD_BITS)).as_u64(); let new_x = offset_sum.as_u64(); offset_carry = new_offset_carry; a.set(i, new_x); @@ -661,7 +661,7 @@ mod tests { let mut result = mp_nat_to_u128(ref x); let mask = BASE.wrapping_pow(x.digits.len().into()).wrapping_sub(1); - assert_eq!(result, n.wrapping_shl(shift.into()) & mask); + assert_eq!(result, n.wrapping_shl(shift) & mask); } fn check_in_place_shr(n: u128, shift: u32) { @@ -669,7 +669,7 @@ mod tests { in_place_shr(ref x.digits, shift); let mut result = mp_nat_to_u128(ref x); - assert_eq!(result, n.wrapping_shr(shift.into())); + assert_eq!(result, n.wrapping_shr(shift)); } fn check_mod_inv(n: Word) { diff --git a/crates/utils/src/crypto/modexp/mpnat.cairo b/crates/utils/src/crypto/modexp/mpnat.cairo index 5f461837f..6cd2d7fff 100644 --- a/crates/utils/src/crypto/modexp/mpnat.cairo +++ b/crates/utils/src/crypto/modexp/mpnat.cairo @@ -323,7 +323,7 @@ pub impl MPNatTraitImpl of MPNatTrait { in_place_shr(ref b.digits, 1); - res.digits.set(wordpos, res.digits[wordpos] | (x.shl(bitpos.into()))); + res.digits.set(wordpos, res.digits[wordpos] | (x.shl(bitpos))); bitpos += 1; if bitpos == WORD_BITS { @@ -404,7 +404,7 @@ pub impl MPNatTraitImpl of MPNatTrait { let mut digits = Felt252VecImpl::new(); digits.expand(trailing_zeros + 1).unwrap(); let mut tmp = MPNat { digits }; - tmp.digits.set(trailing_zeros, 1_u64.shl(additional_zero_bits.into())); + tmp.digits.set(trailing_zeros, 1_u64.shl(additional_zero_bits)); tmp }; @@ -415,7 +415,7 @@ pub impl MPNatTraitImpl of MPNatTrait { digits.expand(num_digits).unwrap(); let mut tmp = MPNat { digits }; if additional_zero_bits > 0 { - tmp.digits.set(0, modulus.digits[trailing_zeros].shr(additional_zero_bits.into())); + tmp.digits.set(0, modulus.digits[trailing_zeros].shr(additional_zero_bits)); let mut i = 1; loop { if i == num_digits { @@ -429,10 +429,9 @@ pub impl MPNatTraitImpl of MPNatTrait { i - 1, tmp.digits[i - 1] - + (d & power_of_two_mask) - .shl((WORD_BITS - additional_zero_bits).into()) + + (d & power_of_two_mask).shl(WORD_BITS - additional_zero_bits) ); - tmp.digits.set(i, d.shr(additional_zero_bits.into())); + tmp.digits.set(i, d.shr(additional_zero_bits)); i += 1; }; diff --git a/crates/utils/src/math.cairo b/crates/utils/src/math.cairo index 6e77fdbc9..e65b720ef 100644 --- a/crates/utils/src/math.cairo +++ b/crates/utils/src/math.cairo @@ -1,5 +1,5 @@ use core::integer::{u512}; -use core::num::traits::{Zero, One, BitSize, OverflowingAdd, OverflowingMul}; +use core::num::traits::{Zero, One, BitSize, OverflowingAdd, OverflowingMul, Bounded}; use core::panic_with_felt252; use core::traits::{BitAnd}; @@ -203,7 +203,7 @@ pub trait Bitshift { /// /// Panics if the shift is greater than 255. /// Panics if the result overflows the type T. - fn shl(self: T, shift: T) -> T; + fn shl(self: T, shift: usize) -> T; /// Shift a number right by a given number of bits. /// @@ -219,7 +219,7 @@ pub trait Bitshift { /// # Panics /// /// Panics if the shift is greater than 255. - fn shr(self: T, shift: T) -> T; + fn shr(self: T, shift: usize) -> T; } impl BitshiftImpl< @@ -237,23 +237,35 @@ impl BitshiftImpl< +BitSize, +TryInto, > of Bitshift { - fn shl(self: T, shift: T) -> T { + fn shl(self: T, shift: usize) -> T { // if we shift by more than nb_bits of T, the result is 0 // we early return to save gas and prevent unexpected behavior - if shift > BitSize::::bits().try_into().unwrap() - One::one() { + if shift > BitSize::::bits() - One::one() { panic_with_felt252('mul Overflow'); } + // if the shift is within the bit size of u256 (<= 255 bits), + // use the POW_2 lookup table to get 2^shift for efficient multiplication + if shift <= BitSize::::bits() - One::::one() { + // In case the pow2 is greater than the max value of T, we have an overflow + // so we can panic + return self * (*POW_2_256.span().at(shift)).try_into().expect('mul Overflow'); + } + // for shifts greater than 255 bits, perform the shift manually let two = One::one() + One::one(); - self * two.pow(shift) + self * two.pow(shift.try_into().expect('mul Overflow')) } - fn shr(self: T, shift: T) -> T { + fn shr(self: T, shift: usize) -> T { // early return to save gas if shift > nb_bits of T - if shift > BitSize::::bits().try_into().unwrap() - One::one() { + if shift > BitSize::::bits() - One::one() { panic_with_felt252('mul Overflow'); } + // use the POW_2 lookup table when the bit size + if shift <= BitSize::::bits() - One::::one() { + return self / (*POW_2_256.span().at(shift)).try_into().expect('mul Overflow'); + } let two = One::one() + One::one(); - self / two.pow(shift) + self / two.pow(shift.try_into().expect('mul Overflow')) } } @@ -270,7 +282,7 @@ pub trait WrappingBitshift { /// # Returns /// /// The result of shifting `self` left by `shift` bits, wrapped if necessary - fn wrapping_shl(self: T, shift: T) -> T; + fn wrapping_shl(self: T, shift: usize) -> T; /// Shift a number right by a given number of bits. /// If the shift is greater than 255, the result is 0. @@ -283,7 +295,7 @@ pub trait WrappingBitshift { /// # Returns /// /// The result of shifting `self` right by `shift` bits, or 0 if shift > 255 - fn wrapping_shr(self: T, shift: T) -> T; + fn wrapping_shr(self: T, shift: usize) -> T; } pub impl WrappingBitshiftImpl< @@ -300,18 +312,31 @@ pub impl WrappingBitshiftImpl< +OverflowingMul, +WrappingExponentiation, +BitSize, + +Bounded, + +Into, +TryInto, > of WrappingBitshift { - fn wrapping_shl(self: T, shift: T) -> T { + fn wrapping_shl(self: T, shift: usize) -> T { + if shift <= BitSize::::bits() - One::::one() { + let pow_2: u256 = (*POW_2_256.span().at(shift)); + let pow2_mod_t: u256 = pow_2 % Bounded::::MAX.into(); + let (result, _) = self.overflowing_mul(pow2_mod_t.try_into().unwrap()); + return result; + } let two = One::::one() + One::::one(); let (result, _) = self.overflowing_mul(two.wrapping_pow(shift)); result } - fn wrapping_shr(self: T, shift: T) -> T { + fn wrapping_shr(self: T, shift: usize) -> T { + if shift <= BitSize::::bits() - One::::one() { + let pow_2: u256 = (*POW_2_256.span().at(shift)); + let pow2_mod_t: u256 = pow_2 % Bounded::::MAX.into(); + return self / pow2_mod_t.try_into().unwrap(); + } let two = One::::one() + One::::one(); - if shift > BitSize::::bits().try_into().unwrap() - One::one() { + if shift > BitSize::::bits() - One::one() { return Zero::zero(); } self / two.pow(shift) diff --git a/crates/utils/src/traits.cairo b/crates/utils/src/traits.cairo index 3cf2fbf2f..58a8cb94a 100644 --- a/crates/utils/src/traits.cairo +++ b/crates/utils/src/traits.cairo @@ -109,7 +109,7 @@ pub impl SpanU8TryIntoResultEthAddress of TryIntoResult, EthAddress> { let mut i: u32 = 0; while i != len { let byte: u256 = (*self.at(i)).into(); - result += byte.shl(8 * (offset - i).into()); + result += byte.shl(8 * (offset - i)); i += 1; }; let address: felt252 = result.try_into_result()?; diff --git a/crates/utils/src/traits/bytes.cairo b/crates/utils/src/traits/bytes.cairo index c780098f2..85f66ec36 100644 --- a/crates/utils/src/traits/bytes.cairo +++ b/crates/utils/src/traits/bytes.cairo @@ -50,7 +50,7 @@ pub impl U8SpanExImpl of U8SpanExTrait { Option::Some(byte) => { let byte: u64 = (*byte.unbox()).into(); // Accumulate pending_word in a little endian manner - byte.shl(8_u64 * byte_counter.into()) + byte.shl(8_u32 * byte_counter.into()) }, Option::None => { break; }, }; @@ -69,7 +69,7 @@ pub impl U8SpanExImpl of U8SpanExTrait { last_input_word += match self.get(full_u64_word_count * 8 + byte_counter.into()) { Option::Some(byte) => { let byte: u64 = (*byte.unbox()).into(); - byte.shl(8_u64 * byte_counter.into()) + byte.shl(8_u32 * byte_counter.into()) }, Option::None => { break; }, }; @@ -249,17 +249,13 @@ pub impl ToBytesImpl< fn to_be_bytes(self: T) -> Span { let bytes_used = self.bytes_used(); - let one = One::::one(); - let two = one + one; - let eight = two * two * two; - // 0xFF let mask = Bounded::::MAX.into(); let mut bytes: Array = Default::default(); let mut i: u8 = 0; while i != bytes_used { - let val = Bitshift::::shr(self, eight * (bytes_used - i - 1).into()); + let val = Bitshift::::shr(self, 8_u32 * (bytes_used.into() - i.into() - 1)); bytes.append((val & mask).try_into().unwrap()); i += 1; }; @@ -274,9 +270,6 @@ pub impl ToBytesImpl< fn to_le_bytes(mut self: T) -> Span { let bytes_used = self.bytes_used(); - let one = One::::one(); - let two = one + one; - let eight = two * two * two; // 0xFF let mask = Bounded::::MAX.into(); @@ -285,7 +278,7 @@ pub impl ToBytesImpl< let mut i: u8 = 0; while i != bytes_used { - let val = self.shr(eight * i.into()); + let val = self.shr(8_u32 * i.into()); bytes.append((val & mask).try_into().unwrap()); i += 1; }; @@ -536,7 +529,7 @@ pub impl ByteArrayExt of ByteArrayExTrait { Option::Some(byte) => { let byte: u64 = byte.into(); // Accumulate pending_word in a little endian manner - byte.shl(8_u64 * byte_counter.into()) + byte.shl(8_u32 * byte_counter.into()) }, Option::None => { break; }, }; @@ -555,7 +548,7 @@ pub impl ByteArrayExt of ByteArrayExTrait { last_input_word += match self.at(full_u64_word_count * 8 + byte_counter.into()) { Option::Some(byte) => { let byte: u64 = byte.into(); - byte.shl(8_u64 * byte_counter.into()) + byte.shl(8_u32 * byte_counter.into()) }, Option::None => { break; }, }; diff --git a/crates/utils/src/traits/eth_address.cairo b/crates/utils/src/traits/eth_address.cairo index c2f260716..a7bb5e338 100644 --- a/crates/utils/src/traits/eth_address.cairo +++ b/crates/utils/src/traits/eth_address.cairo @@ -4,18 +4,18 @@ use crate::traits::EthAddressIntoU256; #[generate_trait] pub impl EthAddressExImpl of EthAddressExTrait { + const BYTES_USED: u8 = 20; /// Converts an EthAddress to an array of bytes. /// /// # Returns /// /// * `Array` - A 20-byte array representation of the EthAddress. fn to_bytes(self: EthAddress) -> Array { - let bytes_used: u256 = 20; let value: u256 = self.into(); let mut bytes: Array = Default::default(); - let mut i = 0; - while i != bytes_used { - let val = value.shr(8 * (bytes_used - i - 1)); + let mut i: u8 = 0; + while i != Self::BYTES_USED { + let val = value.shr(8_u32 * (Self::BYTES_USED.into() - i.into() - 1)); bytes.append((val & 0xFF).try_into().unwrap()); i += 1; }; @@ -43,7 +43,7 @@ pub impl EthAddressExImpl of EthAddressExTrait { let mut i: u32 = 0; while i != len { let byte: u256 = (*input.at(i)).into(); - result += byte.shl((8 * (offset - i)).into()); + result += byte.shl((8 * (offset - i))); i += 1; }; diff --git a/crates/utils/src/traits/integer.cairo b/crates/utils/src/traits/integer.cairo index 9ee504f46..0e0dc3b41 100644 --- a/crates/utils/src/traits/integer.cairo +++ b/crates/utils/src/traits/integer.cairo @@ -186,11 +186,9 @@ pub impl BitsUsedImpl< if self == Zero::zero() { return 0; } - let two: T = One::one() + One::one(); - let eight: T = two * two * two; let bytes_used = self.bytes_used(); - let last_byte = self.shr(eight * (bytes_used.into() - One::one())); + let last_byte = self.shr(8_u32 * (bytes_used.into() - One::one())); // safe unwrap since we know atmost 8 bits are used let bits_used: u8 = bits_used_internal::bits_used_in_byte(last_byte.try_into().unwrap());