diff --git a/src/core/error.rs b/src/core/error.rs index 42ba0915..60e8f808 100644 --- a/src/core/error.rs +++ b/src/core/error.rs @@ -11,9 +11,11 @@ pub enum RuntimeError { DivideBy0, UnrepresentableResult, FunctionNotFound, + ModuleNotFound, StackSmash, // https://github.com/wasmi-labs/wasmi/blob/37d1449524a322817c55026eb21eb97dd693b9ce/crates/core/src/trap.rs#L265C5-L265C27 BadConversionToInteger, + UnmetImport, } #[derive(Debug, PartialEq, Eq, Clone)] @@ -127,9 +129,13 @@ impl Display for RuntimeError { match self { RuntimeError::DivideBy0 => f.write_str("Divide by zero is not permitted"), RuntimeError::UnrepresentableResult => f.write_str("Result is unrepresentable"), + RuntimeError::ModuleNotFound => f.write_str("Module not found"), RuntimeError::FunctionNotFound => f.write_str("Function not found"), RuntimeError::StackSmash => f.write_str("Stack smashed"), RuntimeError::BadConversionToInteger => f.write_str("Bad conversion to integer"), + RuntimeError::UnmetImport => { + f.write_str("There is at least one import which has no corresponding export") + } } } } diff --git a/src/execution/execution_info.rs b/src/execution/execution_info.rs new file mode 100644 index 00000000..eec16eaf --- /dev/null +++ b/src/execution/execution_info.rs @@ -0,0 +1,29 @@ +use alloc::string::{String, ToString}; +use alloc::vec::Vec; + +use crate::core::reader::types::FuncType; +use crate::core::reader::WasmReader; +use crate::execution::Store; + +/// ExecutionInfo is a compilation of relevant information needed by the [interpreter loop]( +/// crate::execution::interpreter_loop::run). The lifetime annotation `'r` represents that this structure needs to be +/// valid at least as long as the [RuntimeInstance](crate::execution::RuntimeInstance) that creates it. +pub struct ExecutionInfo<'r> { + pub name: String, + pub wasm_bytecode: &'r [u8], + pub wasm_reader: WasmReader<'r>, + pub fn_types: Vec, + pub store: Store, +} + +impl<'r> ExecutionInfo<'r> { + pub fn new(name: &str, wasm_bytecode: &'r [u8], fn_types: Vec, store: Store) -> Self { + ExecutionInfo { + name: name.to_string(), + wasm_bytecode, + wasm_reader: WasmReader::new(wasm_bytecode), + fn_types, + store, + } + } +} diff --git a/src/execution/interpreter_loop.rs b/src/execution/interpreter_loop.rs index 6892b52e..0bd9184d 100644 --- a/src/execution/interpreter_loop.rs +++ b/src/execution/interpreter_loop.rs @@ -16,13 +16,12 @@ use crate::{ assert_validated::UnwrapValidatedExt, core::{ indices::{FuncIdx, GlobalIdx, LocalIdx}, - reader::{ - types::{memarg::MemArg, FuncType}, - WasmReadable, WasmReader, - }, + reader::{types::memarg::MemArg, WasmReadable}, }, + execution::execution_info::ExecutionInfo, + execution::Lut, locals::Locals, - store::Store, + store::FuncInst, value, value_stack::Stack, NumType, RuntimeError, ValType, Value, @@ -33,19 +32,22 @@ use crate::execution::hooks::HookSet; /// Interprets a functions. Parameters and return values are passed on the stack. pub(super) fn run( - wasm_bytecode: &[u8], - types: &[FuncType], - store: &mut Store, + modules: &mut [ExecutionInfo], + current_module_idx: &mut usize, + lut: &Lut, stack: &mut Stack, mut hooks: H, ) -> Result<(), RuntimeError> { - let func_inst = store + let func_inst = modules[*current_module_idx] + .store .funcs .get(stack.current_stackframe().func_idx) + .unwrap_validated() + .try_into_local() .unwrap_validated(); // Start reading the function's instructions - let mut wasm = WasmReader::new(wasm_bytecode); + let mut wasm = &mut modules[*current_module_idx].wasm_reader; // unwrap is sound, because the validation assures that the function points to valid subslice of the WASM binary wasm.move_start_to(func_inst.code_expr).unwrap(); @@ -54,13 +56,14 @@ pub(super) fn run( loop { // call the instruction hook #[cfg(feature = "hooks")] - hooks.instruction_hook(wasm_bytecode, wasm.pc); + hooks.instruction_hook(modules[*current_module_idx].wasm_bytecode, wasm.pc); let first_instr_byte = wasm.read_u8().unwrap_validated(); match first_instr_byte { END => { - let maybe_return_address = stack.pop_stackframe(); + let (return_module, maybe_return_address) = stack.pop_stackframe(); + *current_module_idx = return_module; // We finished this entire invocation if there is no stackframe left. If there are // one or more stack frames, we need to continue from where the callee was called @@ -70,6 +73,7 @@ pub(super) fn run( } trace!("end of function reached, returning to previous stack frame"); + wasm = &mut modules[return_module].wasm_reader; wasm.pc = maybe_return_address; } RETURN => { @@ -77,8 +81,15 @@ pub(super) fn run( let func_to_call_idx = stack.current_stackframe().func_idx; - let func_to_call_inst = store.funcs.get(func_to_call_idx).unwrap_validated(); - let func_to_call_ty = types.get(func_to_call_inst.ty).unwrap_validated(); + let func_to_call_inst = modules[*current_module_idx] + .store + .funcs + .get(func_to_call_idx) + .unwrap_validated(); + let func_to_call_ty = modules[*current_module_idx] + .fn_types + .get(func_to_call_inst.ty()) + .unwrap_validated(); let ret_vals = stack .pop_tail_iter(func_to_call_ty.returns.valtypes.len()) @@ -94,23 +105,71 @@ pub(super) fn run( } trace!("end of function reached, returning to previous stack frame"); - wasm.pc = stack.pop_stackframe(); + let (return_module, return_pc) = stack.pop_stackframe(); + *current_module_idx = return_module; + wasm = &mut modules[return_module].wasm_reader; + wasm.pc = return_pc; } CALL => { let func_to_call_idx = wasm.read_var_u32().unwrap_validated() as FuncIdx; - let func_to_call_inst = store.funcs.get(func_to_call_idx).unwrap_validated(); - let func_to_call_ty = types.get(func_to_call_inst.ty).unwrap_validated(); + let func_to_call_inst = modules[*current_module_idx] + .store + .funcs + .get(func_to_call_idx) + .unwrap_validated(); + let func_to_call_ty = modules[*current_module_idx] + .fn_types + .get(func_to_call_inst.ty()) + .unwrap_validated(); let params = stack.pop_tail_iter(func_to_call_ty.params.valtypes.len()); - let remaining_locals = func_to_call_inst.locals.iter().cloned(); trace!("Instruction: call [{func_to_call_idx:?}]"); - let locals = Locals::new(params, remaining_locals); - stack.push_stackframe(func_to_call_idx, func_to_call_ty, locals, wasm.pc); - wasm.move_start_to(func_to_call_inst.code_expr) - .unwrap_validated(); + match func_to_call_inst { + FuncInst::Local(local_func_inst) => { + let remaining_locals = local_func_inst.locals.iter().cloned(); + let locals = Locals::new(params, remaining_locals); + + stack.push_stackframe( + *current_module_idx, + func_to_call_idx, + func_to_call_ty, + locals, + wasm.pc, + ); + + wasm.move_start_to(local_func_inst.code_expr) + .unwrap_validated(); + } + FuncInst::Imported(_imported_func_inst) => { + let (next_module, next_func_idx) = lut + .lookup(*current_module_idx, func_to_call_idx) + .expect("invalid state for lookup"); + + let local_func_inst = modules[next_module].store.funcs[next_func_idx] + .try_into_local() + .unwrap(); + + let remaining_locals = local_func_inst.locals.iter().cloned(); + let locals = Locals::new(params, remaining_locals); + + stack.push_stackframe( + *current_module_idx, + func_to_call_idx, + func_to_call_ty, + locals, + wasm.pc, + ); + + wasm = &mut modules[next_module].wasm_reader; + *current_module_idx = next_module; + + wasm.move_start_to(local_func_inst.code_expr) + .unwrap_validated(); + } + } } LOCAL_GET => { stack.get_local(wasm.read_var_u32().unwrap_validated() as LocalIdx); @@ -119,21 +178,38 @@ pub(super) fn run( LOCAL_TEE => stack.tee_local(wasm.read_var_u32().unwrap_validated() as LocalIdx), GLOBAL_GET => { let global_idx = wasm.read_var_u32().unwrap_validated() as GlobalIdx; - let global = store.globals.get(global_idx).unwrap_validated(); + let global = modules[*current_module_idx] + .store + .globals + .get(global_idx) + .unwrap_validated(); + + // TODO: imported global stack.push_value(global.value); } GLOBAL_SET => { let global_idx = wasm.read_var_u32().unwrap_validated() as GlobalIdx; - let global = store.globals.get_mut(global_idx).unwrap_validated(); + let global = modules[*current_module_idx] + .store + .globals + .get_mut(global_idx) + .unwrap_validated(); + + // TODO: imported global (?) ... can imported globals be set as mutable? global.value = stack.pop_value(global.global.ty.ty) } I32_LOAD => { - let memarg = MemArg::read_unvalidated(&mut wasm); + let memarg = MemArg::read_unvalidated(wasm); let relative_address: u32 = stack.pop_value(ValType::NumType(NumType::I32)).into(); - let mem = store.mems.first().unwrap_validated(); // there is only one memory allowed as of now + // TODO: how does this interact with imports? + let mem = modules[*current_module_idx] + .store + .mems + .first() + .unwrap_validated(); // there is only one memory allowed as of now let data: u32 = { // The spec states that this should be a 33 bit integer @@ -156,10 +232,15 @@ pub(super) fn run( trace!("Instruction: i32.load [{relative_address}] -> [{data}]"); } F32_LOAD => { - let memarg = MemArg::read_unvalidated(&mut wasm); + let memarg = MemArg::read_unvalidated(wasm); let relative_address: u32 = stack.pop_value(ValType::NumType(NumType::I32)).into(); - let mem = store.mems.first().unwrap_validated(); // there is only one memory allowed as of now + // TODO: how does this interact with imports? + let mem = modules[*current_module_idx] + .store + .mems + .first() + .unwrap_validated(); // there is only one memory allowed as of now let data: f32 = { // The spec states that this should be a 33 bit integer @@ -182,12 +263,17 @@ pub(super) fn run( trace!("Instruction: f32.load [{relative_address}] -> [{data}]"); } I32_STORE => { - let memarg = MemArg::read_unvalidated(&mut wasm); + let memarg = MemArg::read_unvalidated(wasm); let data_to_store: u32 = stack.pop_value(ValType::NumType(NumType::I32)).into(); let relative_address: u32 = stack.pop_value(ValType::NumType(NumType::I32)).into(); - let mem = store.mems.get_mut(0).unwrap_validated(); // there is only one memory allowed as of now + // TODO: How does this interact with imports? + let mem = modules[*current_module_idx] + .store + .mems + .get_mut(0) + .unwrap_validated(); // there is only one memory allowed as of now // The spec states that this should be a 33 bit integer // See: https://webassembly.github.io/spec/core/syntax/instructions.html#memory-instructions @@ -203,12 +289,17 @@ pub(super) fn run( trace!("Instruction: i32.store [{relative_address} {data_to_store}] -> []"); } F32_STORE => { - let memarg = MemArg::read_unvalidated(&mut wasm); + let memarg = MemArg::read_unvalidated(wasm); let data_to_store: f32 = stack.pop_value(ValType::NumType(NumType::F32)).into(); let relative_address: u32 = stack.pop_value(ValType::NumType(NumType::I32)).into(); - let mem = store.mems.get_mut(0).unwrap_validated(); // there is only one memory allowed as of now + // TODO: how does this interact with imports? + let mem = modules[*current_module_idx] + .store + .mems + .get_mut(0) + .unwrap_validated(); // there is only one memory allowed as of now // The spec states that this should be a 33 bit integer // See: https://webassembly.github.io/spec/core/syntax/instructions.html#memory-instructions @@ -224,12 +315,17 @@ pub(super) fn run( trace!("Instruction: f32.store [{relative_address} {data_to_store}] -> []"); } F64_STORE => { - let memarg = MemArg::read_unvalidated(&mut wasm); + let memarg = MemArg::read_unvalidated(wasm); let data_to_store: f64 = stack.pop_value(ValType::NumType(NumType::F64)).into(); let relative_address: u32 = stack.pop_value(ValType::NumType(NumType::I32)).into(); - let mem = store.mems.get_mut(0).unwrap_validated(); // there is only one memory allowed as of now + // TODO: how does this interact with imports? + let mem = modules[*current_module_idx] + .store + .mems + .get_mut(0) + .unwrap_validated(); // there is only one memory allowed as of now // The spec states that this should be a 33 bit integer // See: https://webassembly.github.io/spec/core/syntax/instructions.html#memory-instructions diff --git a/src/execution/lut.rs b/src/execution/lut.rs new file mode 100644 index 00000000..3081b7bc --- /dev/null +++ b/src/execution/lut.rs @@ -0,0 +1,113 @@ +use crate::{core::reader::types::export::ExportDesc, execution::execution_info::ExecutionInfo}; +use alloc::{collections::btree_map::BTreeMap, string::String, vec::Vec}; + +pub struct Lut { + /// function_lut\[local_module_idx\]\[function_local_idx\] = (foreign_module_idx, function_foreign_idx) + /// + /// - Module A imports a function "foo". Inside module A, the function has the index "function_local_idx". Module A + /// is assigned the index "local_module_idx". + /// - Module B exports a function "foo". Inside module B, the function has the index "function_foreign_idx". Module + /// B is assigned the index "foreign_module_idx". + function_lut: Vec>, +} + +impl Lut { + /// Create a new linker lookup-table. + /// + /// # Arguments + /// - `modules`: The modules to link together. + /// - `module_map`: A map from module name to module index within the `modules` array. + /// + /// # Returns + /// A new linker lookup-table. Can return `None` if there are import directives that cannot be resolved. + pub fn new(modules: &[ExecutionInfo], module_map: &BTreeMap) -> Option { + let mut function_lut = Vec::new(); + for module in modules { + let module_lut = module + .store + .funcs + .iter() + .filter_map(|f| f.try_into_imported()) + .map(|import| { + Self::manual_lookup( + modules, + module_map, + &import.module_name, + &import.function_name, + ) + }) + .collect::>>()?; + + // TODO: what do we want to do if there is a missing import/export pair? Currently we fail the entire + // operation. Should it be a RuntimeError if said missing pair is called? + + function_lut.push(module_lut); + } + + Some(Self { function_lut }) + } + + /// Lookup a function by its module and function index. + /// + /// # Arguments + /// - `module_idx`: The index of the module within the `modules` array passed in [Lut::new]. + /// - `function_idx`: The index of the function within the module. This index is considered in-bounds only if it is + /// an index of an imported function. + /// + /// # Returns + /// - `None`, if the indicies are out of bound + /// - `Some(export_module_idx, export_function_idx)`, where the new indicies are the indicies of the module which + /// contains the implementation of the imported function, and the implementation has the returned index within. + pub fn lookup(&self, module_idx: usize, function_idx: usize) -> Option<(usize, usize)> { + self.function_lut + .get(module_idx)? + .get(function_idx) + .copied() + } + + /// Manually lookup a function by its module and function name. + /// + /// This function is used to resolve import directives before the [Lut] is created, and can be used to resolve + /// imports even after the [Lut] is created at the cost of speed. + /// + /// # Arguments + /// - `modules`: The modules to link together. + /// - `module_map`: A map from module name to module index within the `modules` array. + /// - `module_name`: The name of the module which imports the function. + /// - `function_name`: The name of the function to import. + /// + /// # Returns + /// - `None`, if the module or function is not found. + /// - `Some(export_module_idx, export_function_idx)`, where the new indicies are the indicies of the module which + /// contains the implementation of the imported function, and the implementation has the returned index within. + /// Note that this function returns the first matching function, if there are multiple functions with the same + /// name. + pub fn manual_lookup( + modules: &[ExecutionInfo], + module_map: &BTreeMap, + module_name: &str, + function_name: &str, + ) -> Option<(usize, usize)> { + let module_idx = module_map.get(module_name)?; + let module = &modules[*module_idx]; + + module + .store + .exports + .iter() + .filter_map(|export| { + if export.name == function_name { + Some(&export.desc) + } else { + None + } + }) + .find_map(|desc| { + if let ExportDesc::FuncIdx(func_idx) = desc { + Some((*module_idx, *func_idx)) + } else { + None + } + }) + } +} diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 3480690f..f577c265 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -1,14 +1,19 @@ -use alloc::string::ToString; +use alloc::collections::btree_map::BTreeMap; +use alloc::string::{String, ToString}; use alloc::vec::Vec; use const_interpreter_loop::run_const; +use execution_info::ExecutionInfo; use function_ref::FunctionRef; use interpreter_loop::run; use locals::Locals; +use lut::Lut; +use store::{ImportedFuncInst, LocalFuncInst}; use value_stack::Stack; -use crate::core::reader::types::export::{Export, ExportDesc}; -use crate::core::reader::types::FuncType; +use crate::core::reader::types::export::ExportDesc; +use crate::core::reader::types::import::ImportDesc; +use crate::core::reader::types::ValType; use crate::core::reader::WasmReader; use crate::execution::assert_validated::UnwrapValidatedExt; use crate::execution::hooks::{EmptyHookSet, HookSet}; @@ -16,15 +21,17 @@ use crate::execution::store::{FuncInst, GlobalInst, MemInst, Store}; use crate::execution::value::Value; use crate::validation::code::read_declared_locals; use crate::value::InteropValueList; -use crate::{RuntimeError, ValType, ValidationInfo}; +use crate::{RuntimeError, ValidationInfo}; // TODO pub(crate) mod assert_validated; mod const_interpreter_loop; +pub(crate) mod execution_info; pub mod function_ref; pub mod hooks; mod interpreter_loop; pub(crate) mod locals; +pub(crate) mod lut; pub(crate) mod store; pub mod value; pub mod value_stack; @@ -36,10 +43,9 @@ pub struct RuntimeInstance<'b, H = EmptyHookSet> where H: HookSet, { - pub wasm_bytecode: &'b [u8], - types: Vec, - exports: Vec, - store: Store, + pub modules: Vec>, + module_map: BTreeMap, + lut: Option, pub hook_set: H, } @@ -67,16 +73,15 @@ where ) -> Result { trace!("Starting instantiation of bytecode"); - let store = Self::init_store(validation_info); - let mut instance = RuntimeInstance { - wasm_bytecode: validation_info.wasm, - types: validation_info.types.clone(), - exports: validation_info.exports.clone(), - store, + modules: Vec::new(), + module_map: BTreeMap::new(), + lut: None, hook_set, }; + instance.add_module(module_name, validation_info); + // TODO: how do we handle the start function, if we don't have a LUT yet? if let Some(start) = validation_info.start { // "start" is not always exported, so we need create a non-API exposed function reference. // Note: function name is not important here, as it is not used in the verification process. @@ -114,8 +119,13 @@ where module_idx: usize, function_idx: usize, ) -> Result { - // TODO: Module resolution - let function_name = self + let module = self + .modules + .get(module_idx) + .ok_or(RuntimeError::ModuleNotFound)?; + + let function_name = module + .store .exports .iter() .find(|export| match &export.desc { @@ -126,8 +136,7 @@ where .ok_or(RuntimeError::FunctionNotFound)?; Ok(FunctionRef { - // TODO: get the module name from the module index - module_name: DEFAULT_MODULE.to_string(), + module_name: module.name.clone(), function_name, module_index: module_idx, function_index: function_idx, @@ -135,14 +144,20 @@ where }) } - // TODO: remove this annotation when implementing the function - #[allow(clippy::result_unit_err)] - pub fn add_module( - &mut self, - _module_name: &str, - _validation_info: &'_ ValidationInfo<'b>, - ) -> Result<(), ()> { - todo!("Implement module linking"); + pub fn add_module(&mut self, module_name: &str, validation_info: &'_ ValidationInfo<'b>) { + let store = Self::init_store(validation_info); + let exec_info = ExecutionInfo::new( + module_name, + validation_info.wasm, + validation_info.types.clone(), + store, + ); + + self.module_map + .insert(module_name.to_string(), self.modules.len()); + self.modules.push(exec_info); + + self.lut = Lut::new(&self.modules, &self.module_map); } pub fn invoke( @@ -151,11 +166,21 @@ where params: Param, ) -> Result { // First, verify that the function reference is valid - let (_module_idx, func_idx) = self.verify_function_ref(function_ref)?; + let (module_idx, func_idx) = self.verify_function_ref(function_ref)?; // -=-= Verification =-=- - let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx"); - let func_ty = self.types.get(func_inst.ty).unwrap_validated(); + trace!("{:?}", self.modules[module_idx].store.funcs); + let func_inst = self.modules[module_idx] + .store + .funcs + .get(func_idx) + .ok_or(RuntimeError::FunctionNotFound)? + .try_into_local() + .ok_or(RuntimeError::FunctionNotFound)?; + let func_ty = self.modules[module_idx] + .fn_types + .get(func_inst.ty) + .unwrap_validated(); // Check correct function parameters and return types if func_ty.params.valtypes != Param::TYS { @@ -174,13 +199,15 @@ where // setting `usize::MAX` as return address for the outermost function ensures that we // observably fail upon errornoeusly continuing execution after that function returns. - stack.push_stackframe(func_idx, func_ty, locals, usize::MAX); + stack.push_stackframe(module_idx, func_idx, func_ty, locals, usize::MAX); + + let mut current_module_idx = module_idx; // Run the interpreter run( - self.wasm_bytecode, - &self.types, - &mut self.store, + &mut self.modules, + &mut current_module_idx, + self.lut.as_ref().ok_or(RuntimeError::UnmetImport)?, &mut stack, EmptyHookSet, )?; @@ -206,11 +233,21 @@ where ret_types: &[ValType], ) -> Result, RuntimeError> { // First, verify that the function reference is valid - let (_module_idx, func_idx) = self.verify_function_ref(function_ref)?; + let (module_idx, func_idx) = self.verify_function_ref(function_ref)?; // -=-= Verification =-=- - let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx"); - let func_ty = self.types.get(func_inst.ty).unwrap_validated(); + let func_inst = self.modules[module_idx] + .store + .funcs + .get(func_idx) + .ok_or(RuntimeError::FunctionNotFound)? + .try_into_local() + .ok_or(RuntimeError::FunctionNotFound)?; + let func_ty = self.modules[module_idx] + .fn_types + .get(func_inst.ty) + .unwrap_validated() + .clone(); // Verify that the given parameters match the function parameters let param_types = params.iter().map(|v| v.to_ty()).collect::>(); @@ -227,20 +264,19 @@ where // Prepare a new stack with the locals for the entry function let mut stack = Stack::new(); let locals = Locals::new(params.into_iter(), func_inst.locals.iter().cloned()); - stack.push_stackframe(func_idx, func_ty, locals, 0); + stack.push_stackframe(module_idx, func_idx, &func_ty, locals, 0); + + let mut current_module_idx = module_idx; // Run the interpreter run( - self.wasm_bytecode, - &self.types, - &mut self.store, + &mut self.modules, + &mut current_module_idx, + self.lut.as_ref().ok_or(RuntimeError::UnmetImport)?, &mut stack, EmptyHookSet, )?; - let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx"); - let func_ty = self.types.get(func_inst.ty).unwrap_validated(); - // Pop return values from stack let return_values = func_ty .returns @@ -256,13 +292,30 @@ where Ok(ret) } - // TODO: replace this with the lookup table when implmenting the linker + /// Get the indicies of a module and function by their names. + /// + /// # Arguments + /// - `module_name`: The module in which to find the function. + /// - `function_name`: The name of the function to find inside the module. The function must be a local function and + /// not an import. + /// + /// # Returns + /// - `Ok((module_idx, func_idx))`, where `module_idx` is the internal index of the module inside the + /// [RuntimeInstance], and `func_idx` is the internal index of the function inside the module. + /// - `Err(RuntimeError::ModuleNotFound)`, if the module is not found. + /// - `Err(RuntimeError::FunctionNotFound`, if the function is not found within the module. fn get_indicies( &self, - _module_name: &str, + module_name: &str, function_name: &str, ) -> Result<(usize, usize), RuntimeError> { - let func_idx = self + let module_idx = *self + .module_map + .get(module_name) + .ok_or(RuntimeError::ModuleNotFound)?; + + let func_idx = self.modules[module_idx] + .store .exports .iter() .find_map(|export| { @@ -277,9 +330,25 @@ where }) .ok_or(RuntimeError::FunctionNotFound)?; - Ok((0, func_idx)) + Ok((module_idx, func_idx)) } + /// Verify that the function reference is still valid. A function reference may be invalid if it created from + /// another [RuntimeInstance] or the modules inside the instance have been changed in a way that the indicies inside + /// the [FunctionRef] would be invalid. + /// + /// Note: this function ensures that making an unchecked indexation will not cause a panic. + /// + /// # Returns + /// - `Ok((function_ref.module_idx, function_ref.func_idx))` + /// - `Err(RuntimeError::FunctionNotFound)`, or `Err(RuntimeError::ModuleNotFound)` if the function is not valid. + /// + /// # Implementation details + /// For an exported function (i.e. created by the same [RuntimeInstance]), the names are re-resolved using + /// [RuntimeInstance::get_indicies], and the indicies are compared with the indicies in the [FunctionRef]. + /// + /// For a [FunctionRef] with the [export](FunctionRef::exported) flag set to `false`, the indicies are checked to be + /// in-bounds, and that the module name matches the module name in the [FunctionRef]. The function name is ignored. fn verify_function_ref( &self, function_ref: &FunctionRef, @@ -288,8 +357,10 @@ where let (module_idx, func_idx) = self.get_indicies(&function_ref.module_name, &function_ref.function_name)?; - if module_idx != function_ref.module_index || func_idx != function_ref.function_index { - // TODO: should we return a different error here? + if module_idx != function_ref.module_index { + return Err(RuntimeError::FunctionNotFound); + } + if func_idx != function_ref.function_index { return Err(RuntimeError::FunctionNotFound); } @@ -297,11 +368,19 @@ where } else { let (module_idx, func_idx) = (function_ref.module_index, function_ref.function_index); - // TODO: verify module named - index mapping. + let module = self + .modules + .get(module_idx) + .ok_or(RuntimeError::ModuleNotFound)?; + + if module.name != function_ref.module_name { + return Err(RuntimeError::ModuleNotFound); + } // Sanity check that the function index is at least in the bounds of the store, though this doesn't mean // that it's a valid function. - self.store + module + .store .funcs .get(func_idx) .ok_or(RuntimeError::FunctionNotFound)?; @@ -317,28 +396,40 @@ where let functions = validation_info.functions.iter(); let func_blocks = validation_info.func_blocks.iter(); - functions - .zip(func_blocks) - .map(|(ty, func)| { - wasm_reader - .move_start_to(*func) - .expect("function index to be in the bounds of the WASM binary"); - - let (locals, bytes_read) = wasm_reader - .measure_num_read_bytes(read_declared_locals) - .unwrap_validated(); - - let code_expr = wasm_reader - .make_span(func.len() - bytes_read) - .expect("TODO remove this expect"); - - FuncInst { - ty: *ty, - locals, - code_expr, - } + let local_function_inst = functions.zip(func_blocks).map(|(ty, func)| { + wasm_reader + .move_start_to(*func) + .expect("function index to be in the bounds of the WASM binary"); + + let (locals, bytes_read) = wasm_reader + .measure_num_read_bytes(read_declared_locals) + .unwrap_validated(); + + let code_expr = wasm_reader + .make_span(func.len() - bytes_read) + .expect("TODO remove this expect"); + + FuncInst::Local(LocalFuncInst { + ty: *ty, + locals, + code_expr, }) - .collect() + }); + + let imported_function_inst = + validation_info + .imports + .iter() + .filter_map(|import| match &import.desc { + ImportDesc::Func(type_idx) => Some(FuncInst::Imported(ImportedFuncInst { + ty: *type_idx, + module_name: import.module_name.clone(), + function_name: import.name.clone(), + })), + _ => None, + }); + + imported_function_inst.chain(local_function_inst).collect() }; let memory_instances: Vec = validation_info @@ -369,10 +460,12 @@ where }) .collect(); + let exports = validation_info.exports.clone(); Store { funcs: function_instances, mems: memory_instances, globals: global_instances, + exports, } } } diff --git a/src/execution/store.rs b/src/execution/store.rs index 6d897b8f..3e7fb782 100644 --- a/src/execution/store.rs +++ b/src/execution/store.rs @@ -1,9 +1,11 @@ +use alloc::string::String; use alloc::vec; use alloc::vec::Vec; use core::iter; use crate::core::indices::TypeIdx; use crate::core::reader::span::Span; +use crate::core::reader::types::export::Export; use crate::core::reader::types::global::Global; use crate::core::reader::types::{MemType, TableType, ValType}; use crate::execution::value::{Ref, Value}; @@ -13,25 +15,68 @@ use crate::execution::value::{Ref, Value}; /// globals, element segments, and data segments that have been allocated during the life time of /// the abstract machine. /// +#[derive(Debug)] pub struct Store { pub funcs: Vec, // tables: Vec, pub mems: Vec, pub globals: Vec, + pub exports: Vec, } -pub struct FuncInst { +#[derive(Debug)] +pub enum FuncInst { + Local(LocalFuncInst), + Imported(ImportedFuncInst), +} + +impl FuncInst { + pub fn ty(&self) -> TypeIdx { + match self { + FuncInst::Local(f) => f.ty, + FuncInst::Imported(f) => f.ty, + } + } + + pub fn try_into_local(&self) -> Option<&LocalFuncInst> { + match self { + FuncInst::Local(f) => Some(f), + FuncInst::Imported(_) => None, + } + } + + pub fn try_into_imported(&self) -> Option<&ImportedFuncInst> { + match self { + FuncInst::Local(_) => None, + FuncInst::Imported(f) => Some(f), + } + } +} + +#[derive(Debug)] +pub struct LocalFuncInst { pub ty: TypeIdx, pub locals: Vec, pub code_expr: Span, } +#[derive(Debug)] +pub struct ImportedFuncInst { + pub ty: TypeIdx, + #[allow(dead_code)] + pub module_name: String, + #[allow(dead_code)] + pub function_name: String, +} + #[allow(dead_code)] +#[derive(Debug)] pub struct TableInst { pub ty: TableType, pub elem: Vec, } +#[derive(Debug)] pub struct MemInst { #[allow(warnings)] pub ty: MemType, @@ -61,6 +106,7 @@ impl MemInst { } } +#[derive(Debug)] pub struct GlobalInst { pub global: Global, /// Must be of the same type as specified in `ty` diff --git a/src/execution/value_stack.rs b/src/execution/value_stack.rs index 27cf4ddf..902a0243 100644 --- a/src/execution/value_stack.rs +++ b/src/execution/value_stack.rs @@ -105,9 +105,10 @@ impl Stack { self.frames.last_mut().unwrap_validated() } - /// Pop a [`CallFrame`] from the call stack, returning the return address - pub fn pop_stackframe(&mut self) -> usize { + /// Pop a [`CallFrame`] from the call stack, returning the module id and return address + pub fn pop_stackframe(&mut self) -> (usize, usize) { let CallFrame { + module_idx, return_addr, value_stack_base_idx, return_value_count, @@ -123,7 +124,7 @@ impl Stack { "after a function call finished, the stack must have exactly as many values as it had before calling the function plus the number of function return values" ); - return_addr + (module_idx, return_addr) } /// Push a stackframe to the call stack @@ -131,12 +132,14 @@ impl Stack { /// Takes the current [`Self::values`]'s length as [`CallFrame::value_stack_base_idx`]. pub fn push_stackframe( &mut self, + module_idx: usize, func_idx: FuncIdx, func_ty: &FuncType, locals: Locals, return_addr: usize, ) { self.frames.push(CallFrame { + module_idx, func_idx, locals, return_addr, @@ -169,6 +172,9 @@ impl Stack { /// The [WASM spec](https://webassembly.github.io/spec/core/exec/runtime.html#stack) calls this `Activations`, however it refers to the call frames of functions. pub(crate) struct CallFrame { + /// + pub module_idx: usize, + /// Index to the function of this [`CallFrame`] pub func_idx: FuncIdx, diff --git a/src/validation/code.rs b/src/validation/code.rs index 52400d73..6543f73b 100644 --- a/src/validation/code.rs +++ b/src/validation/code.rs @@ -1,7 +1,7 @@ use alloc::vec::Vec; use core::iter; -use crate::core::indices::{FuncIdx, GlobalIdx, LocalIdx}; +use crate::core::indices::{FuncIdx, GlobalIdx, LocalIdx, TypeIdx}; use crate::core::reader::section_header::{SectionHeader, SectionTy}; use crate::core::reader::span::Span; use crate::core::reader::types::global::Global; @@ -11,17 +11,34 @@ use crate::core::reader::{WasmReadable, WasmReader}; use crate::validation_stack::ValidationStack; use crate::{Error, Result}; +/// +/// +/// # Arguments +/// - `wasm`: The reader over the whole wasm binary. It is expected to be at the beginning of the code section, and +/// after execution it will be at the beginning of the next section if the result is `Ok(...)`. +/// - `section_header`: The header of the code section. +/// - `fn_types`: The types of all functions in the module, including imported functions. +/// - `type_idx_of_fn`: The index of the type of each function in `fn_types`, including imported functions. As per the +/// specification, the indicies of the type of imported functions come first. +/// - `num_imported_funcs`: The number of imported functions. This is used as an offset, to determine the first index of +/// a local function in `type_idx_of_fn`. +/// - `globals`: The global variables of the module. +/// +/// # Returns +/// pub fn validate_code_section( wasm: &mut WasmReader, section_header: SectionHeader, fn_types: &[FuncType], - type_idx_of_fn: &[usize], + type_idx_of_fn: &[TypeIdx], + num_imported_funcs: usize, globals: &[Global], ) -> Result> { assert_eq!(section_header.ty, SectionTy::Code); let code_block_spans = wasm.read_vec_enumerated(|wasm, idx| { - let ty_idx = type_idx_of_fn[idx]; + // We need to offset the index by the number of functions that were imported + let ty_idx = type_idx_of_fn[idx + num_imported_funcs]; let func_ty = fn_types[ty_idx].clone(); debug!("{:x?}", wasm.full_wasm_binary); @@ -39,7 +56,7 @@ pub fn validate_code_section( let mut stack = ValidationStack::new(); read_instructions( - idx, + idx + num_imported_funcs, wasm, &mut stack, &locals, diff --git a/src/validation/mod.rs b/src/validation/mod.rs index d6a52fa9..fc8b40bc 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -5,7 +5,7 @@ use crate::core::reader::section_header::{SectionHeader, SectionTy}; use crate::core::reader::span::Span; use crate::core::reader::types::export::Export; use crate::core::reader::types::global::Global; -use crate::core::reader::types::import::Import; +use crate::core::reader::types::import::{Import, ImportDesc}; use crate::core::reader::types::{FuncType, MemType, TableType}; use crate::core::reader::{WasmReadable, WasmReader}; use crate::{Error, Result}; @@ -73,10 +73,30 @@ pub fn validate(wasm: &[u8]) -> Result { while (skip_section(&mut wasm, &mut header)?).is_some() {} - let functions = handle_section(&mut wasm, &mut header, SectionTy::Function, |wasm, _| { - wasm.read_vec(|wasm| wasm.read_var_u32().map(|u| u as usize)) - })? - .unwrap_or_default(); + // The `Function` section only covers module-level (or "local") functions. Imported functions have their types known + // in the `import` section. Both local and imported functions share the same index space. + // + // Imported functions are given priority and have the first indicies, and only after that do the local functions get + // assigned their indices. + let local_functions = + handle_section(&mut wasm, &mut header, SectionTy::Function, |wasm, _| { + wasm.read_vec(|wasm| wasm.read_var_u32().map(|u| u as usize)) + })? + .unwrap_or_default(); + + let imported_functions = imports + .iter() + .filter_map(|import| match &import.desc { + ImportDesc::Func(type_idx) => Some(*type_idx), + _ => None, + }) + .collect::>(); + + let all_functions = imported_functions + .iter() + .chain(local_functions.iter()) + .cloned() + .collect::>(); while (skip_section(&mut wasm, &mut header)?).is_some() {} @@ -130,11 +150,22 @@ pub fn validate(wasm: &[u8]) -> Result { while (skip_section(&mut wasm, &mut header)?).is_some() {} let func_blocks = handle_section(&mut wasm, &mut header, SectionTy::Code, |wasm, h| { - code::validate_code_section(wasm, h, &types, &functions, &globals) + code::validate_code_section( + wasm, + h, + &types, + &all_functions, + imported_functions.len(), + &globals, + ) })? .unwrap_or_default(); - assert_eq!(func_blocks.len(), functions.len(), "these should be equal"); // TODO check if this is in the spec + assert_eq!( + func_blocks.len(), + local_functions.len(), + "these should be equal" + ); // TODO check if this is in the spec while (skip_section(&mut wasm, &mut header)?).is_some() {} @@ -154,7 +185,7 @@ pub fn validate(wasm: &[u8]) -> Result { wasm: wasm.into_inner(), types, imports, - functions, + functions: local_functions, tables, memories, globals, @@ -189,3 +220,12 @@ fn handle_section Result>( _ => Ok(None), } } + +impl ValidationInfo<'_> { + pub fn get_imported_funcs(&self) -> impl Iterator { + self.imports.iter().filter_map(|import| match &import.desc { + ImportDesc::Func(type_idx) => Some(type_idx), + _ => None, + }) + } +} diff --git a/tests/imports.rs b/tests/imports.rs new file mode 100644 index 00000000..68a3105f --- /dev/null +++ b/tests/imports.rs @@ -0,0 +1,82 @@ +use wasm::{validate, RuntimeError, RuntimeInstance, DEFAULT_MODULE}; + +const UNMET_IMPORTS: &str = r#" +(module + (import "env" "dummy1" (func (param i32))) + (import "env" "dummy2" (func (param i32))) + (func (export "get_three") (param) (result i32) + i32.const 1 + i32.const 2 + i32.add + ) +)"#; + +const SIMPLE_IMPORT_BASE: &str = r#" +(module + (import "env" "get_one" (func $get_one (param) (result i32))) + (func (export "get_three") (param) (result i32) + call $get_one + i32.const 2 + i32.add + ) +)"#; + +const SIMPLE_IMPORT_ADDON: &str = r#" +(module + (func (export "get_one") (param) (result i32) + i32.const 1 + ) +)"#; + +#[test_log::test] +pub fn unmet_imports() { + let wasm_bytes = wat::parse_str(UNMET_IMPORTS).unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let get_three = instance + .get_function_by_name(DEFAULT_MODULE, "get_three") + .unwrap(); + + assert_eq!( + RuntimeError::UnmetImport, + instance + .invoke::<(), i32>(&get_three, ()) + .expect_err("Expected invoke to fail due to unmet imports") + ); +} + +#[test_log::test] +pub fn compile_simple_import() { + let wasm_bytes = wat::parse_str(SIMPLE_IMPORT_BASE).unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = + RuntimeInstance::new_named("base", &validation_info).expect("instantiation failed"); + + let wasm_bytes = wat::parse_str(SIMPLE_IMPORT_ADDON).unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + instance.add_module("addon", &validation_info); + + // assert_eq!((), instance.invoke_named("print_three", ()).unwrap()); + // Function 0 should be the imported function + // assert_eq!((), instance.invoke_func(1, ()).unwrap()); +} + +#[test_log::test] +pub fn run_simple_import() { + let wasm_bytes = wat::parse_str(SIMPLE_IMPORT_BASE).unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = + RuntimeInstance::new_named("base", &validation_info).expect("instantiation failed"); + + let wasm_bytes = wat::parse_str(SIMPLE_IMPORT_ADDON).unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + instance.add_module("env", &validation_info); + + let get_three = instance.get_function_by_name("base", "get_three").unwrap(); + assert_eq!(3, instance.invoke(&get_three, ()).unwrap()); + + // Function 0 should be the imported function + let get_three = instance.get_function_by_index(0, 1).unwrap(); + assert_eq!(3, instance.invoke(&get_three, ()).unwrap()); +}