Skip to content

Commit

Permalink
chore: split FuncInst into local and imported variant. In preparation…
Browse files Browse the repository at this point in the history
… for linker

Signed-off-by: George Cosma <[email protected]>
  • Loading branch information
george-cosma committed Sep 26, 2024
1 parent 09c3435 commit ee05bf3
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 43 deletions.
12 changes: 10 additions & 2 deletions src/execution/interpreter_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ pub(super) fn run<H: HookSet>(
let func_inst = store
.funcs
.get(stack.current_stackframe().func_idx)
.unwrap_validated()
.try_into_local()
.unwrap_validated();

// Start reading the function's instructions
Expand Down Expand Up @@ -78,7 +80,7 @@ pub(super) fn run<H: HookSet>(
let func_to_call_idx = stack.current_stackframe().func_idx;

let func_to_call_inst = store.funcs.get(func_to_call_idx).unwrap_validated();
let func_to_call_ty = types.get(func_to_call_inst.ty).unwrap_validated();
let func_to_call_ty = types.get(func_to_call_inst.ty()).unwrap_validated();

let ret_vals = stack
.pop_tail_iter(func_to_call_ty.returns.valtypes.len())
Expand All @@ -99,7 +101,13 @@ pub(super) fn run<H: HookSet>(
CALL => {
let func_to_call_idx = wasm.read_var_u32().unwrap_validated() as FuncIdx;

let func_to_call_inst = store.funcs.get(func_to_call_idx).unwrap_validated();
// TODO: if it is imported, defer to linking
let func_to_call_inst = store
.funcs
.get(func_to_call_idx)
.unwrap_validated()
.try_into_local()
.expect("TODO: call imported functions");
let func_to_call_ty = types.get(func_to_call_inst.ty).unwrap_validated();

let params = stack.pop_tail_iter(func_to_call_ty.params.valtypes.len());
Expand Down
89 changes: 61 additions & 28 deletions src/execution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ use const_interpreter_loop::run_const;
use function_ref::FunctionRef;
use interpreter_loop::run;
use locals::Locals;
use store::{ImportedFuncInst, LocalFuncInst};
use value_stack::Stack;

use crate::core::reader::types::export::{Export, ExportDesc};
use crate::core::reader::types::import::ImportDesc;
use crate::core::reader::types::FuncType;
use crate::core::reader::types::{FuncType, ValType};
use crate::core::reader::WasmReader;
use crate::execution::assert_validated::UnwrapValidatedExt;
use crate::execution::hooks::{EmptyHookSet, HookSet};
Expand Down Expand Up @@ -154,15 +157,30 @@ where
let (_module_idx, func_idx) = self.verify_function_ref(function_ref)?;

// -=-= Verification =-=-
let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx");
trace!("{:?}", self.store.funcs);
let func_inst = self
.store
.funcs
.get(func_idx)
.ok_or(RuntimeError::FunctionNotFound)?
.try_into_local()
.ok_or(RuntimeError::FunctionNotFound)?;
let func_ty = self.types.get(func_inst.ty).unwrap_validated();

// Check correct function parameters and return types
if func_ty.params.valtypes != Param::TYS {
panic!("Invalid `Param` generics");
panic!(
"Invalid `Param` generics. Expected: {:?}, Found: {:?}",
func_ty.params.valtypes,
Param::TYS
);
}
if func_ty.returns.valtypes != Returns::TYS {
panic!("Invalid `Returns` generics");
panic!(
"Invalid `Returns` generics. Expected: {:?}, Found: {:?}",
func_ty.returns.valtypes,
Returns::TYS
);
}

// Prepare a new stack with the locals for the entry function
Expand Down Expand Up @@ -209,7 +227,13 @@ where
let (_module_idx, func_idx) = self.verify_function_ref(function_ref)?;

// -=-= Verification =-=-
let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx");
let func_inst = self
.store
.funcs
.get(func_idx)
.ok_or(RuntimeError::FunctionNotFound)?
.try_into_local()
.ok_or(RuntimeError::FunctionNotFound)?;
let func_ty = self.types.get(func_inst.ty).unwrap_validated();

// Verify that the given parameters match the function parameters
Expand Down Expand Up @@ -238,9 +262,6 @@ where
EmptyHookSet,
)?;

let func_inst = self.store.funcs.get(func_idx).expect("valid FuncIdx");
let func_ty = self.types.get(func_inst.ty).unwrap_validated();

// Pop return values from stack
let return_values = func_ty
.returns
Expand Down Expand Up @@ -317,28 +338,40 @@ where
let functions = validation_info.functions.iter();
let func_blocks = validation_info.func_blocks.iter();

functions
.zip(func_blocks)
.map(|(ty, func)| {
wasm_reader
.move_start_to(*func)
.expect("function index to be in the bounds of the WASM binary");

let (locals, bytes_read) = wasm_reader
.measure_num_read_bytes(read_declared_locals)
.unwrap_validated();

let code_expr = wasm_reader
.make_span(func.len() - bytes_read)
.expect("TODO remove this expect");

FuncInst {
ty: *ty,
locals,
code_expr,
}
let local_function_inst = functions.zip(func_blocks).map(|(ty, func)| {
wasm_reader
.move_start_to(*func)
.expect("function index to be in the bounds of the WASM binary");

let (locals, bytes_read) = wasm_reader
.measure_num_read_bytes(read_declared_locals)
.unwrap_validated();

let code_expr = wasm_reader
.make_span(func.len() - bytes_read)
.expect("TODO remove this expect");

FuncInst::Local(LocalFuncInst {
ty: *ty,
locals,
code_expr,
})
.collect()
});

let imported_function_inst =
validation_info
.imports
.iter()
.filter_map(|import| match &import.desc {
ImportDesc::Func(type_idx) => Some(FuncInst::Imported(ImportedFuncInst {
ty: *type_idx,
module_name: import.module_name.clone(),
function_name: import.name.clone(),
})),
_ => None,
});

imported_function_inst.chain(local_function_inst).collect()
};

let memory_instances: Vec<MemInst> = validation_info
Expand Down
39 changes: 38 additions & 1 deletion src/execution/store.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
use core::iter;
Expand All @@ -13,25 +14,60 @@ use crate::execution::value::{Ref, Value};
/// globals, element segments, and data segments that have been allocated during the life time of
/// the abstract machine.
/// <https://webassembly.github.io/spec/core/exec/runtime.html#store>
#[derive(Debug)]
pub struct Store {
pub funcs: Vec<FuncInst>,
// tables: Vec<TableInst>,
pub mems: Vec<MemInst>,
pub globals: Vec<GlobalInst>,
}

pub struct FuncInst {
#[derive(Debug)]
pub enum FuncInst {
Local(LocalFuncInst),
Imported(ImportedFuncInst),
}

impl FuncInst {
pub fn ty(&self) -> TypeIdx {
match self {
FuncInst::Local(f) => f.ty,
FuncInst::Imported(f) => f.ty,
}
}

pub fn try_into_local(&self) -> Option<&LocalFuncInst> {
match self {
FuncInst::Local(f) => Some(f),
FuncInst::Imported(_) => None,
}
}
}

#[derive(Debug)]
pub struct LocalFuncInst {
pub ty: TypeIdx,
pub locals: Vec<ValType>,
pub code_expr: Span,
}

#[derive(Debug)]
pub struct ImportedFuncInst {
pub ty: TypeIdx,
#[allow(dead_code)]
pub module_name: String,
#[allow(dead_code)]
pub function_name: String,
}

#[allow(dead_code)]
#[derive(Debug)]
pub struct TableInst {
pub ty: TableType,
pub elem: Vec<Ref>,
}

#[derive(Debug)]
pub struct MemInst {
#[allow(warnings)]
pub ty: MemType,
Expand Down Expand Up @@ -61,6 +97,7 @@ impl MemInst {
}
}

#[derive(Debug)]
pub struct GlobalInst {
pub global: Global,
/// Must be of the same type as specified in `ty`
Expand Down
25 changes: 21 additions & 4 deletions src/validation/code.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use alloc::vec::Vec;
use core::iter;

use crate::core::indices::{FuncIdx, GlobalIdx, LocalIdx};
use crate::core::indices::{FuncIdx, GlobalIdx, LocalIdx, TypeIdx};
use crate::core::reader::section_header::{SectionHeader, SectionTy};
use crate::core::reader::span::Span;
use crate::core::reader::types::global::Global;
Expand All @@ -11,17 +11,34 @@ use crate::core::reader::{WasmReadable, WasmReader};
use crate::validation_stack::ValidationStack;
use crate::{Error, Result};

/// <todo! summary>
///
/// # Arguments
/// - `wasm`: The reader over the whole wasm binary. It is expected to be at the beginning of the code section, and
/// after execution it will be at the beginning of the next section if the result is `Ok(...)`.
/// - `section_header`: The header of the code section.
/// - `fn_types`: The types of all functions in the module, including imported functions.
/// - `type_idx_of_fn`: The index of the type of each function in `fn_types`, including imported functions. As per the
/// specification, the indicies of the type of imported functions come first.
/// - `num_imported_funcs`: The number of imported functions. This is used as an offset, to determine the first index of
/// a local function in `type_idx_of_fn`.
/// - `globals`: The global variables of the module.
///
/// # Returns
/// <todo! returns>
pub fn validate_code_section(
wasm: &mut WasmReader,
section_header: SectionHeader,
fn_types: &[FuncType],
type_idx_of_fn: &[usize],
type_idx_of_fn: &[TypeIdx],
num_imported_funcs: usize,
globals: &[Global],
) -> Result<Vec<Span>> {
assert_eq!(section_header.ty, SectionTy::Code);

let code_block_spans = wasm.read_vec_enumerated(|wasm, idx| {
let ty_idx = type_idx_of_fn[idx];
// We need to offset the index by the number of functions that were imported
let ty_idx = type_idx_of_fn[idx + num_imported_funcs];
let func_ty = fn_types[ty_idx].clone();

debug!("{:x?}", wasm.full_wasm_binary);
Expand All @@ -39,7 +56,7 @@ pub fn validate_code_section(
let mut stack = ValidationStack::new();

read_instructions(
idx,
idx + num_imported_funcs,
wasm,
&mut stack,
&locals,
Expand Down
56 changes: 48 additions & 8 deletions src/validation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::core::reader::section_header::{SectionHeader, SectionTy};
use crate::core::reader::span::Span;
use crate::core::reader::types::export::Export;
use crate::core::reader::types::global::Global;
use crate::core::reader::types::import::Import;
use crate::core::reader::types::import::{Import, ImportDesc};
use crate::core::reader::types::{FuncType, MemType, TableType};
use crate::core::reader::{WasmReadable, WasmReader};
use crate::{Error, Result};
Expand Down Expand Up @@ -73,10 +73,30 @@ pub fn validate(wasm: &[u8]) -> Result<ValidationInfo> {

while (skip_section(&mut wasm, &mut header)?).is_some() {}

let functions = handle_section(&mut wasm, &mut header, SectionTy::Function, |wasm, _| {
wasm.read_vec(|wasm| wasm.read_var_u32().map(|u| u as usize))
})?
.unwrap_or_default();
// The `Function` section only covers module-level (or "local") functions. Imported functions have their types known
// in the `import` section. Both local and imported functions share the same index space.
//
// Imported functions are given priority and have the first indicies, and only after that do the local functions get
// assigned their indices.
let local_functions =
handle_section(&mut wasm, &mut header, SectionTy::Function, |wasm, _| {
wasm.read_vec(|wasm| wasm.read_var_u32().map(|u| u as usize))
})?
.unwrap_or_default();

let imported_functions = imports
.iter()
.filter_map(|import| match &import.desc {
ImportDesc::Func(type_idx) => Some(*type_idx),
_ => None,
})
.collect::<Vec<_>>();

let all_functions = imported_functions
.iter()
.chain(local_functions.iter())
.cloned()
.collect::<Vec<TypeIdx>>();

while (skip_section(&mut wasm, &mut header)?).is_some() {}

Expand Down Expand Up @@ -130,11 +150,22 @@ pub fn validate(wasm: &[u8]) -> Result<ValidationInfo> {
while (skip_section(&mut wasm, &mut header)?).is_some() {}

let func_blocks = handle_section(&mut wasm, &mut header, SectionTy::Code, |wasm, h| {
code::validate_code_section(wasm, h, &types, &functions, &globals)
code::validate_code_section(
wasm,
h,
&types,
&all_functions,
imported_functions.len(),
&globals,
)
})?
.unwrap_or_default();

assert_eq!(func_blocks.len(), functions.len(), "these should be equal"); // TODO check if this is in the spec
assert_eq!(
func_blocks.len(),
local_functions.len(),
"these should be equal"
); // TODO check if this is in the spec

while (skip_section(&mut wasm, &mut header)?).is_some() {}

Expand All @@ -154,7 +185,7 @@ pub fn validate(wasm: &[u8]) -> Result<ValidationInfo> {
wasm: wasm.into_inner(),
types,
imports,
functions,
functions: local_functions,
tables,
memories,
globals,
Expand Down Expand Up @@ -189,3 +220,12 @@ fn handle_section<T, F: FnOnce(&mut WasmReader, SectionHeader) -> Result<T>>(
_ => Ok(None),
}
}

impl ValidationInfo<'_> {
pub fn get_imported_funcs(&self) -> impl Iterator<Item = &TypeIdx> {
self.imports.iter().filter_map(|import| match &import.desc {
ImportDesc::Func(type_idx) => Some(type_idx),
_ => None,
})
}
}
Loading

0 comments on commit ee05bf3

Please sign in to comment.