diff --git a/src/core/reader/mod.rs b/src/core/reader/mod.rs index 545f25a8..6e5f26f8 100644 --- a/src/core/reader/mod.rs +++ b/src/core/reader/mod.rs @@ -7,6 +7,7 @@ pub mod types; /// A struct for managing and reading WASM bytecode /// /// Its purpose is to abstract parsing basic WASM values from the bytecode. +#[derive(Clone)] pub struct WasmReader<'a> { /// Entire WASM binary as slice pub full_wasm_binary: &'a [u8], @@ -148,9 +149,29 @@ impl<'a> WasmReader<'a> { pub fn into_inner(self) -> &'a [u8] { self.full_wasm_binary } + + /// A wrapper function for reads with transaction-like behavior. + /// + /// The provided closure will be called with `&mut self` and its result will be returned. + /// However if the closure returns `Err(_)`, `self` will be reset as if the closure was never called. + #[allow(unused)] + pub fn handle_transaction( + &mut self, + f: impl FnOnce(&mut WasmReader) -> Result, + ) -> Result { + let original = self.clone(); + f(self).inspect_err(|_| { + *self = original; + }) + } } pub trait WasmReadable: Sized { + /// Reads a new [`Self`] from given [`WasmReader`]. + /// + /// Note that if this function returns `Err(_)`, the [`WasmReader`] may still have been advanced, + /// which may lead to unexpected behaviour. + /// To avoid this consider using the [`WasmReader::handle_transaction`] method to wrap this function call. fn read(wasm: &mut WasmReader) -> Result; fn read_unvalidated(wasm: &mut WasmReader) -> Self; } @@ -315,4 +336,25 @@ mod test { assert_eq!(wasm_reader.remaining_bytes(), my_bytes); assert_eq!(wasm_reader.skip(6), Err(Error::Eof)); } + + #[test] + fn reader_transaction() { + let bytes = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]; + let mut reader = WasmReader::new(&bytes); + + assert_eq!( + reader.handle_transaction(|reader| { reader.strip_bytes::<2>() }), + Ok([0x1, 0x2]), + ); + + let transaction_result: Result<()> = reader.handle_transaction(|reader| { + assert_eq!(reader.strip_bytes::<2>(), Ok([0x3, 0x4])); + + // The exact error type does not matter + Err(Error::InvalidMagic) + }); + assert_eq!(transaction_result, Err(Error::InvalidMagic)); + + assert_eq!(reader.strip_bytes::<3>(), Ok([0x3, 0x4, 0x5])); + } }