Skip to content

Commit

Permalink
use from_be_bytes_partial if input type isnt known"
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Sep 5, 2024
1 parent b5d4bcf commit bd54558
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 38 deletions.
8 changes: 4 additions & 4 deletions crates/utils/src/crypto/modexp/mpnat.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ pub impl MPNatTraitImpl of MPNatTrait {
buf.copy_from_bytes_le((WORD_BYTES - r), bytes.slice(0, r)).unwrap();

// safe unwrap, since we know that bytes won't overflow
let word = buf.to_le_bytes().from_be_bytes().unwrap();
let word = buf.to_le_bytes().from_be_bytes().expect('mpnat_from_big_endian_word');
digits.set(i, word);

if i == 0 {
Expand All @@ -88,7 +88,7 @@ pub impl MPNatTraitImpl of MPNatTrait {
buf.copy_from_bytes_le(0, bytes.slice(j, next_j - j)).unwrap();

// safe unwrap, since we know that bytes won't overflow
let word: u64 = buf.to_le_bytes().from_be_bytes().unwrap();
let word: u64 = buf.to_le_bytes().from_be_bytes().expect('mpnat_from_big_endian_word');
digits.set(i, word);

if i == 0 {
Expand Down Expand Up @@ -369,7 +369,7 @@ pub impl MPNatTraitImpl of MPNatTrait {
}

if exp.len() <= (ByteSize::<usize>::byte_size()) {
let exp_as_number: usize = exp.from_le_bytes().unwrap();
let exp_as_number: usize = exp.from_le_bytes_partial().expect('modpow_exp_as_number');

match self.digits.len().checked_mul(exp_as_number) {
Option::Some(max_output_digits) => {
Expand Down Expand Up @@ -714,7 +714,7 @@ mod tests {

i += 1;
};
result.from_le_bytes().unwrap()
result.from_le_bytes_partial().expect('mpnat_to_u128')
}

fn check_modpow_even(base: u128, exp: u128, modulus: u128, expected: u128) {
Expand Down
254 changes: 224 additions & 30 deletions crates/utils/src/helpers.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -550,17 +550,39 @@ pub trait FromBytes<T> {
///
/// # Returns
/// * The Option::(value) represented by the bytes in big endian, Option::None if the span is
/// longer than the byte size of T.
/// not the byte size of T.
fn from_be_bytes(self: Span<u8>) -> Option<T>;

/// Parses a span of big endian bytes into a type T, allowing for partial input
///
/// # Arguments
/// * `self` a span of big endian bytes.
///
/// # Returns
/// * The Option::(value) represented by the bytes in big endian, Option::None if the span is
/// longer than the byte size of T.
fn from_be_bytes_partial(self: Span<u8>) -> Option<T>;


/// Parses a span of little endian bytes into a type T
///
/// # Arguments
/// * `self` a span of little endian bytes.
///
/// # Returns
/// * The Option::(value) represented by the bytes in little endian, Option::None if the span is
/// longer than the byte size of T.
/// not the byte size of T.
fn from_le_bytes(self: Span<u8>) -> Option<T>;

/// Parses a span of little endian bytes into a type T, allowing for partial input
///
/// # Arguments
/// * `self` a span of little endian bytes.
///
/// # Returns
/// * The Option::(value) represented by the bytes in little endian, Option::None if the span is
/// longer than the byte size of T.
fn from_le_bytes_partial(self: Span<u8>) -> Option<T>;
}

pub impl FromBytesImpl<
Expand Down Expand Up @@ -597,6 +619,22 @@ pub impl FromBytesImpl<
Option::Some(result)
}

fn from_be_bytes_partial(self: Span<u8>) -> Option<T> {
let byte_size = ByteSize::<T>::byte_size();

if self.len() > byte_size {
return Option::None;
}

let mut result: T = Zero::zero();
for byte in self {
let tmp = result * 256_u16.into();
result = tmp + (*byte).into();
};

Option::Some(result)
}

fn from_le_bytes(self: Span<u8>) -> Option<T> {
let byte_size = ByteSize::<T>::byte_size();

Expand All @@ -613,6 +651,23 @@ pub impl FromBytesImpl<
};
Option::Some(result)
}

fn from_le_bytes_partial(self: Span<u8>) -> Option<T> {
let byte_size = ByteSize::<T>::byte_size();

if self.len() > byte_size {
return Option::None;
}

let mut result: T = Zero::zero();
let mut i = self.len();
while i != 0 {
i -= 1;
let tmp = result * 256_u16.into();
result = tmp + (*self[i]).into();
};
Option::Some(result)
}
}


Expand Down Expand Up @@ -1574,39 +1629,178 @@ mod tests {
let input: Array<u8> = array![0xf4, 0x32, 0x15, 0x62];
let res: Option<u32> = input.span().from_be_bytes();

assert(res.is_some(), 'should have a value');
assert(res.unwrap() == 0xf4321562, 'wrong result value');
assert!(res.is_some());
assert_eq!(res.unwrap(), 0xf4321562);
}

#[test]
fn test_u32_from_be_bytes_too_big() {
fn test_u32_from_be_bytes_too_big_should_return_none() {
let input: Array<u8> = array![0xf4, 0x32, 0x15, 0x62, 0x01];
let res: Option<u32> = input.span().from_be_bytes();

assert(res.is_none(), 'should not have a value');
assert!(res.is_none());
}

#[test]
fn test_u32_from_be_bytes_too_small_should_return_none() {
let input: Array<u8> = array![0xf4, 0x32, 0x15];
let res: Option<u32> = input.span().from_be_bytes();

assert!(res.is_none());
}

#[test]
fn test_u32_from_be_bytes_partial_full() {
let input: Array<u8> = array![0xf4, 0x32, 0x15, 0x62];
let res: Option<u32> = input.span().from_be_bytes_partial();

assert!(res.is_some());
assert_eq!(res.unwrap(), 0xf4321562);
}

#[test]
fn test_u32_from_be_bytes_partial_smaller_input() {
let input: Array<u8> = array![0xf4, 0x32, 0x15];
let res: Option<u32> = input.span().from_be_bytes_partial();

assert!(res.is_some());
assert_eq!(res.unwrap(), 0xf43215);
}

#[test]
fn test_u32_from_be_bytes_partial_single_byte() {
let input: Array<u8> = array![0xf4];
let res: Option<u32> = input.span().from_be_bytes_partial();

assert!(res.is_some());
assert_eq!(res.unwrap(), 0xf4);
}

#[test]
fn test_u32_from_be_bytes_partial_empty_input() {
let input: Array<u8> = array![];
let res: Option<u32> = input.span().from_be_bytes_partial();

assert!(res.is_some());
assert_eq!(res.unwrap(), 0);
}

#[test]
fn test_u32_from_be_bytes_partial_too_big_input() {
let input: Array<u8> = array![0xf4, 0x32, 0x15, 0x62, 0x01];
let res: Option<u32> = input.span().from_be_bytes_partial();

assert!(res.is_none());
}

#[test]
fn test_u32_from_le_bytes() {
let input: Array<u8> = array![0x62, 0x15, 0x32, 0xf4];
let res: Option<u32> = input.span().from_le_bytes();

assert!(res.is_some());
assert_eq!(res.unwrap(), 0xf4321562);
}

#[test]
fn test_u32_from_le_bytes_too_big() {
let input: Array<u8> = array![0x62, 0x15, 0x32, 0xf4, 0x01];
let res: Option<u32> = input.span().from_le_bytes();

assert!(res.is_none());
}

#[test]
fn test_u32_from_le_bytes_too_small() {
let input: Array<u8> = array![0x62, 0x15, 0x32];
let res: Option<u32> = input.span().from_le_bytes();

assert!(res.is_none());
}

#[test]
fn test_u32_from_le_bytes_zero() {
let input: Array<u8> = array![0x00, 0x00, 0x00, 0x00];
let res: Option<u32> = input.span().from_le_bytes();

assert!(res.is_some());
assert_eq!(res.unwrap(), 0);
}

#[test]
fn test_u32_from_le_bytes_max() {
let input: Array<u8> = array![0xff, 0xff, 0xff, 0xff];
let res: Option<u32> = input.span().from_le_bytes();

assert!(res.is_some());
assert_eq!(res.unwrap(), 0xffffffff);
}

#[test]
fn test_u32_from_le_bytes_partial() {
let input: Array<u8> = array![0x62, 0x15, 0x32];
let res: Option<u32> = input.span().from_le_bytes_partial();

assert!(res.is_some());
assert_eq!(res.unwrap(), 0x321562);
}

#[test]
fn test_u32_from_le_bytes_partial_full() {
let input: Array<u8> = array![0x62, 0x15, 0x32, 0xf4];
let res: Option<u32> = input.span().from_le_bytes_partial();

assert!(res.is_some());
assert_eq!(res.unwrap(), 0xf4321562);
}

#[test]
fn test_u32_from_le_bytes_partial_too_big() {
let input: Array<u8> = array![0x62, 0x15, 0x32, 0xf4, 0x01];
let res: Option<u32> = input.span().from_le_bytes_partial();

assert!(res.is_none());
}

#[test]
fn test_u32_from_le_bytes_partial_empty() {
let input: Array<u8> = array![];
let res: Option<u32> = input.span().from_le_bytes_partial();

assert!(res.is_some());
assert_eq!(res.unwrap(), 0);
}

#[test]
fn test_u32_from_le_bytes_partial_single_byte() {
let input: Array<u8> = array![0xff];
let res: Option<u32> = input.span().from_le_bytes_partial();

assert!(res.is_some());
assert_eq!(res.unwrap(), 0xff);
}

#[test]
fn test_u32_to_bytes_full() {
let input: u32 = 0xf4321562;
let res: Span<u8> = input.to_be_bytes();

assert(res.len() == 4, 'wrong result length');
assert(*res[0] == 0xf4, 'wrong result value');
assert(*res[1] == 0x32, 'wrong result value');
assert(*res[2] == 0x15, 'wrong result value');
assert(*res[3] == 0x62, 'wrong result value');
assert_eq!(res.len(), 4);
assert_eq!(*res[0], 0xf4);
assert_eq!(*res[1], 0x32);
assert_eq!(*res[2], 0x15);
assert_eq!(*res[3], 0x62);
}

#[test]
fn test_u32_to_bytes_partial() {
let input: u32 = 0xf43215;
let res: Span<u8> = input.to_be_bytes();

assert(res.len() == 3, 'wrong result length');
assert(*res[0] == 0xf4, 'wrong result value');
assert(*res[1] == 0x32, 'wrong result value');
assert(*res[2] == 0x15, 'wrong result value');
assert_eq!(res.len(), 3);
assert_eq!(*res[0], 0xf4);
assert_eq!(*res[1], 0x32);
assert_eq!(*res[2], 0x15);
}


Expand All @@ -1615,9 +1809,9 @@ mod tests {
let input: u32 = 0x00f432;
let res: Span<u8> = input.to_be_bytes();

assert(res.len() == 2, 'wrong result length');
assert(*res[0] == 0xf4, 'wrong result value');
assert(*res[1] == 0x32, 'wrong result value');
assert_eq!(res.len(), 2);
assert_eq!(*res[0], 0xf4);
assert_eq!(*res[1], 0x32);
}

#[test]
Expand Down Expand Up @@ -1650,7 +1844,7 @@ mod tests {
let len: u32 = 0x001234;
let bytes_count = len.bytes_used();

assert(bytes_count == 2, 'wrong bytes count');
assert_eq!(bytes_count, 2);
}
}

Expand Down Expand Up @@ -1787,10 +1981,10 @@ mod tests {
fn test_split_u256_into_u64_little() {
let value: u256 = 0xFAFFFFFF000000E500000077000000DEAD0000000004200000FADE0000450000;
let ((high_h, low_h), (high_l, low_l)) = value.split_into_u64_le();
assert(high_h == 0xDE00000077000000, 'split mismatch');
assert(low_h == 0xE5000000FFFFFFFA, 'split mismatch');
assert(high_l == 0x0000450000DEFA00, 'split mismatch');
assert(low_l == 0x00200400000000AD, 'split mismatch');
assert_eq!(high_h, 0xDE00000077000000);
assert_eq!(low_h, 0xE5000000FFFFFFFA);
assert_eq!(high_l, 0x0000450000DEFA00);
assert_eq!(low_l, 0x00200400000000AD);
}

#[test]
Expand Down Expand Up @@ -1888,13 +2082,13 @@ mod tests {
byte_arr.append_byte(0xFF);
byte_arr.append_byte(0xAA);
byte_arr.append_span_bytes(bytes.span());
assert(byte_arr.len() == 6, 'wrong length');
assert(byte_arr[0] == 0xFF, 'wrong value');
assert(byte_arr[1] == 0xAA, 'wrong value');
assert(byte_arr[2] == 0x01, 'wrong value');
assert(byte_arr[3] == 0x02, 'wrong value');
assert(byte_arr[4] == 0x03, 'wrong value');
assert(byte_arr[5] == 0x04, 'wrong value');
assert_eq!(byte_arr.len(), 6);
assert_eq!(byte_arr[0], 0xFF);
assert_eq!(byte_arr[1], 0xAA);
assert_eq!(byte_arr[2], 0x01);
assert_eq!(byte_arr[3], 0x02);
assert_eq!(byte_arr[4], 0x03);
assert_eq!(byte_arr[5], 0x04);
}

#[test]
Expand Down
Loading

0 comments on commit bd54558

Please sign in to comment.