From dff4ddfe073940d0b9136ee808a48579a9c40f03 Mon Sep 17 00:00:00 2001 From: Cem Onem Date: Fri, 6 Dec 2024 15:47:01 +0100 Subject: [PATCH] feat: br_if, return validation/execution, throw correct errors --- src/execution/interpreter_loop.rs | 83 +++++------- src/execution/value_stack.rs | 1 + src/validation/code.rs | 177 ++++++++++++++++--------- src/validation/validation_stack.rs | 17 ++- tests/structured_control_flow/block.rs | 2 +- 5 files changed, 166 insertions(+), 114 deletions(-) diff --git a/src/execution/interpreter_loop.rs b/src/execution/interpreter_loop.rs index a55f74ae..3eae9c59 100644 --- a/src/execution/interpreter_loop.rs +++ b/src/execution/interpreter_loop.rs @@ -102,62 +102,28 @@ pub(super) fn run( .unwrap_validated() .sidetable; } - BR => { - //skip n of BR n + BR_IF => { wasm.read_var_u32().unwrap_validated(); - let sidetable_entry = ¤t_sidetable[stp]; - - // TODO fix this corner cutting implementation - let jump_vals = stack - .pop_tail_iter(sidetable_entry.valcnt) - .collect::>(); - stack.pop_n_values(sidetable_entry.popcnt); + let c: i32 = stack.pop_value(ValType::NumType(NumType::I32)).into(); - for val in jump_vals { - stack.push_value(val); + if c != 0 { + do_sidetable_control_transfer(&mut wasm, stack, &mut stp, current_sidetable); + } else { + stp += 1; } - - // TODO ugly - stp = (stp as isize + sidetable_entry.delta_stp) - .try_into() - .unwrap_validated(); - wasm.pc = (wasm.pc as isize + sidetable_entry.delta_pc) - .try_into() - .unwrap_validated(); + } + BR => { + //skip n of BR n + wasm.read_var_u32().unwrap_validated(); + do_sidetable_control_transfer(&mut wasm, stack, &mut stp, current_sidetable); } BLOCK => { BlockType::read_unvalidated(&mut wasm); } RETURN => { - trace!("returning from function"); - - let func_to_call_idx = stack.current_stackframe().func_idx; - - let func_to_call_inst = store.funcs.get(func_to_call_idx).unwrap_validated(); - let func_to_call_ty = types.get(func_to_call_inst.ty).unwrap_validated(); - - let ret_vals = stack - .pop_tail_iter(func_to_call_ty.returns.valtypes.len()) - .collect::>(); - stack.clear_callframe_values(); - - for val in ret_vals { - stack.push_value(val); - } - - if stack.callframe_count() == 1 { - break; - } - - trace!("end of function reached, returning to previous stack frame"); - (wasm.pc, stp) = stack.pop_stackframe(); - - current_sidetable = &store - .funcs - .get(stack.current_stackframe().func_idx) - .unwrap_validated() - .sidetable; + //same as BR, except no need to skip n of BR n + do_sidetable_control_transfer(&mut wasm, stack, &mut stp, current_sidetable); } CALL => { let func_to_call_idx = wasm.read_var_u32().unwrap_validated() as FuncIdx; @@ -2232,3 +2198,26 @@ pub(super) fn run( } Ok(()) } + +//helper function for avoiding code duplication at intraprocedural jumps +fn do_sidetable_control_transfer( + wasm: &mut WasmReader, + stack: &mut Stack, + current_stp: &mut usize, + current_sidetable: &Sidetable, +) { + let sidetable_entry = ¤t_sidetable[*current_stp]; + + // TODO fix this corner cutting implementation + let jump_vals = stack + .pop_tail_iter(sidetable_entry.valcnt) + .collect::>(); + stack.pop_n_values(sidetable_entry.popcnt); + + for val in jump_vals { + stack.push_value(val); + } + + *current_stp = (*current_stp as isize + sidetable_entry.delta_stp) as usize; + wasm.pc = (wasm.pc as isize + sidetable_entry.delta_pc) as usize; +} diff --git a/src/execution/value_stack.rs b/src/execution/value_stack.rs index 3d0c5127..056a9e72 100644 --- a/src/execution/value_stack.rs +++ b/src/execution/value_stack.rs @@ -186,6 +186,7 @@ impl Stack { } /// Clear all of the values pushed to the value stack by the current stack frame + #[allow(unused)] // TODO remove this once sidetable implementation lands pub fn clear_callframe_values(&mut self) { self.values .truncate(self.current_stackframe().value_stack_base_idx); diff --git a/src/validation/code.rs b/src/validation/code.rs index 5be4f1ca..1146b1a9 100644 --- a/src/validation/code.rs +++ b/src/validation/code.rs @@ -42,7 +42,6 @@ pub fn validate_code_section( let mut sidetable: Sidetable = Sidetable::default(); read_instructions( - idx, wasm, &mut stack, &mut sidetable, @@ -89,9 +88,107 @@ pub fn read_declared_locals(wasm: &mut WasmReader) -> Result> { Ok(locals) } +//helper function to avoid code duplication in jump validations +//the entries, except for the loop label, need to be correctly backpatched later +//the temporary values of fields (delta_pc, delta_stp) of the entries are the (ip, stp) of the relevant label +//the label is also updated with the additional information of the index of this sidetable +//entry itself so that the entry can be backpatched when the end instruction of the label +//is hit. +fn generate_unbackpatched_sidetable_entry( + wasm: &WasmReader, + sidetable: &mut Sidetable, + valcnt: usize, + popcnt: usize, + label_info: &mut LabelInfo, +) { + let stp_here = sidetable.len(); + + sidetable.push(SidetableEntry { + delta_pc: wasm.pc as isize, + delta_stp: stp_here as isize, + popcnt, + valcnt, + }); + + match label_info { + LabelInfo::Block { stps_to_backpatch } => stps_to_backpatch.push(stp_here), + LabelInfo::Loop { ip, stp } => { + //we already know where to jump to for loops + sidetable[stp_here].delta_pc = *ip as isize - wasm.pc as isize; + sidetable[stp_here].delta_stp = *stp as isize - stp_here as isize; + } + LabelInfo::If { + stps_to_backpatch, .. + } => stps_to_backpatch.push(stp_here), + LabelInfo::Func { stps_to_backpatch } => stps_to_backpatch.push(stp_here), + LabelInfo::Untyped => { + unreachable!("this label is for untyped wasm sequences") + } + } +} + +//helper function to avoid code duplication for common stuff in br, return +fn validate_unconditional_jump_and_generate_sidetable_entry( + wasm: &WasmReader, + label_idx: usize, + stack: &mut ValidationStack, + sidetable: &mut Sidetable, +) -> Result<()> { + let ctrl_stack_len = stack.ctrl_stack.len(); + + stack.assert_val_types_of_label_jump_types_on_top(label_idx)?; + + let targeted_ctrl_block_entry = stack + .ctrl_stack + .get(ctrl_stack_len - label_idx - 1) + .ok_or(Error::InvalidLabelIdx(label_idx))?; + + let valcnt = targeted_ctrl_block_entry.label_types().len(); + let popcnt = stack.len() - targeted_ctrl_block_entry.height - valcnt; + + let label_info = &mut stack + .ctrl_stack + .get_mut(ctrl_stack_len - label_idx - 1) + .unwrap() + .label_info; + + generate_unbackpatched_sidetable_entry(wasm, sidetable, valcnt, popcnt, label_info); + + stack.make_unspecified() +} + +//helper function to avoid code duplication for common stuff in if, else, br_if +fn validate_conditional_jump_and_generate_sidetable_entry( + wasm: &WasmReader, + label_idx: usize, + stack: &mut ValidationStack, + sidetable: &mut Sidetable, +) -> Result<()> { + let ctrl_stack_len = stack.ctrl_stack.len(); + + stack.assert_val_types_of_label_jump_types(label_idx)?; + + let targeted_ctrl_block_entry = stack + .ctrl_stack + .get(ctrl_stack_len - label_idx - 1) + .ok_or(Error::InvalidLabelIdx(label_idx))?; + + let valcnt = targeted_ctrl_block_entry.label_types().len(); + let popcnt = 0; //otherwise the above assert would fail. + + let label_info = &mut stack + .ctrl_stack + .get_mut(ctrl_stack_len - label_idx - 1) + .unwrap() + .label_info; + + generate_unbackpatched_sidetable_entry(wasm, sidetable, valcnt, popcnt, label_info); + + Ok(()) +} + #[allow(clippy::too_many_arguments)] fn read_instructions( - this_function_idx: usize, wasm: &mut WasmReader, stack: &mut ValidationStack, sidetable: &mut Sidetable, @@ -129,47 +226,16 @@ fn read_instructions( } BR => { let label_idx = wasm.read_var_u32()? as LabelIdx; - let ctrl_stack_len = stack.ctrl_stack.len(); - - stack.assert_val_types_of_label_jump_types_on_top(label_idx)?; - - let targeted_ctrl_block_entry = stack - .ctrl_stack - .get(ctrl_stack_len - label_idx - 1) - .ok_or(Error::InvalidLabelIdx(label_idx))?; - - let valcnt = targeted_ctrl_block_entry.label_types().len(); - let popcnt = stack.len() - targeted_ctrl_block_entry.height - valcnt; - - let targeted_ctrl_block_entry = stack - .ctrl_stack - .get_mut(ctrl_stack_len - label_idx - 1) - .unwrap(); - - let stp_here = sidetable.len(); - - sidetable.push(SidetableEntry { - delta_pc: wasm.pc as isize, - delta_stp: stp_here as isize, - popcnt, - valcnt, - }); - - match &mut targeted_ctrl_block_entry.label_info { - LabelInfo::Block { stps_to_backpatch } => stps_to_backpatch.push(stp_here), - LabelInfo::Loop { .. } => { - todo!("implement loop") - } - LabelInfo::If { .. } => { - todo!("implement if") - } - LabelInfo::Func { stps_to_backpatch } => stps_to_backpatch.push(stp_here), - LabelInfo::Untyped => { - unreachable!("this label is for untyped wasm sequences") - } - } - - stack.make_unspecified()?; + validate_unconditional_jump_and_generate_sidetable_entry( + wasm, label_idx, stack, sidetable, + )?; + } + BR_IF => { + let label_idx = wasm.read_var_u32()? as LabelIdx; + stack.assert_pop_val_type(ValType::NumType(NumType::I32))?; + validate_conditional_jump_and_generate_sidetable_entry( + wasm, label_idx, stack, sidetable, + )?; } // end END => { @@ -213,27 +279,10 @@ fn read_instructions( } } RETURN => { - let this_func_ty = &fn_types[type_idx_of_fn[this_function_idx]]; - - stack - .assert_val_types_on_top(&this_func_ty.returns.valtypes) - .map_err(|_| Error::EndInvalidValueStack)?; - - stack.make_unspecified()?; - - // TODO(george-cosma): a `return Ok(());` should probably be introduced here, but since we don't have - // controls flows implemented, the only way to test `return` is to place it at the end of function. - // However, an `end` is introduced after it, which is invalid. Compilation for this test case should - // probably fail. - - // TODO(wucke13) I believe we must not drain the validation stack here; only if we - // know this return is actually taken during execution we may drain the stack. This - // could however be a conditional return (return in an `if`), and the other side - // past the `else` might need the values on the `ValidationStack` that do belong - // to the current function (but not the current block), so draining would make - // continued validation of the current function impossible. We should most - // definitely not `return Ok(())` here, because there might be still more of the - // current function to validate. + let label_idx = stack.ctrl_stack.len() - 1; // return behaves the same as br + validate_unconditional_jump_and_generate_sidetable_entry( + wasm, label_idx, stack, sidetable, + )?; } // call [t1*] -> [t2*] CALL => { diff --git a/src/validation/validation_stack.rs b/src/validation/validation_stack.rs index 5469fe08..f1029b1b 100644 --- a/src/validation/validation_stack.rs +++ b/src/validation/validation_stack.rs @@ -158,14 +158,14 @@ impl ValidationStack { match actual_ty { ValidationStackEntry::Val(actual_val_ty) => { if *actual_val_ty != *expected_ty { - return Err(Error::InvalidValidationStackValType(Some(*actual_val_ty))); + return Err(Error::EndInvalidValueStack); } } ValidationStackEntry::NumOrVecType => match expected_ty { // unify the NumOrVecType to expected_ty ValType::NumType(_) => *actual_ty = ValidationStackEntry::Val(*expected_ty), ValType::VecType => *actual_ty = ValidationStackEntry::Val(*expected_ty), - _ => return Err(Error::InvalidValidationStackValType(None)), + _ => return Err(Error::EndInvalidValueStack), }, ValidationStackEntry::UnspecifiedValTypes => { unreachable!("bottom type should not exist in the stack") @@ -248,6 +248,19 @@ impl ValidationStack { ) } + pub fn assert_val_types_of_label_jump_types(&mut self, label_idx: usize) -> Result<()> { + let label_types = self + .ctrl_stack + .get(self.ctrl_stack.len() - label_idx - 1) + .ok_or(Error::InvalidLabelIdx(label_idx))? + .label_types(); + ValidationStack::assert_val_types_with_custom_stacks( + &mut self.stack, + &self.ctrl_stack, + label_types, + ) + } + // TODO is moving block_ty ok? pub fn assert_push_ctrl(&mut self, label_info: LabelInfo, block_ty: FuncType) -> Result<()> { self.assert_val_types_on_top(&block_ty.params.valtypes)?; diff --git a/tests/structured_control_flow/block.rs b/tests/structured_control_flow/block.rs index bb9be985..24257cc1 100644 --- a/tests/structured_control_flow/block.rs +++ b/tests/structured_control_flow/block.rs @@ -119,7 +119,7 @@ fn param_and_result() { let wasm_bytes = wat::parse_str( r#" (module - (func (export "add_one") (param $x i32) (result) + (func (export "add_one") (param $x i32) (result i32) local.get $x (block $my_block (param i32) (result i32) i32.const 1