diff --git a/tests/compiler/venom/test_duplicate_operands.py b/tests/compiler/venom/test_duplicate_operands.py new file mode 100644 index 0000000000..505f01e31b --- /dev/null +++ b/tests/compiler/venom/test_duplicate_operands.py @@ -0,0 +1,28 @@ +from vyper.compiler.settings import OptimizationLevel +from vyper.venom import generate_assembly_experimental +from vyper.venom.basicblock import IRLiteral +from vyper.venom.function import IRFunction + + +def test_duplicate_operands(): + """ + Test the duplicate operands code generation. + The venom code: + + %1 = 10 + %2 = add %1, %1 + %3 = mul %1, %2 + stop + + Should compile to: [PUSH1, 10, DUP1, DUP1, DUP1, ADD, MUL, STOP] + """ + ctx = IRFunction() + + op = ctx.append_instruction("store", [IRLiteral(10)]) + sum = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("mul", [sum, op]) + ctx.append_instruction("stop", [], False) + + asm = generate_assembly_experimental(ctx, OptimizationLevel.CODESIZE) + + assert asm == ["PUSH1", 10, "DUP1", "DUP1", "DUP1", "ADD", "MUL", "STOP", "REVERT"] diff --git a/tests/compiler/venom/test_multi_entry_block.py b/tests/compiler/venom/test_multi_entry_block.py new file mode 100644 index 0000000000..bb57fa1065 --- /dev/null +++ b/tests/compiler/venom/test_multi_entry_block.py @@ -0,0 +1,96 @@ +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRLiteral +from vyper.venom.function import IRBasicBlock, IRFunction, IRLabel +from vyper.venom.passes.normalization import NormalizationPass + + +def test_multi_entry_block_1(): + ctx = IRFunction() + + finish_label = IRLabel("finish") + target_label = IRLabel("target") + block_1_label = IRLabel("block_1", ctx) + + op = ctx.append_instruction("store", [IRLiteral(10)]) + acc = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + + block_1 = IRBasicBlock(block_1_label, ctx) + ctx.append_basic_block(block_1) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + + target_bb = IRBasicBlock(target_label, ctx) + ctx.append_basic_block(target_bb) + ctx.append_instruction("mul", [acc, acc]) + ctx.append_instruction("jmp", [finish_label], False) + + finish_bb = IRBasicBlock(finish_label, ctx) + ctx.append_basic_block(finish_bb) + ctx.append_instruction("stop", [], False) + + calculate_cfg(ctx) + assert not ctx.normalized, "CFG should not be normalized" + + NormalizationPass.run_pass(ctx) + + assert ctx.normalized, "CFG should be normalized" + + finish_bb = ctx.get_basic_block(finish_label.value) + cfg_in = list(finish_bb.cfg_in.keys()) + assert cfg_in[0].label.value == "target", "Should contain target" + assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + + +# more complicated one +def test_multi_entry_block_2(): + ctx = IRFunction() + + finish_label = IRLabel("finish") + target_label = IRLabel("target") + block_1_label = IRLabel("block_1", ctx) + block_2_label = IRLabel("block_2", ctx) + + op = ctx.append_instruction("store", [IRLiteral(10)]) + acc = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + + block_1 = IRBasicBlock(block_1_label, ctx) + ctx.append_basic_block(block_1) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + ctx.append_instruction("jnz", [acc, target_label, finish_label], False) + + block_2 = IRBasicBlock(block_2_label, ctx) + ctx.append_basic_block(block_2) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + # switch the order of the labels, for fun + ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + + target_bb = IRBasicBlock(target_label, ctx) + ctx.append_basic_block(target_bb) + ctx.append_instruction("mul", [acc, acc]) + ctx.append_instruction("jmp", [finish_label], False) + + finish_bb = IRBasicBlock(finish_label, ctx) + ctx.append_basic_block(finish_bb) + ctx.append_instruction("stop", [], False) + + calculate_cfg(ctx) + assert not ctx.normalized, "CFG should not be normalized" + + NormalizationPass.run_pass(ctx) + + assert ctx.normalized, "CFG should be normalized" + + finish_bb = ctx.get_basic_block(finish_label.value) + cfg_in = list(finish_bb.cfg_in.keys()) + assert cfg_in[0].label.value == "target", "Should contain target" + assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" diff --git a/tests/compiler/venom/test_stack_at_external_return.py b/tests/compiler/venom/test_stack_at_external_return.py new file mode 100644 index 0000000000..be9fa66e9a --- /dev/null +++ b/tests/compiler/venom/test_stack_at_external_return.py @@ -0,0 +1,5 @@ +def test_stack_at_external_return(): + """ + TODO: USE BOA DO GENERATE THIS TEST + """ + pass diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 82eba63f32..ca1792384e 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -141,6 +141,11 @@ def _parse_args(argv): "-p", help="Set the root path for contract imports", default=".", dest="root_folder" ) parser.add_argument("-o", help="Set the output path", dest="output_path") + parser.add_argument( + "--experimental-codegen", + help="The compiler use the new IR codegen. This is an experimental feature.", + action="store_true", + ) args = parser.parse_args(argv) @@ -188,6 +193,7 @@ def _parse_args(argv): settings, args.storage_layout, args.no_bytecode_metadata, + args.experimental_codegen, ) if args.output_path: @@ -225,6 +231,7 @@ def compile_files( settings: Optional[Settings] = None, storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, + experimental_codegen: bool = False, ) -> dict: root_path = Path(root_folder).resolve() if not root_path.exists(): @@ -275,6 +282,7 @@ def compile_files( storage_layout_override=storage_layout_override, show_gas_estimates=show_gas_estimates, no_bytecode_metadata=no_bytecode_metadata, + experimental_codegen=experimental_codegen, ) ret[file_path] = output diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 1d24b6c6dd..c48f1256c3 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -162,5 +162,9 @@ def generate_ir_for_function( # (note: internal functions do not need to adjust gas estimate since mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore + ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore + ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + else: + ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore return ret diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index 228191e3ca..cf01dbdab4 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -68,4 +68,6 @@ def generate_ir_for_internal_function( ["seq"] + nonreentrant_post + [["exit_to", "return_pc"]], ] - return IRnode.from_list(["seq", body, cleanup_routine]) + ir_node = IRnode.from_list(["seq", body, cleanup_routine]) + ir_node.passthrough_metadata["func_t"] = func_t + return ir_node diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index e17ef47c8f..ce26066968 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -171,6 +171,10 @@ class IRnode: valency: int args: List["IRnode"] value: Union[str, int] + is_self_call: bool + passthrough_metadata: dict[str, Any] + func_ir: Any + common_ir: Any def __init__( self, @@ -184,6 +188,8 @@ def __init__( mutable: bool = True, add_gas_estimate: int = 0, encoding: Encoding = Encoding.VYPER, + is_self_call: bool = False, + passthrough_metadata: dict[str, Any] = None, ): if args is None: args = [] @@ -201,6 +207,10 @@ def __init__( self.add_gas_estimate = add_gas_estimate self.encoding = encoding self.as_hex = AS_HEX_DEFAULT + self.is_self_call = is_self_call + self.passthrough_metadata = passthrough_metadata or {} + self.func_ir = None + self.common_ir = None assert self.value is not None, "None is not allowed as IRnode value" @@ -585,6 +595,8 @@ def from_list( error_msg: Optional[str] = None, mutable: bool = True, add_gas_estimate: int = 0, + is_self_call: bool = False, + passthrough_metadata: dict[str, Any] = None, encoding: Encoding = Encoding.VYPER, ) -> "IRnode": if isinstance(typ, str): @@ -617,6 +629,8 @@ def from_list( source_pos=source_pos, encoding=encoding, error_msg=error_msg, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) else: return cls( @@ -630,4 +644,6 @@ def from_list( add_gas_estimate=add_gas_estimate, encoding=encoding, error_msg=error_msg, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) diff --git a/vyper/codegen/return_.py b/vyper/codegen/return_.py index 56bea2b8da..41fa11ab56 100644 --- a/vyper/codegen/return_.py +++ b/vyper/codegen/return_.py @@ -40,7 +40,9 @@ def finalize(fill_return_buffer): cleanup_loops = "cleanup_repeat" if context.forvars else "seq" # NOTE: because stack analysis is incomplete, cleanup_repeat must # come after fill_return_buffer otherwise the stack will break - return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit]) + jump_to_exit_ir = IRnode.from_list(jump_to_exit) + jump_to_exit_ir.passthrough_metadata["func_t"] = func_t + return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit_ir]) if context.return_type is None: if context.is_internal: diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index c320e6889c..f03f2eb9c8 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -121,4 +121,6 @@ def ir_for_self_call(stmt_expr, context): add_gas_estimate=func_t._ir_info.gas_estimate, ) o.is_self_call = True + o.passthrough_metadata["func_t"] = func_t + o.passthrough_metadata["args_ir"] = args_ir return o diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 62ea05b243..61d7a7c229 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -55,6 +55,7 @@ def compile_code( no_bytecode_metadata: bool = False, show_gas_estimates: bool = False, exc_handler: Optional[Callable] = None, + experimental_codegen: bool = False, ) -> dict: """ Generate consumable compiler output(s) from a single contract source code. @@ -104,6 +105,7 @@ def compile_code( storage_layout_override, show_gas_estimates, no_bytecode_metadata, + experimental_codegen, ) ret = {} diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index bfbb336d54..4e32812fee 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -16,6 +16,7 @@ from vyper.semantics import set_data_positions, validate_semantics from vyper.semantics.types.function import ContractFunctionT from vyper.typing import StorageLayout +from vyper.venom import generate_assembly_experimental, generate_ir DEFAULT_CONTRACT_NAME = PurePath("VyperContract.vy") @@ -60,6 +61,7 @@ def __init__( storage_layout: StorageLayout = None, show_gas_estimates: bool = False, no_bytecode_metadata: bool = False, + experimental_codegen: bool = False, ) -> None: """ Initialization method. @@ -78,14 +80,18 @@ def __init__( Show gas estimates for abi and ir output modes no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False + experimental_codegen: bool, optional + Use experimental codegen. Defaults to False """ + # to force experimental codegen, uncomment: + # experimental_codegen = True self.contract_path = contract_path self.source_code = source_code self.source_id = source_id self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata - + self.experimental_codegen = experimental_codegen self.settings = settings or Settings() self.input_bundle = input_bundle or FilesystemInputBundle([Path(".")]) @@ -160,7 +166,11 @@ def global_ctx(self) -> GlobalContext: @cached_property def _ir_output(self): # fetch both deployment and runtime IR - return generate_ir_nodes(self.global_ctx, self.settings.optimize) + nodes = generate_ir_nodes(self.global_ctx, self.settings.optimize) + if self.experimental_codegen: + return [generate_ir(nodes[0]), generate_ir(nodes[1])] + else: + return nodes @property def ir_nodes(self) -> IRnode: @@ -183,11 +193,21 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: @cached_property def assembly(self) -> list: - return generate_assembly(self.ir_nodes, self.settings.optimize) + if self.experimental_codegen: + return generate_assembly_experimental( + self.ir_nodes, self.settings.optimize # type: ignore + ) + else: + return generate_assembly(self.ir_nodes, self.settings.optimize) @cached_property def assembly_runtime(self) -> list: - return generate_assembly(self.ir_runtime, self.settings.optimize) + if self.experimental_codegen: + return generate_assembly_experimental( + self.ir_runtime, self.settings.optimize # type: ignore + ) + else: + return generate_assembly(self.ir_runtime, self.settings.optimize) @cached_property def bytecode(self) -> bytes: diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 1c4dc1ef7c..1d3df8becb 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -9,6 +9,7 @@ from vyper.compiler.settings import OptimizationLevel from vyper.evm.opcodes import get_opcodes, version_check from vyper.exceptions import CodegenPanic, CompilerPanic +from vyper.ir.optimizer import COMMUTATIVE_OPS from vyper.utils import MemoryPositions from vyper.version import version_tuple @@ -164,7 +165,7 @@ def _add_postambles(asm_ops): # insert the postambles *before* runtime code # so the data section of the runtime code can't bork the postambles. runtime = None - if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], _RuntimeHeader): + if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], RuntimeHeader): runtime = asm_ops.pop() # for some reason there might not be a STOP at the end of asm_ops. @@ -229,7 +230,7 @@ def compile_to_assembly(code, optimize=OptimizationLevel.GAS): _relocate_segments(res) if optimize != OptimizationLevel.NONE: - _optimize_assembly(res) + optimize_assembly(res) return res @@ -531,7 +532,7 @@ def _height_of(witharg): # since the asm data structures are very primitive, to make sure # assembly_to_evm is able to calculate data offsets correctly, # we pass the memsize via magic opcodes to the subcode - subcode = [_RuntimeHeader(runtime_begin, memsize, immutables_len)] + subcode + subcode = [RuntimeHeader(runtime_begin, memsize, immutables_len)] + subcode # append the runtime code after the ctor code # `append(...)` call here is intentional. @@ -675,7 +676,7 @@ def _height_of(witharg): ) elif code.value == "data": - data_node = [_DataHeader("_sym_" + code.args[0].value)] + data_node = [DataHeader("_sym_" + code.args[0].value)] for c in code.args[1:]: if isinstance(c.value, int): @@ -837,6 +838,31 @@ def _prune_inefficient_jumps(assembly): return changed +def _optimize_inefficient_jumps(assembly): + # optimize sequences `_sym_common JUMPI _sym_x JUMP _sym_common JUMPDEST` + # to `ISZERO _sym_x JUMPI _sym_common JUMPDEST` + changed = False + i = 0 + while i < len(assembly) - 6: + if ( + is_symbol(assembly[i]) + and assembly[i + 1] == "JUMPI" + and is_symbol(assembly[i + 2]) + and assembly[i + 3] == "JUMP" + and assembly[i] == assembly[i + 4] + and assembly[i + 5] == "JUMPDEST" + ): + changed = True + assembly[i] = "ISZERO" + assembly[i + 1] = assembly[i + 2] + assembly[i + 2] = "JUMPI" + del assembly[i + 3 : i + 4] + else: + i += 1 + + return changed + + def _merge_jumpdests(assembly): # When we have multiple JUMPDESTs in a row, or when a JUMPDEST # is immediately followed by another JUMP, we can skip the @@ -938,7 +964,7 @@ def _prune_unused_jumpdests(assembly): used_jumpdests.add(assembly[i]) for item in assembly: - if isinstance(item, list) and isinstance(item[0], _DataHeader): + if isinstance(item, list) and isinstance(item[0], DataHeader): # add symbols used in data sections as they are likely # used for a jumptable. for t in item: @@ -961,6 +987,12 @@ def _stack_peephole_opts(assembly): changed = False i = 0 while i < len(assembly) - 2: + if assembly[i : i + 3] == ["DUP1", "SWAP2", "SWAP1"]: + changed = True + del assembly[i + 2] + assembly[i] = "SWAP1" + assembly[i + 1] = "DUP2" + continue # usually generated by with statements that return their input like # (with x (...x)) if assembly[i : i + 3] == ["DUP1", "SWAP1", "POP"]: @@ -975,16 +1007,22 @@ def _stack_peephole_opts(assembly): changed = True del assembly[i] continue + if assembly[i : i + 2] == ["SWAP1", "SWAP1"]: + changed = True + del assembly[i : i + 2] + if assembly[i] == "SWAP1" and assembly[i + 1].lower() in COMMUTATIVE_OPS: + changed = True + del assembly[i] i += 1 return changed # optimize assembly, in place -def _optimize_assembly(assembly): +def optimize_assembly(assembly): for x in assembly: - if isinstance(x, list) and isinstance(x[0], _RuntimeHeader): - _optimize_assembly(x) + if isinstance(x, list) and isinstance(x[0], RuntimeHeader): + optimize_assembly(x) for _ in range(1024): changed = False @@ -993,6 +1031,7 @@ def _optimize_assembly(assembly): changed |= _merge_iszero(assembly) changed |= _merge_jumpdests(assembly) changed |= _prune_inefficient_jumps(assembly) + changed |= _optimize_inefficient_jumps(assembly) changed |= _prune_unused_jumpdests(assembly) changed |= _stack_peephole_opts(assembly) @@ -1021,7 +1060,7 @@ def adjust_pc_maps(pc_maps, ofst): def _data_to_evm(assembly, symbol_map): ret = bytearray() - assert isinstance(assembly[0], _DataHeader) + assert isinstance(assembly[0], DataHeader) for item in assembly[1:]: if is_symbol(item): symbol = symbol_map[item].to_bytes(SYMBOL_SIZE, "big") @@ -1039,7 +1078,7 @@ def _data_to_evm(assembly, symbol_map): # predict what length of an assembly [data] node will be in bytecode def _length_of_data(assembly): ret = 0 - assert isinstance(assembly[0], _DataHeader) + assert isinstance(assembly[0], DataHeader) for item in assembly[1:]: if is_symbol(item): ret += SYMBOL_SIZE @@ -1055,7 +1094,7 @@ def _length_of_data(assembly): @dataclass -class _RuntimeHeader: +class RuntimeHeader: label: str ctor_mem_size: int immutables_len: int @@ -1065,7 +1104,7 @@ def __repr__(self): @dataclass -class _DataHeader: +class DataHeader: label: str def __repr__(self): @@ -1081,11 +1120,11 @@ def _relocate_segments(assembly): code_segments = [] for t in assembly: if isinstance(t, list): - if isinstance(t[0], _DataHeader): + if isinstance(t[0], DataHeader): data_segments.append(t) else: _relocate_segments(t) # recurse - assert isinstance(t[0], _RuntimeHeader) + assert isinstance(t[0], RuntimeHeader) code_segments.append(t) else: non_data_segments.append(t) @@ -1134,7 +1173,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat mem_ofst_size, ctor_mem_size = None, None max_mem_ofst = 0 for i, item in enumerate(assembly): - if isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + if isinstance(item, list) and isinstance(item[0], RuntimeHeader): assert runtime_code is None, "Multiple subcodes" assert ctor_mem_size is None @@ -1184,6 +1223,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat if is_symbol_map_indicator(assembly[i + 1]): # Don't increment pc as the symbol itself doesn't go into code if item in symbol_map: + print(assembly) raise CompilerPanic(f"duplicate jumpdest {item}") symbol_map[item] = pc @@ -1198,7 +1238,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat # [_OFST, _sym_foo, bar] -> PUSH2 (foo+bar) # [_OFST, _mem_foo, bar] -> PUSHN (foo+bar) pc -= 1 - elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + elif isinstance(item, list) and isinstance(item[0], RuntimeHeader): # we are in initcode symbol_map[item[0].label] = pc # add source map for all items in the runtime map @@ -1209,10 +1249,10 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat pc += len(runtime_code) # grab lengths of data sections from the runtime for t in item: - if isinstance(t, list) and isinstance(t[0], _DataHeader): + if isinstance(t, list) and isinstance(t[0], DataHeader): data_section_lengths.append(_length_of_data(t)) - elif isinstance(item, list) and isinstance(item[0], _DataHeader): + elif isinstance(item, list) and isinstance(item[0], DataHeader): symbol_map[item[0].label] = pc pc += _length_of_data(item) else: @@ -1285,9 +1325,9 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat ret.append(DUP_OFFSET + int(item[3:])) elif item[:4] == "SWAP": ret.append(SWAP_OFFSET + int(item[4:])) - elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + elif isinstance(item, list) and isinstance(item[0], RuntimeHeader): ret.extend(runtime_code) - elif isinstance(item, list) and isinstance(item[0], _DataHeader): + elif isinstance(item, list) and isinstance(item[0], DataHeader): ret.extend(_data_to_evm(item, symbol_map)) else: # pragma: no cover # unreachable diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index 8df4bbac2d..79e02f041d 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -440,6 +440,8 @@ def _optimize(node: IRnode, parent: Optional[IRnode]) -> Tuple[bool, IRnode]: error_msg = node.error_msg annotation = node.annotation add_gas_estimate = node.add_gas_estimate + is_self_call = node.is_self_call + passthrough_metadata = node.passthrough_metadata changed = False @@ -462,6 +464,8 @@ def finalize(val, args): error_msg=error_msg, annotation=annotation, add_gas_estimate=add_gas_estimate, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) if should_check_symbols: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 77b9efb13d..140f73f095 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -93,7 +93,7 @@ def __init__( self.nonreentrant = nonreentrant # a list of internal functions this function calls - self.called_functions = OrderedSet() + self.called_functions = OrderedSet[ContractFunctionT]() # to be populated during codegen self._ir_info: Any = None diff --git a/vyper/utils.py b/vyper/utils.py index 3d9d9cb416..0a2e1f831f 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -6,12 +6,14 @@ import time import traceback import warnings -from typing import List, Union +from typing import Generic, List, TypeVar, Union from vyper.exceptions import DecimalOverrideException, InvalidLiteral +_T = TypeVar("_T") -class OrderedSet(dict): + +class OrderedSet(Generic[_T], dict[_T, None]): """ a minimal "ordered set" class. this is needed in some places because, while dict guarantees you can recover insertion order @@ -20,9 +22,41 @@ class OrderedSet(dict): functionality as needed. """ - def add(self, item): + def __init__(self, iterable=None): + super().__init__() + if iterable is not None: + for item in iterable: + self.add(item) + + def __repr__(self): + keys = ", ".join(repr(k) for k in self.keys()) + return f"{{{keys}}}" + + def get(self, *args, **kwargs): + raise RuntimeError("can't call get() on OrderedSet!") + + def add(self, item: _T) -> None: self[item] = None + def remove(self, item: _T) -> None: + del self[item] + + def difference(self, other): + ret = self.copy() + for k in other.keys(): + if k in ret: + ret.remove(k) + return ret + + def union(self, other): + return self | other + + def __or__(self, other): + return self.__class__(super().__or__(other)) + + def copy(self): + return self.__class__(super().copy()) + class DecimalContextOverride(decimal.Context): def __setattr__(self, name, value): @@ -436,3 +470,25 @@ def annotate_source_code( cleanup_lines += [""] * (num_lines - len(cleanup_lines)) return "\n".join(cleanup_lines) + + +def ir_pass(func): + """ + Decorator for IR passes. This decorator will run the pass repeatedly until + no more changes are made. + """ + + def wrapper(*args, **kwargs): + count = 0 + + while True: + changes = func(*args, **kwargs) or 0 + if isinstance(changes, list) or isinstance(changes, set): + changes = len(changes) + count += changes + if changes == 0: + break + + return count + + return wrapper diff --git a/vyper/venom/README.md b/vyper/venom/README.md new file mode 100644 index 0000000000..a81f6c0582 --- /dev/null +++ b/vyper/venom/README.md @@ -0,0 +1,162 @@ +## Venom - An Intermediate representation language for Vyper + +### Introduction + +Venom serves as the next-gen intermediate representation language specifically tailored for use with the Vyper smart contract compiler. Drawing inspiration from LLVM IR, Venom has been adapted to be simpler, and to be architected towards emitting code for stack-based virtual machines. Designed with a Single Static Assignment (SSA) form, Venom allows for sophisticated analysis and optimizations, while accommodating the idiosyncrasies of the EVM architecture. + +### Venom Form + +In Venom, values are denoted as strings commencing with the `'%'` character, referred to as variables. Variables can only be assigned to at declaration (they remain immutable post-assignment). Constants are represented as decimal numbers (hexadecimal may be added in the future). + +Reserved words include all the instruction opcodes and `'IRFunction'`, `'param'`, `'dbname'` and `'db'`. + +Any content following the `';'` character until the line end is treated as a comment. + +For instance, an example of incrementing a variable by one is as follows: + +```llvm +%sum = add %x, 1 ; Add one to x +``` + +Each instruction is identified by its opcode and a list of input operands. In cases where an instruction produces a result, it is stored in a new variable, as indicated on the left side of the assignment character. + +Code is organized into non-branching instruction blocks, known as _"Basic Blocks"_. Each basic block is defined by a label and contains its set of instructions. The final instruction of a basic block should either be a terminating instruction or a jump (conditional or unconditional) to other block(s). + +Basic blocks are grouped into _functions_ that are named and dictate the first block to execute. + +Venom employs two scopes: global and function level. + +### Example code + +```llvm +IRFunction: global + +global: + %1 = calldataload 0 + %2 = shr 224, %1 + jmp label %selector_bucket_0 + +selector_bucket_0: + %3 = xor %2, 1579456981 + %4 = iszero %3 + jnz label %1, label %2, %4 + +1: IN=[selector_bucket_0] OUT=[9] + jmp label %fallback + +2: + %5 = callvalue + %6 = calldatasize + %7 = lt %6, 164 + %8 = or %5, %7 + %9 = iszero %8 + assert %9 + stop + +fallback: + revert 0, 0 +``` + +### Grammar + +Below is a (not-so-complete) grammar to describe the text format of Venom IR: + +```llvm +program ::= function_declaration* + +function_declaration ::= "IRFunction:" identifier input_list? output_list? "=>" block + +input_list ::= "IN=" "[" (identifier ("," identifier)*)? "]" +output_list ::= "OUT=" "[" (identifier ("," identifier)*)? "]" + +block ::= label ":" input_list? output_list? "=>{" operation* "}" + +operation ::= "%" identifier "=" opcode operand ("," operand)* + | opcode operand ("," operand)* + +opcode ::= "calldataload" | "shr" | "shl" | "and" | "add" | "codecopy" | "mload" | "jmp" | "xor" | "iszero" | "jnz" | "label" | "lt" | "or" | "assert" | "callvalue" | "calldatasize" | "alloca" | "calldatacopy" | "invoke" | "gt" | ... + +operand ::= "%" identifier | label | integer | "label" "%" identifier +label ::= "%" identifier + +identifier ::= [a-zA-Z_][a-zA-Z0-9_]* +integer ::= [0-9]+ +``` + +## Implementation + +In the current implementation the compiler was extended to incorporate a new pass responsible for translating the original s-expr based IR into Venom. Subsequently, the generated Venom code undergoes processing by the actual Venom compiler, ultimately converting it to assembly code. That final assembly code is then passed to the original assembler of Vyper to produce the executable bytecode. + +Currently there is no implementation of the text format (that is, there is no front-end), although this is planned. At this time, Venom IR can only be constructed programmatically. + +## Architecture + +The Venom implementation is composed of several distinct passes that iteratively transform and optimize the Venom IR code until it reaches the assembly emitter, which produces the stack-based EVM assembly. The compiler is designed to be more-or-less pluggable, so passes can be written without too much knowledge of or dependency on other passes. + +These passes encompass generic transformations that streamline the code (such as dead code elimination and normalization), as well as those generating supplementary information about the code, like liveness analysis and control-flow graph (CFG) construction. Some passes may rely on the output of others, requiring a specific execution order. For instance, the code emitter expects the execution of a normalization pass preceding it, and this normalization pass, in turn, requires the augmentation of the Venom IR with code flow information. + +The primary categorization of pass types are: + +- Transformation passes +- Analysis/augmentation passes +- Optimization passes + +## Currently implemented passes + +The Venom compiler currently implements the following passes. + +### Control Flow Graph calculation + +The compiler generates a fundamental data structure known as the Control Flow Graph (CFG). This graph illustrates the interconnections between basic blocks, serving as a foundational data structure upon which many subsequent passes depend. + +### Data Flow Graph calculation + +To enable the compiler to analyze the movement of data through the code during execution, a specialized graph, the Dataflow Graph (DFG), is generated. The compiler inspects the code, determining where each variable is defined (in one location) and all the places where it is utilized. + +### Dataflow Transformation + +This pass depends on the DFG construction, and reorders variable declarations to try to reduce stack traffic during instruction selection. + +### Liveness analysis + +This pass conducts a dataflow analysis, utilizing information from previous passes to identify variables that are live at each instruction in the Venom IR code. A variable is deemed live at a particular instruction if it holds a value necessary for future operations. Variables only alive for their assignment instructions are identified here and then eliminated by the dead code elimination pass. + +### Dead code elimination + +This pass eliminates all basic blocks that are not reachable from any other basic block, leveraging the CFG. + +### Normalization + +A Venom program may feature basic blocks with multiple CFG inputs and outputs. This currently can occur when multiple blocks conditionally direct control to the same target basic block. We define a Venom IR as "normalized" when it contains no basic blocks that have multiple inputs and outputs. The normalization pass is responsible for converting any Venom IR program to its normalized form. EVM assembly emission operates solely on normalized Venom programs, because the stack layout is not well defined for non-normalized basic blocks. + +### Code emission + +This final pass of the compiler aims to emit EVM assembly recognized by Vyper's assembler. It calcluates the desired stack layout for every basic block, schedules items on the stack and selects instructions. It ensures that deploy code, runtime code, and data segments are arranged according to the assembler's expectations. + +## Future planned passes + +A number of passes that are planned to be implemented, or are implemented for immediately after the initial PR merge are below. + +### Constant folding + +### Instruction combination + +### Dead store elimination + +### Scalar evolution + +### Loop invariant code motion + +### Loop unrolling + +### Code sinking + +### Expression reassociation + +### Stack to mem + +### Mem to stack + +### Function inlining + +### Load-store elimination diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py new file mode 100644 index 0000000000..5a09f8378e --- /dev/null +++ b/vyper/venom/__init__.py @@ -0,0 +1,56 @@ +# maybe rename this `main.py` or `venom.py` +# (can have an `__init__.py` which exposes the API). + +from typing import Optional + +from vyper.codegen.ir_node import IRnode +from vyper.compiler.settings import OptimizationLevel +from vyper.venom.analysis import DFG, calculate_cfg, calculate_liveness +from vyper.venom.bb_optimizer import ( + ir_pass_optimize_empty_blocks, + ir_pass_optimize_unused_variables, + ir_pass_remove_unreachable_blocks, +) +from vyper.venom.function import IRFunction +from vyper.venom.ir_node_to_venom import convert_ir_basicblock +from vyper.venom.passes.constant_propagation import ir_pass_constant_propagation +from vyper.venom.passes.dft import DFTPass +from vyper.venom.venom_to_assembly import VenomCompiler + + +def generate_assembly_experimental( + ctx: IRFunction, optimize: Optional[OptimizationLevel] = None +) -> list[str]: + compiler = VenomCompiler(ctx) + return compiler.generate_evm(optimize is OptimizationLevel.NONE) + + +def generate_ir(ir: IRnode, optimize: Optional[OptimizationLevel] = None) -> IRFunction: + # Convert "old" IR to "new" IR + ctx = convert_ir_basicblock(ir) + + # Run passes on "new" IR + # TODO: Add support for optimization levels + while True: + changes = 0 + + changes += ir_pass_optimize_empty_blocks(ctx) + changes += ir_pass_remove_unreachable_blocks(ctx) + + calculate_liveness(ctx) + + changes += ir_pass_optimize_unused_variables(ctx) + + calculate_cfg(ctx) + calculate_liveness(ctx) + + changes += ir_pass_constant_propagation(ctx) + changes += DFTPass.run_pass(ctx) + + calculate_cfg(ctx) + calculate_liveness(ctx) + + if changes == 0: + break + + return ctx diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py new file mode 100644 index 0000000000..5980e21028 --- /dev/null +++ b/vyper/venom/analysis.py @@ -0,0 +1,191 @@ +from vyper.exceptions import CompilerPanic +from vyper.utils import OrderedSet +from vyper.venom.basicblock import ( + BB_TERMINATORS, + CFG_ALTERING_OPS, + IRBasicBlock, + IRInstruction, + IRVariable, +) +from vyper.venom.function import IRFunction + + +def calculate_cfg(ctx: IRFunction) -> None: + """ + Calculate (cfg) inputs for each basic block. + """ + for bb in ctx.basic_blocks: + bb.cfg_in = OrderedSet() + bb.cfg_out = OrderedSet() + bb.out_vars = OrderedSet() + + # TODO: This is a hack to support the old IR format where `deploy` is + # an instruction. in the future we should have two entry points, one + # for the initcode and one for the runtime code. + deploy_bb = None + after_deploy_bb = None + for i, bb in enumerate(ctx.basic_blocks): + if bb.instructions[0].opcode == "deploy": + deploy_bb = bb + after_deploy_bb = ctx.basic_blocks[i + 1] + break + + if deploy_bb is not None: + assert after_deploy_bb is not None, "No block after deploy block" + entry_block = after_deploy_bb + has_constructor = ctx.basic_blocks[0].instructions[0].opcode != "deploy" + if has_constructor: + deploy_bb.add_cfg_in(ctx.basic_blocks[0]) + entry_block.add_cfg_in(deploy_bb) + else: + entry_block = ctx.basic_blocks[0] + + # TODO: Special case for the jump table of selector buckets and fallback. + # this will be cleaner when we introduce an "indirect jump" instruction + # for the selector table (which includes all possible targets). it will + # also clean up the code for normalization because it will not have to + # handle this case specially. + for bb in ctx.basic_blocks: + if "selector_bucket_" in bb.label.value or bb.label.value == "fallback": + bb.add_cfg_in(entry_block) + + for bb in ctx.basic_blocks: + assert len(bb.instructions) > 0, "Basic block should not be empty" + last_inst = bb.instructions[-1] + assert last_inst.opcode in BB_TERMINATORS, f"Last instruction should be a terminator {bb}" + + for inst in bb.instructions: + if inst.opcode in CFG_ALTERING_OPS: + ops = inst.get_label_operands() + for op in ops: + ctx.get_basic_block(op.value).add_cfg_in(bb) + + # Fill in the "out" set for each basic block + for bb in ctx.basic_blocks: + for in_bb in bb.cfg_in: + in_bb.add_cfg_out(bb) + + +def _reset_liveness(ctx: IRFunction) -> None: + for bb in ctx.basic_blocks: + for inst in bb.instructions: + inst.liveness = OrderedSet() + + +def _calculate_liveness_bb(bb: IRBasicBlock) -> None: + """ + Compute liveness of each instruction in the basic block. + """ + liveness = bb.out_vars.copy() + for instruction in reversed(bb.instructions): + ops = instruction.get_inputs() + + for op in ops: + if op in liveness: + instruction.dup_requirements.add(op) + + liveness = liveness.union(OrderedSet.fromkeys(ops)) + out = instruction.get_outputs()[0] if len(instruction.get_outputs()) > 0 else None + if out in liveness: + liveness.remove(out) + instruction.liveness = liveness + + +def _calculate_liveness_r(bb: IRBasicBlock, visited: dict) -> None: + assert isinstance(visited, dict) + for out_bb in bb.cfg_out: + if visited.get(bb) == out_bb: + continue + visited[bb] = out_bb + + # recurse + _calculate_liveness_r(out_bb, visited) + + target_vars = input_vars_from(bb, out_bb) + + # the output stack layout for bb. it produces a stack layout + # which works for all possible cfg_outs from the bb. + bb.out_vars = bb.out_vars.union(target_vars) + + _calculate_liveness_bb(bb) + + +def calculate_liveness(ctx: IRFunction) -> None: + _reset_liveness(ctx) + _calculate_liveness_r(ctx.basic_blocks[0], dict()) + + +# calculate the input variables into self from source +def input_vars_from(source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IRVariable]: + liveness = target.instructions[0].liveness.copy() + assert isinstance(liveness, OrderedSet) + + for inst in target.instructions: + if inst.opcode == "phi": + # we arbitrarily choose one of the arguments to be in the + # live variables set (dependent on how we traversed into this + # basic block). the argument will be replaced by the destination + # operand during instruction selection. + # for instance, `%56 = phi %label1 %12 %label2 %14` + # will arbitrarily choose either %12 or %14 to be in the liveness + # set, and then during instruction selection, after this instruction, + # %12 will be replaced by %56 in the liveness set + source1, source2 = inst.operands[0], inst.operands[2] + phi1, phi2 = inst.operands[1], inst.operands[3] + if source.label == source1: + liveness.add(phi1) + if phi2 in liveness: + liveness.remove(phi2) + elif source.label == source2: + liveness.add(phi2) + if phi1 in liveness: + liveness.remove(phi1) + else: + # bad path into this phi node + raise CompilerPanic(f"unreachable: {inst}") + + return liveness + + +# DataFlow Graph +# this could be refactored into its own file, but it's only used here +# for now +class DFG: + _dfg_inputs: dict[IRVariable, list[IRInstruction]] + _dfg_outputs: dict[IRVariable, IRInstruction] + + def __init__(self): + self._dfg_inputs = dict() + self._dfg_outputs = dict() + + # return uses of a given variable + def get_uses(self, op: IRVariable) -> list[IRInstruction]: + return self._dfg_inputs.get(op, []) + + # the instruction which produces this variable. + def get_producing_instruction(self, op: IRVariable) -> IRInstruction: + return self._dfg_outputs[op] + + @classmethod + def build_dfg(cls, ctx: IRFunction) -> "DFG": + dfg = cls() + + # Build DFG + + # %15 = add %13 %14 + # %16 = iszero %15 + # dfg_outputs of %15 is (%15 = add %13 %14) + # dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...] + for bb in ctx.basic_blocks: + for inst in bb.instructions: + operands = inst.get_inputs() + res = inst.get_outputs() + + for op in operands: + inputs = dfg._dfg_inputs.setdefault(op, []) + inputs.append(inst) + + for op in res: # type: ignore + dfg._dfg_outputs[op] = inst + + return dfg diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py new file mode 100644 index 0000000000..b95d7416ca --- /dev/null +++ b/vyper/venom/basicblock.py @@ -0,0 +1,345 @@ +from enum import Enum, auto +from typing import TYPE_CHECKING, Any, Iterator, Optional + +from vyper.utils import OrderedSet + +# instructions which can terminate a basic block +BB_TERMINATORS = frozenset(["jmp", "jnz", "ret", "return", "revert", "deploy", "stop"]) + +VOLATILE_INSTRUCTIONS = frozenset( + [ + "param", + "alloca", + "call", + "staticcall", + "invoke", + "sload", + "sstore", + "iload", + "istore", + "assert", + "mstore", + "mload", + "calldatacopy", + "codecopy", + "dloadbytes", + "dload", + "return", + "ret", + "jmp", + "jnz", + ] +) + +CFG_ALTERING_OPS = frozenset(["jmp", "jnz", "call", "staticcall", "invoke", "deploy"]) + + +if TYPE_CHECKING: + from vyper.venom.function import IRFunction + + +class IRDebugInfo: + """ + IRDebugInfo represents debug information in IR, used to annotate IR instructions + with source code information when printing IR. + """ + + line_no: int + src: str + + def __init__(self, line_no: int, src: str) -> None: + self.line_no = line_no + self.src = src + + def __repr__(self) -> str: + src = self.src if self.src else "" + return f"\t# line {self.line_no}: {src}".expandtabs(20) + + +class IROperand: + """ + IROperand represents an operand in IR. An operand is anything that can + be an argument to an IRInstruction + """ + + value: Any + + +class IRValue(IROperand): + """ + IRValue represents a value in IR. A value is anything that can be + operated by non-control flow instructions. That is, IRValues can be + IRVariables or IRLiterals. + """ + + pass + + +class IRLiteral(IRValue): + """ + IRLiteral represents a literal in IR + """ + + value: int + + def __init__(self, value: int) -> None: + assert isinstance(value, str) or isinstance(value, int), "value must be an int" + self.value = value + + def __repr__(self) -> str: + return str(self.value) + + +class MemType(Enum): + OPERAND_STACK = auto() + MEMORY = auto() + + +class IRVariable(IRValue): + """ + IRVariable represents a variable in IR. A variable is a string that starts with a %. + """ + + value: str + offset: int = 0 + + # some variables can be in memory for conversion from legacy IR to venom + mem_type: MemType = MemType.OPERAND_STACK + mem_addr: Optional[int] = None + + def __init__( + self, value: str, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: int = None + ) -> None: + assert isinstance(value, str) + self.value = value + self.offset = 0 + self.mem_type = mem_type + self.mem_addr = mem_addr + + def __repr__(self) -> str: + return self.value + + +class IRLabel(IROperand): + """ + IRLabel represents a label in IR. A label is a string that starts with a %. + """ + + # is_symbol is used to indicate if the label came from upstream + # (like a function name, try to preserve it in optimization passes) + is_symbol: bool = False + value: str + + def __init__(self, value: str, is_symbol: bool = False) -> None: + assert isinstance(value, str), "value must be an str" + self.value = value + self.is_symbol = is_symbol + + def __repr__(self) -> str: + return self.value + + +class IRInstruction: + """ + IRInstruction represents an instruction in IR. Each instruction has an opcode, + operands, and return value. For example, the following IR instruction: + %1 = add %0, 1 + has opcode "add", operands ["%0", "1"], and return value "%1". + + Convention: the rightmost value is the top of the stack. + """ + + opcode: str + volatile: bool + operands: list[IROperand] + output: Optional[IROperand] + # set of live variables at this instruction + liveness: OrderedSet[IRVariable] + dup_requirements: OrderedSet[IRVariable] + parent: Optional["IRBasicBlock"] + fence_id: int + annotation: Optional[str] + + def __init__( + self, + opcode: str, + operands: list[IROperand] | Iterator[IROperand], + output: Optional[IROperand] = None, + ): + assert isinstance(opcode, str), "opcode must be an str" + assert isinstance(operands, list | Iterator), "operands must be a list" + self.opcode = opcode + self.volatile = opcode in VOLATILE_INSTRUCTIONS + self.operands = [op for op in operands] # in case we get an iterator + self.output = output + self.liveness = OrderedSet() + self.dup_requirements = OrderedSet() + self.parent = None + self.fence_id = -1 + self.annotation = None + + def get_label_operands(self) -> list[IRLabel]: + """ + Get all labels in instruction. + """ + return [op for op in self.operands if isinstance(op, IRLabel)] + + def get_non_label_operands(self) -> list[IROperand]: + """ + Get input operands for instruction which are not labels + """ + return [op for op in self.operands if not isinstance(op, IRLabel)] + + def get_inputs(self) -> list[IRVariable]: + """ + Get all input operands for instruction. + """ + return [op for op in self.operands if isinstance(op, IRVariable)] + + def get_outputs(self) -> list[IROperand]: + """ + Get the output item for an instruction. + (Currently all instructions output at most one item, but write + it as a list to be generic for the future) + """ + return [self.output] if self.output else [] + + def replace_operands(self, replacements: dict) -> None: + """ + Update operands with replacements. + replacements are represented using a dict: "key" is replaced by "value". + """ + for i, operand in enumerate(self.operands): + if operand in replacements: + self.operands[i] = replacements[operand] + + def __repr__(self) -> str: + s = "" + if self.output: + s += f"{self.output} = " + opcode = f"{self.opcode} " if self.opcode != "store" else "" + s += opcode + operands = ", ".join( + [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in self.operands] + ) + s += operands + + if self.annotation: + s += f" <{self.annotation}>" + + # if self.liveness: + # return f"{s: <30} # {self.liveness}" + + return s + + +class IRBasicBlock: + """ + IRBasicBlock represents a basic block in IR. Each basic block has a label and + a list of instructions, while belonging to a function. + + The following IR code: + %1 = add %0, 1 + %2 = mul %1, 2 + is represented as: + bb = IRBasicBlock("bb", function) + bb.append_instruction(IRInstruction("add", ["%0", "1"], "%1")) + bb.append_instruction(IRInstruction("mul", ["%1", "2"], "%2")) + + The label of a basic block is used to refer to it from other basic blocks + in order to branch to it. + + The parent of a basic block is the function it belongs to. + + The instructions of a basic block are executed sequentially, and the last + instruction of a basic block is always a terminator instruction, which is + used to branch to other basic blocks. + """ + + label: IRLabel + parent: "IRFunction" + instructions: list[IRInstruction] + # basic blocks which can jump to this basic block + cfg_in: OrderedSet["IRBasicBlock"] + # basic blocks which this basic block can jump to + cfg_out: OrderedSet["IRBasicBlock"] + # stack items which this basic block produces + out_vars: OrderedSet[IRVariable] + + def __init__(self, label: IRLabel, parent: "IRFunction") -> None: + assert isinstance(label, IRLabel), "label must be an IRLabel" + self.label = label + self.parent = parent + self.instructions = [] + self.cfg_in = OrderedSet() + self.cfg_out = OrderedSet() + self.out_vars = OrderedSet() + + def add_cfg_in(self, bb: "IRBasicBlock") -> None: + self.cfg_in.add(bb) + + def remove_cfg_in(self, bb: "IRBasicBlock") -> None: + assert bb in self.cfg_in + self.cfg_in.remove(bb) + + def add_cfg_out(self, bb: "IRBasicBlock") -> None: + # malformed: jnz condition label1 label1 + # (we could handle but it makes a lot of code easier + # if we have this assumption) + self.cfg_out.add(bb) + + def remove_cfg_out(self, bb: "IRBasicBlock") -> None: + assert bb in self.cfg_out + self.cfg_out.remove(bb) + + @property + def is_reachable(self) -> bool: + return len(self.cfg_in) > 0 + + def append_instruction(self, instruction: IRInstruction) -> None: + assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + instruction.parent = self + self.instructions.append(instruction) + + def insert_instruction(self, instruction: IRInstruction, index: int) -> None: + assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + instruction.parent = self + self.instructions.insert(index, instruction) + + def clear_instructions(self) -> None: + self.instructions = [] + + def replace_operands(self, replacements: dict) -> None: + """ + Update operands with replacements. + """ + for instruction in self.instructions: + instruction.replace_operands(replacements) + + @property + def is_terminated(self) -> bool: + """ + Check if the basic block is terminal, i.e. the last instruction is a terminator. + """ + # it's ok to return False here, since we use this to check + # if we can/need to append instructions to the basic block. + if len(self.instructions) == 0: + return False + return self.instructions[-1].opcode in BB_TERMINATORS + + def copy(self): + bb = IRBasicBlock(self.label, self.parent) + bb.instructions = self.instructions.copy() + bb.cfg_in = self.cfg_in.copy() + bb.cfg_out = self.cfg_out.copy() + bb.out_vars = self.out_vars.copy() + return bb + + def __repr__(self) -> str: + s = ( + f"{repr(self.label)}: IN={[bb.label for bb in self.cfg_in]}" + f" OUT={[bb.label for bb in self.cfg_out]} => {self.out_vars} \n" + ) + for instruction in self.instructions: + s += f" {instruction}\n" + return s diff --git a/vyper/venom/bb_optimizer.py b/vyper/venom/bb_optimizer.py new file mode 100644 index 0000000000..620ee66d15 --- /dev/null +++ b/vyper/venom/bb_optimizer.py @@ -0,0 +1,73 @@ +from vyper.utils import ir_pass +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRInstruction, IRLabel +from vyper.venom.function import IRFunction + + +def _optimize_unused_variables(ctx: IRFunction) -> set[IRInstruction]: + """ + Remove unused variables. + """ + removeList = set() + for bb in ctx.basic_blocks: + for i, inst in enumerate(bb.instructions[:-1]): + if inst.volatile: + continue + if inst.output and inst.output not in bb.instructions[i + 1].liveness: + removeList.add(inst) + + bb.instructions = [inst for inst in bb.instructions if inst not in removeList] + + return removeList + + +def _optimize_empty_basicblocks(ctx: IRFunction) -> int: + """ + Remove empty basic blocks. + """ + count = 0 + i = 0 + while i < len(ctx.basic_blocks): + bb = ctx.basic_blocks[i] + i += 1 + if len(bb.instructions) > 0: + continue + + replaced_label = bb.label + replacement_label = ctx.basic_blocks[i].label if i < len(ctx.basic_blocks) else None + if replacement_label is None: + continue + + # Try to preserve symbol labels + if replaced_label.is_symbol: + replaced_label, replacement_label = replacement_label, replaced_label + ctx.basic_blocks[i].label = replacement_label + + for bb2 in ctx.basic_blocks: + for inst in bb2.instructions: + for op in inst.operands: + if isinstance(op, IRLabel) and op.value == replaced_label.value: + op.value = replacement_label.value + + ctx.basic_blocks.remove(bb) + i -= 1 + count += 1 + + return count + + +@ir_pass +def ir_pass_optimize_empty_blocks(ctx: IRFunction) -> int: + changes = _optimize_empty_basicblocks(ctx) + calculate_cfg(ctx) + return changes + + +@ir_pass +def ir_pass_remove_unreachable_blocks(ctx: IRFunction) -> int: + return ctx.remove_unreachable_blocks() + + +@ir_pass +def ir_pass_optimize_unused_variables(ctx: IRFunction) -> int: + return len(_optimize_unused_variables(ctx)) diff --git a/vyper/venom/function.py b/vyper/venom/function.py new file mode 100644 index 0000000000..c14ad77345 --- /dev/null +++ b/vyper/venom/function.py @@ -0,0 +1,170 @@ +from typing import Optional + +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IROperand, + IRVariable, + MemType, +) + +GLOBAL_LABEL = IRLabel("global") + + +class IRFunction: + """ + Function that contains basic blocks. + """ + + name: IRLabel # symbol name + args: list + basic_blocks: list[IRBasicBlock] + data_segment: list[IRInstruction] + last_label: int + last_variable: int + + def __init__(self, name: IRLabel = None) -> None: + if name is None: + name = GLOBAL_LABEL + self.name = name + self.args = [] + self.basic_blocks = [] + self.data_segment = [] + self.last_label = 0 + self.last_variable = 0 + + self.append_basic_block(IRBasicBlock(name, self)) + + def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: + """ + Append basic block to function. + """ + assert isinstance(bb, IRBasicBlock), f"append_basic_block takes IRBasicBlock, got '{bb}'" + self.basic_blocks.append(bb) + + # TODO add sanity check somewhere that basic blocks have unique labels + + return self.basic_blocks[-1] + + def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: + """ + Get basic block by label. + If label is None, return the last basic block. + """ + if label is None: + return self.basic_blocks[-1] + for bb in self.basic_blocks: + if bb.label.value == label: + return bb + raise AssertionError(f"Basic block '{label}' not found") + + def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock: + """ + Get basic block after label. + """ + for i, bb in enumerate(self.basic_blocks[:-1]): + if bb.label.value == label.value: + return self.basic_blocks[i + 1] + raise AssertionError(f"Basic block after '{label}' not found") + + def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]: + """ + Get basic blocks that contain label. + """ + return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in] + + def get_next_label(self) -> IRLabel: + self.last_label += 1 + return IRLabel(f"{self.last_label}") + + def get_next_variable( + self, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: Optional[int] = None + ) -> IRVariable: + self.last_variable += 1 + return IRVariable(f"%{self.last_variable}", mem_type, mem_addr) + + def get_last_variable(self) -> str: + return f"%{self.last_variable}" + + def remove_unreachable_blocks(self) -> int: + removed = 0 + new_basic_blocks = [] + for bb in self.basic_blocks: + if not bb.is_reachable and bb.label.value != "global": + removed += 1 + else: + new_basic_blocks.append(bb) + self.basic_blocks = new_basic_blocks + return removed + + def append_instruction( + self, opcode: str, args: list[IROperand], do_ret: bool = True + ) -> Optional[IRVariable]: + """ + Append instruction to last basic block. + """ + ret = self.get_next_variable() if do_ret else None + inst = IRInstruction(opcode, args, ret) # type: ignore + self.get_basic_block().append_instruction(inst) + return ret + + def append_data(self, opcode: str, args: list[IROperand]) -> None: + """ + Append data + """ + self.data_segment.append(IRInstruction(opcode, args)) # type: ignore + + @property + def normalized(self) -> bool: + """ + Check if function is normalized. A function is normalized if in the + CFG, no basic block simultaneously has multiple inputs and outputs. + That is, a basic block can be jumped to *from* multiple blocks, or it + can jump *to* multiple blocks, but it cannot simultaneously do both. + Having a normalized CFG makes calculation of stack layout easier when + emitting assembly. + """ + for bb in self.basic_blocks: + # Ignore if there are no multiple predecessors + if len(bb.cfg_in) <= 1: + continue + + # Check if there is a conditional jump at the end + # of one of the predecessors + # + # TODO: this check could be: + # `if len(in_bb.cfg_out) > 1: return False` + # but the cfg is currently not calculated "correctly" for + # certain special instructions (deploy instruction and + # selector table indirect jumps). + for in_bb in bb.cfg_in: + jump_inst = in_bb.instructions[-1] + if jump_inst.opcode != "jnz": + continue + if jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRLabel): + continue + + # The function is not normalized + return False + + # The function is normalized + return True + + def copy(self): + new = IRFunction(self.name) + new.basic_blocks = self.basic_blocks.copy() + new.data_segment = self.data_segment.copy() + new.last_label = self.last_label + new.last_variable = self.last_variable + return new + + def __repr__(self) -> str: + str = f"IRFunction: {self.name}\n" + for bb in self.basic_blocks: + str += f"{bb}\n" + if len(self.data_segment) > 0: + str += "Data segment:\n" + for inst in self.data_segment: + str += f"{inst}\n" + return str diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py new file mode 100644 index 0000000000..19bd5c8b73 --- /dev/null +++ b/vyper/venom/ir_node_to_venom.py @@ -0,0 +1,943 @@ +from typing import Optional + +from vyper.codegen.context import VariableRecord +from vyper.codegen.ir_node import IRnode +from vyper.evm.opcodes import get_opcodes +from vyper.exceptions import CompilerPanic +from vyper.ir.compile_ir import is_mem_sym, is_symbol +from vyper.semantics.types.function import ContractFunctionT +from vyper.utils import MemoryPositions, OrderedSet +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, + MemType, +) +from vyper.venom.function import IRFunction + +_BINARY_IR_INSTRUCTIONS = frozenset( + [ + "eq", + "gt", + "lt", + "slt", + "sgt", + "shr", + "shl", + "or", + "xor", + "and", + "add", + "sub", + "mul", + "div", + "mod", + "exp", + "sha3", + "sha3_64", + "signextend", + ] +) + +# Instuctions that are mapped to their inverse +INVERSE_MAPPED_IR_INSTRUCTIONS = {"ne": "eq", "le": "gt", "sle": "sgt", "ge": "lt", "sge": "slt"} + +# Instructions that have a direct EVM opcode equivalent and can +# be passed through to the EVM assembly without special handling +PASS_THROUGH_INSTRUCTIONS = [ + "chainid", + "basefee", + "timestamp", + "caller", + "selfbalance", + "calldatasize", + "callvalue", + "address", + "origin", + "codesize", + "gas", + "gasprice", + "gaslimit", + "returndatasize", + "coinbase", + "number", + "iszero", + "ceil32", + "calldataload", + "extcodesize", + "extcodehash", + "balance", +] + +SymbolTable = dict[str, IROperand] + + +def _get_symbols_common(a: dict, b: dict) -> dict: + ret = {} + # preserves the ordering in `a` + for k in a.keys(): + if k not in b: + continue + if a[k] == b[k]: + continue + ret[k] = a[k], b[k] + return ret + + +def convert_ir_basicblock(ir: IRnode) -> IRFunction: + global_function = IRFunction() + _convert_ir_basicblock(global_function, ir, {}, OrderedSet(), {}) + + for i, bb in enumerate(global_function.basic_blocks): + if not bb.is_terminated and i < len(global_function.basic_blocks) - 1: + bb.append_instruction(IRInstruction("jmp", [global_function.basic_blocks[i + 1].label])) + + revert_bb = IRBasicBlock(IRLabel("__revert"), global_function) + revert_bb = global_function.append_basic_block(revert_bb) + revert_bb.append_instruction(IRInstruction("revert", [IRLiteral(0), IRLiteral(0)])) + + return global_function + + +def _convert_binary_op( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], + swap: bool = False, +) -> IRVariable: + ir_args = ir.args[::-1] if swap else ir.args + arg_0 = _convert_ir_basicblock(ctx, ir_args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir_args[1], symbols, variables, allocated_variables) + args = [arg_1, arg_0] + + ret = ctx.get_next_variable() + + inst = IRInstruction(ir.value, args, ret) # type: ignore + ctx.get_basic_block().append_instruction(inst) + return ret + + +def _append_jmp(ctx: IRFunction, label: IRLabel) -> None: + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + + label = ctx.get_next_label() + bb = IRBasicBlock(label, ctx) + ctx.append_basic_block(bb) + + +def _new_block(ctx: IRFunction) -> IRBasicBlock: + bb = IRBasicBlock(ctx.get_next_label(), ctx) + bb = ctx.append_basic_block(bb) + return bb + + +def _handle_self_call( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> Optional[IRVariable]: + func_t = ir.passthrough_metadata.get("func_t", None) + args_ir = ir.passthrough_metadata["args_ir"] + goto_ir = [ir for ir in ir.args if ir.value == "goto"][0] + target_label = goto_ir.args[0].value # goto + return_buf = goto_ir.args[1] # return buffer + ret_args = [IRLabel(target_label)] # type: ignore + + for arg in args_ir: + if arg.is_literal: + sym = symbols.get(f"&{arg.value}", None) + if sym is None: + ret = _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + ret_args.append(ret) + else: + ret_args.append(sym) # type: ignore + else: + ret = _convert_ir_basicblock( + ctx, arg._optimized, symbols, variables, allocated_variables + ) + if arg.location and arg.location.load_op == "calldataload": + ret = ctx.append_instruction(arg.location.load_op, [ret]) + ret_args.append(ret) + + if return_buf.is_literal: + ret_args.append(IRLiteral(return_buf.value)) # type: ignore + + do_ret = func_t.return_type is not None + invoke_ret = ctx.append_instruction("invoke", ret_args, do_ret) # type: ignore + allocated_variables["return_buffer"] = invoke_ret # type: ignore + return invoke_ret + + +def _handle_internal_func( + ctx: IRFunction, ir: IRnode, func_t: ContractFunctionT, symbols: SymbolTable +) -> IRnode: + bb = IRBasicBlock(IRLabel(ir.args[0].args[0].value, True), ctx) # type: ignore + bb = ctx.append_basic_block(bb) + + old_ir_mempos = 0 + old_ir_mempos += 64 + + for arg in func_t.arguments: + new_var = ctx.get_next_variable() + + alloca_inst = IRInstruction("param", [], new_var) + alloca_inst.annotation = arg.name + bb.append_instruction(alloca_inst) + symbols[f"&{old_ir_mempos}"] = new_var + old_ir_mempos += 32 # arg.typ.memory_bytes_required + + # return buffer + if func_t.return_type is not None: + new_var = ctx.get_next_variable() + alloca_inst = IRInstruction("param", [], new_var) + bb.append_instruction(alloca_inst) + alloca_inst.annotation = "return_buffer" + symbols["return_buffer"] = new_var + + # return address + new_var = ctx.get_next_variable() + alloca_inst = IRInstruction("param", [], new_var) + bb.append_instruction(alloca_inst) + alloca_inst.annotation = "return_pc" + symbols["return_pc"] = new_var + + return ir.args[0].args[2] + + +def _convert_ir_simple_node( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> Optional[IRVariable]: + args = [ + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args + ] + return ctx.append_instruction(ir.value, args) # type: ignore + + +_break_target: Optional[IRBasicBlock] = None +_continue_target: Optional[IRBasicBlock] = None + + +def _get_variable_from_address( + variables: OrderedSet[VariableRecord], addr: int +) -> Optional[VariableRecord]: + assert isinstance(addr, int), "non-int address" + for var in variables.keys(): + if var.location.name != "memory": + continue + if addr >= var.pos and addr < var.pos + var.size: # type: ignore + return var + return None + + +def _get_return_for_stack_operand( + ctx: IRFunction, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable +) -> IRInstruction: + if isinstance(ret_ir, IRLiteral): + sym = symbols.get(f"&{ret_ir.value}", None) + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) + ctx.append_instruction("mstore", [sym, new_var], False) # type: ignore + else: + sym = symbols.get(ret_ir.value, None) + if sym is None: + # FIXME: needs real allocations + new_var = ctx.append_instruction("alloca", [IRLiteral(32), IRLiteral(0)]) + ctx.append_instruction("mstore", [ret_ir, new_var], False) # type: ignore + else: + new_var = ret_ir + return IRInstruction("return", [last_ir, new_var]) # type: ignore + + +def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): + assert isinstance(variables, OrderedSet) + global _break_target, _continue_target + + frame_info = ir.passthrough_metadata.get("frame_info", None) + if frame_info is not None: + local_vars = OrderedSet[VariableRecord](frame_info.frame_vars.values()) + variables |= local_vars + + assert isinstance(variables, OrderedSet) + + if ir.value in _BINARY_IR_INSTRUCTIONS: + return _convert_binary_op( + ctx, ir, symbols, variables, allocated_variables, ir.value in ["sha3_64"] + ) + + elif ir.value in INVERSE_MAPPED_IR_INSTRUCTIONS: + org_value = ir.value + ir.value = INVERSE_MAPPED_IR_INSTRUCTIONS[ir.value] + new_var = _convert_binary_op(ctx, ir, symbols, variables, allocated_variables) + ir.value = org_value + return ctx.append_instruction("iszero", [new_var]) + + elif ir.value in PASS_THROUGH_INSTRUCTIONS: + return _convert_ir_simple_node(ctx, ir, symbols, variables, allocated_variables) + + elif ir.value in ["pass", "stop", "return"]: + pass + elif ir.value == "deploy": + memsize = ir.args[0].value + ir_runtime = ir.args[1] + padding = ir.args[2].value + assert isinstance(memsize, int), "non-int memsize" + assert isinstance(padding, int), "non-int padding" + + runtimeLabel = ctx.get_next_label() + + inst = IRInstruction("deploy", [IRLiteral(memsize), runtimeLabel, IRLiteral(padding)]) + ctx.get_basic_block().append_instruction(inst) + + bb = IRBasicBlock(runtimeLabel, ctx) + ctx.append_basic_block(bb) + + _convert_ir_basicblock(ctx, ir_runtime, symbols, variables, allocated_variables) + elif ir.value == "seq": + func_t = ir.passthrough_metadata.get("func_t", None) + if ir.is_self_call: + return _handle_self_call(ctx, ir, symbols, variables, allocated_variables) + elif func_t is not None: + symbols = {} + allocated_variables = {} + variables = OrderedSet( + {v: True for v in ir.passthrough_metadata["frame_info"].frame_vars.values()} + ) + if func_t.is_internal: + ir = _handle_internal_func(ctx, ir, func_t, symbols) + # fallthrough + + ret = None + for ir_node in ir.args: # NOTE: skip the last one + ret = _convert_ir_basicblock(ctx, ir_node, symbols, variables, allocated_variables) + + return ret + elif ir.value in ["staticcall", "call"]: # external call + idx = 0 + gas = _convert_ir_basicblock(ctx, ir.args[idx], symbols, variables, allocated_variables) + address = _convert_ir_basicblock( + ctx, ir.args[idx + 1], symbols, variables, allocated_variables + ) + + value = None + if ir.value == "call": + value = _convert_ir_basicblock( + ctx, ir.args[idx + 2], symbols, variables, allocated_variables + ) + else: + idx -= 1 + + argsOffset = _convert_ir_basicblock( + ctx, ir.args[idx + 3], symbols, variables, allocated_variables + ) + argsSize = _convert_ir_basicblock( + ctx, ir.args[idx + 4], symbols, variables, allocated_variables + ) + retOffset = _convert_ir_basicblock( + ctx, ir.args[idx + 5], symbols, variables, allocated_variables + ) + retSize = _convert_ir_basicblock( + ctx, ir.args[idx + 6], symbols, variables, allocated_variables + ) + + if isinstance(argsOffset, IRLiteral): + offset = int(argsOffset.value) + addr = offset - 32 + 4 if offset > 0 else 0 + argsOffsetVar = symbols.get(f"&{addr}", None) + if argsOffsetVar is None: + argsOffsetVar = argsOffset + elif isinstance(argsOffsetVar, IRVariable): + argsOffsetVar.mem_type = MemType.MEMORY + argsOffsetVar.mem_addr = addr + argsOffsetVar.offset = 32 - 4 if offset > 0 else 0 + else: # pragma: nocover + raise CompilerPanic("unreachable") + else: + argsOffsetVar = argsOffset + + retOffsetValue = int(retOffset.value) if retOffset else 0 + retVar = ctx.get_next_variable(MemType.MEMORY, retOffsetValue) + symbols[f"&{retOffsetValue}"] = retVar + + if ir.value == "call": + args = [retSize, retOffset, argsSize, argsOffsetVar, value, address, gas] + return ctx.append_instruction(ir.value, args) + else: + args = [retSize, retOffset, argsSize, argsOffsetVar, address, gas] + return ctx.append_instruction(ir.value, args) + elif ir.value == "if": + cond = ir.args[0] + current_bb = ctx.get_basic_block() + + # convert the condition + cont_ret = _convert_ir_basicblock(ctx, cond, symbols, variables, allocated_variables) + + else_block = IRBasicBlock(ctx.get_next_label(), ctx) + ctx.append_basic_block(else_block) + + # convert "else" + else_ret_val = None + else_syms = symbols.copy() + if len(ir.args) == 3: + else_ret_val = _convert_ir_basicblock( + ctx, ir.args[2], else_syms, variables, allocated_variables.copy() + ) + if isinstance(else_ret_val, IRLiteral): + assert isinstance(else_ret_val.value, int) # help mypy + else_ret_val = ctx.append_instruction("store", [IRLiteral(else_ret_val.value)]) + after_else_syms = else_syms.copy() + + # convert "then" + then_block = IRBasicBlock(ctx.get_next_label(), ctx) + ctx.append_basic_block(then_block) + + then_ret_val = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) + if isinstance(then_ret_val, IRLiteral): + then_ret_val = ctx.append_instruction("store", [IRLiteral(then_ret_val.value)]) + + inst = IRInstruction("jnz", [cont_ret, then_block.label, else_block.label]) + current_bb.append_instruction(inst) + + after_then_syms = symbols.copy() + + # exit bb + exit_label = ctx.get_next_label() + bb = IRBasicBlock(exit_label, ctx) + bb = ctx.append_basic_block(bb) + + if_ret = None + if then_ret_val is not None and else_ret_val is not None: + if_ret = ctx.get_next_variable() + bb.append_instruction( + IRInstruction( + "phi", [then_block.label, then_ret_val, else_block.label, else_ret_val], if_ret + ) + ) + + common_symbols = _get_symbols_common(after_then_syms, after_else_syms) + for sym, val in common_symbols.items(): + ret = ctx.get_next_variable() + old_var = symbols.get(sym, None) + symbols[sym] = ret + if old_var is not None: + for idx, var_rec in allocated_variables.items(): # type: ignore + if var_rec.value == old_var.value: + allocated_variables[idx] = ret # type: ignore + bb.append_instruction( + IRInstruction("phi", [then_block.label, val[0], else_block.label, val[1]], ret) + ) + + if not else_block.is_terminated: + exit_inst = IRInstruction("jmp", [bb.label]) + else_block.append_instruction(exit_inst) + + if not then_block.is_terminated: + exit_inst = IRInstruction("jmp", [bb.label]) + then_block.append_instruction(exit_inst) + + return if_ret + + elif ir.value == "with": + ret = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) # initialization + + # Handle with nesting with same symbol + with_symbols = symbols.copy() + + sym = ir.args[0] + if isinstance(ret, IRLiteral): + new_var = ctx.append_instruction("store", [ret]) # type: ignore + with_symbols[sym.value] = new_var + else: + with_symbols[sym.value] = ret # type: ignore + + return _convert_ir_basicblock( + ctx, ir.args[2], with_symbols, variables, allocated_variables + ) # body + elif ir.value == "goto": + _append_jmp(ctx, IRLabel(ir.args[0].value)) + elif ir.value == "jump": + arg_1 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + inst = IRInstruction("jmp", [arg_1]) + ctx.get_basic_block().append_instruction(inst) + _new_block(ctx) + elif ir.value == "set": + sym = ir.args[0] + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + new_var = ctx.append_instruction("store", [arg_1]) # type: ignore + symbols[sym.value] = new_var + + elif ir.value == "calldatacopy": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + new_v = arg_0 + var = ( + _get_variable_from_address(variables, int(arg_0.value)) + if isinstance(arg_0, IRLiteral) + else None + ) + if var is not None: + if allocated_variables.get(var.name, None) is None: + new_v = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] # type: ignore + ) + allocated_variables[var.name] = new_v # type: ignore + ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + symbols[f"&{var.pos}"] = new_v # type: ignore + else: + ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + + return new_v + elif ir.value == "codecopy": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + ctx.append_instruction("codecopy", [size, arg_1, arg_0], False) # type: ignore + elif ir.value == "symbol": + return IRLabel(ir.args[0].value, True) + elif ir.value == "data": + label = IRLabel(ir.args[0].value) + ctx.append_data("dbname", [label]) + for c in ir.args[1:]: + if isinstance(c, int): + assert 0 <= c <= 255, "data with invalid size" + ctx.append_data("db", [c]) # type: ignore + elif isinstance(c, bytes): + ctx.append_data("db", [c]) # type: ignore + elif isinstance(c, IRnode): + data = _convert_ir_basicblock(ctx, c, symbols, variables, allocated_variables) + ctx.append_data("db", [data]) # type: ignore + elif ir.value == "assert": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + current_bb = ctx.get_basic_block() + inst = IRInstruction("assert", [arg_0]) # type: ignore + current_bb.append_instruction(inst) + elif ir.value == "label": + label = IRLabel(ir.args[0].value, True) + if not ctx.get_basic_block().is_terminated: + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + bb = IRBasicBlock(label, ctx) + ctx.append_basic_block(bb) + _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + elif ir.value == "exit_to": + func_t = ir.passthrough_metadata.get("func_t", None) + assert func_t is not None, "exit_to without func_t" + + if func_t.is_external: + # Hardcoded contructor special case + if func_t.name == "__init__": + label = IRLabel(ir.args[0].value, True) + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + return None + if func_t.return_type is None: + inst = IRInstruction("stop", []) + ctx.get_basic_block().append_instruction(inst) + return None + else: + last_ir = None + ret_var = ir.args[1] + deleted = None + if ret_var.is_literal and symbols.get(f"&{ret_var.value}", None) is not None: + deleted = symbols[f"&{ret_var.value}"] + del symbols[f"&{ret_var.value}"] + for arg in ir.args[2:]: + last_ir = _convert_ir_basicblock( + ctx, arg, symbols, variables, allocated_variables + ) + if deleted is not None: + symbols[f"&{ret_var.value}"] = deleted + + ret_ir = _convert_ir_basicblock( + ctx, ret_var, symbols, variables, allocated_variables + ) + + var = ( + _get_variable_from_address(variables, int(ret_ir.value)) + if isinstance(ret_ir, IRLiteral) + else None + ) + if var is not None: + allocated_var = allocated_variables.get(var.name, None) + assert allocated_var is not None, "unallocated variable" + new_var = symbols.get(f"&{ret_ir.value}", allocated_var) # type: ignore + + if var.size and int(var.size) > 32: + offset = int(ret_ir.value) - var.pos # type: ignore + if offset > 0: + ptr_var = ctx.append_instruction( + "add", [IRLiteral(var.pos), IRLiteral(offset)] + ) + else: + ptr_var = allocated_var + inst = IRInstruction("return", [last_ir, ptr_var]) + else: + inst = _get_return_for_stack_operand(ctx, symbols, new_var, last_ir) + else: + if isinstance(ret_ir, IRLiteral): + sym = symbols.get(f"&{ret_ir.value}", None) + if sym is None: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + if func_t.return_type.memory_bytes_required > 32: + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) + ctx.append_instruction("mstore", [sym, new_var], False) + inst = IRInstruction("return", [last_ir, new_var]) + else: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + if last_ir and int(last_ir.value) > 32: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + ret_buf = IRLiteral(128) # TODO: need allocator + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_buf]) + ctx.append_instruction("mstore", [ret_ir, new_var], False) + inst = IRInstruction("return", [last_ir, new_var]) + + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + + if func_t.is_internal: + assert ir.args[1].value == "return_pc", "return_pc not found" + if func_t.return_type is None: + inst = IRInstruction("ret", [symbols["return_pc"]]) + else: + if func_t.return_type.memory_bytes_required > 32: + inst = IRInstruction("ret", [symbols["return_buffer"], symbols["return_pc"]]) + else: + ret_by_value = ctx.append_instruction("mload", [symbols["return_buffer"]]) + inst = IRInstruction("ret", [ret_by_value, symbols["return_pc"]]) + + ctx.get_basic_block().append_instruction(inst) + + elif ir.value == "revert": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + inst = IRInstruction("revert", [arg_1, arg_0]) + ctx.get_basic_block().append_instruction(inst) + + elif ir.value == "dload": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + src = ctx.append_instruction("add", [arg_0, IRLabel("code_end")]) + + ctx.append_instruction( + "dloadbytes", [IRLiteral(32), src, IRLiteral(MemoryPositions.FREE_VAR_SPACE)], False + ) + return ctx.append_instruction("mload", [IRLiteral(MemoryPositions.FREE_VAR_SPACE)]) + elif ir.value == "dloadbytes": + dst = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + src_offset = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) + len_ = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + src = ctx.append_instruction("add", [src_offset, IRLabel("code_end")]) + + inst = IRInstruction("dloadbytes", [len_, src, dst]) + ctx.get_basic_block().append_instruction(inst) + return None + elif ir.value == "mload": + sym_ir = ir.args[0] + var = ( + _get_variable_from_address(variables, int(sym_ir.value)) if sym_ir.is_literal else None + ) + if var is not None: + if var.size and var.size > 32: + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + ) + + offset = int(sym_ir.value) - var.pos + if offset > 0: + ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + else: + ptr_var = allocated_variables[var.name] + + return ctx.append_instruction("mload", [ptr_var]) + else: + if sym_ir.is_literal: + sym = symbols.get(f"&{sym_ir.value}", None) + if sym is None: + new_var = ctx.append_instruction("store", [sym_ir]) + symbols[f"&{sym_ir.value}"] = new_var + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = new_var + return new_var + else: + return sym + + sym = symbols.get(f"&{sym_ir.value}", None) + assert sym is not None, "unallocated variable" + return sym + else: + if sym_ir.is_literal: + new_var = symbols.get(f"&{sym_ir.value}", None) + if new_var is not None: + return ctx.append_instruction("mload", [new_var]) + else: + return ctx.append_instruction("mload", [IRLiteral(sym_ir.value)]) + else: + new_var = _convert_ir_basicblock( + ctx, sym_ir, symbols, variables, allocated_variables + ) + # + # Old IR gets it's return value as a reference in the stack + # New IR gets it's return value in stack in case of 32 bytes or less + # So here we detect ahead of time if this mload leads a self call and + # and we skip the mload + # + if sym_ir.is_self_call: + return new_var + return ctx.append_instruction("mload", [new_var]) + + elif ir.value == "mstore": + sym_ir = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + + var = None + if isinstance(sym_ir, IRLiteral): + var = _get_variable_from_address(variables, int(sym_ir.value)) + + if var is not None and var.size is not None: + if var.size and var.size > 32: + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + ) + + offset = int(sym_ir.value) - var.pos + if offset > 0: + ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + else: + ptr_var = allocated_variables[var.name] + + return ctx.append_instruction("mstore", [arg_1, ptr_var], False) + else: + if isinstance(sym_ir, IRLiteral): + new_var = ctx.append_instruction("store", [arg_1]) + symbols[f"&{sym_ir.value}"] = new_var + # if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = new_var + return new_var + else: + if not isinstance(sym_ir, IRLiteral): + inst = IRInstruction("mstore", [arg_1, sym_ir]) + ctx.get_basic_block().append_instruction(inst) + return None + + sym = symbols.get(f"&{sym_ir.value}", None) + if sym is None: + inst = IRInstruction("mstore", [arg_1, sym_ir]) + ctx.get_basic_block().append_instruction(inst) + if arg_1 and not isinstance(sym_ir, IRLiteral): + symbols[f"&{sym_ir.value}"] = arg_1 + return None + + if isinstance(sym_ir, IRLiteral): + inst = IRInstruction("mstore", [arg_1, sym]) + ctx.get_basic_block().append_instruction(inst) + return None + else: + symbols[sym_ir.value] = arg_1 + return arg_1 + + elif ir.value in ["sload", "iload"]: + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + return ctx.append_instruction(ir.value, [arg_0]) + elif ir.value in ["sstore", "istore"]: + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + inst = IRInstruction(ir.value, [arg_1, arg_0]) + ctx.get_basic_block().append_instruction(inst) + elif ir.value == "unique_symbol": + sym = ir.args[0] + new_var = ctx.get_next_variable() + symbols[f"&{sym.value}"] = new_var + return new_var + elif ir.value == "repeat": + # + # repeat(sym, start, end, bound, body) + # 1) entry block ] + # 2) init counter block ] -> same block + # 3) condition block (exit block, body block) + # 4) body block + # 5) increment block + # 6) exit block + # TODO: Add the extra bounds check after clarify + def emit_body_block(): + global _break_target, _continue_target + old_targets = _break_target, _continue_target + _break_target, _continue_target = exit_block, increment_block + _convert_ir_basicblock(ctx, body, symbols, variables, allocated_variables) + _break_target, _continue_target = old_targets + + sym = ir.args[0] + start = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + end = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + # "bound" is not used + _ = _convert_ir_basicblock(ctx, ir.args[3], symbols, variables, allocated_variables) + body = ir.args[4] + + entry_block = ctx.get_basic_block() + cond_block = IRBasicBlock(ctx.get_next_label(), ctx) + body_block = IRBasicBlock(ctx.get_next_label(), ctx) + jump_up_block = IRBasicBlock(ctx.get_next_label(), ctx) + increment_block = IRBasicBlock(ctx.get_next_label(), ctx) + exit_block = IRBasicBlock(ctx.get_next_label(), ctx) + + counter_var = ctx.get_next_variable() + counter_inc_var = ctx.get_next_variable() + ret = ctx.get_next_variable() + + inst = IRInstruction("store", [start], counter_var) + ctx.get_basic_block().append_instruction(inst) + symbols[sym.value] = counter_var + inst = IRInstruction("jmp", [cond_block.label]) + ctx.get_basic_block().append_instruction(inst) + + symbols[sym.value] = ret + cond_block.append_instruction( + IRInstruction( + "phi", [entry_block.label, counter_var, increment_block.label, counter_inc_var], ret + ) + ) + + xor_ret = ctx.get_next_variable() + cont_ret = ctx.get_next_variable() + inst = IRInstruction("xor", [ret, end], xor_ret) + cond_block.append_instruction(inst) + cond_block.append_instruction(IRInstruction("iszero", [xor_ret], cont_ret)) + ctx.append_basic_block(cond_block) + + # Do a dry run to get the symbols needing phi nodes + start_syms = symbols.copy() + ctx.append_basic_block(body_block) + emit_body_block() + end_syms = symbols.copy() + diff_syms = _get_symbols_common(start_syms, end_syms) + + replacements = {} + for sym, val in diff_syms.items(): + new_var = ctx.get_next_variable() + symbols[sym] = new_var + replacements[val[0]] = new_var + replacements[val[1]] = new_var + cond_block.insert_instruction( + IRInstruction( + "phi", [entry_block.label, val[0], increment_block.label, val[1]], new_var + ), + 1, + ) + + body_block.replace_operands(replacements) + + body_end = ctx.get_basic_block() + if not body_end.is_terminated: + body_end.append_instruction(IRInstruction("jmp", [jump_up_block.label])) + + jump_cond = IRInstruction("jmp", [increment_block.label]) + jump_up_block.append_instruction(jump_cond) + ctx.append_basic_block(jump_up_block) + + increment_block.append_instruction( + IRInstruction("add", [ret, IRLiteral(1)], counter_inc_var) + ) + increment_block.append_instruction(IRInstruction("jmp", [cond_block.label])) + ctx.append_basic_block(increment_block) + + ctx.append_basic_block(exit_block) + + inst = IRInstruction("jnz", [cont_ret, exit_block.label, body_block.label]) + cond_block.append_instruction(inst) + elif ir.value == "break": + assert _break_target is not None, "Break with no break target" + inst = IRInstruction("jmp", [_break_target.label]) + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + elif ir.value == "continue": + assert _continue_target is not None, "Continue with no contrinue target" + inst = IRInstruction("jmp", [_continue_target.label]) + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + elif ir.value == "gas": + return ctx.append_instruction("gas", []) + elif ir.value == "returndatasize": + return ctx.append_instruction("returndatasize", []) + elif ir.value == "returndatacopy": + assert len(ir.args) == 3, "returndatacopy with wrong number of arguments" + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + new_var = ctx.append_instruction("returndatacopy", [arg_1, size]) + + symbols[f"&{arg_0.value}"] = new_var + return new_var + elif ir.value == "selfdestruct": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + ctx.append_instruction("selfdestruct", [arg_0], False) + elif isinstance(ir.value, str) and ir.value.startswith("log"): + args = [ + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + for arg in ir.args + ] + inst = IRInstruction(ir.value, reversed(args)) + ctx.get_basic_block().append_instruction(inst) + elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): + _convert_ir_opcode(ctx, ir, symbols, variables, allocated_variables) + elif isinstance(ir.value, str) and ir.value in symbols: + return symbols[ir.value] + elif ir.is_literal: + return IRLiteral(ir.value) + else: + raise Exception(f"Unknown IR node: {ir}") + + return None + + +def _convert_ir_opcode( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> None: + opcode = ir.value.upper() # type: ignore + inst_args = [] + for arg in ir.args: + if isinstance(arg, IRnode): + inst_args.append( + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + ) + instruction = IRInstruction(opcode, inst_args) # type: ignore + ctx.get_basic_block().append_instruction(instruction) + + +def _data_ofst_of(sym, ofst, height_): + # e.g. _OFST _sym_foo 32 + assert is_symbol(sym) or is_mem_sym(sym) + if isinstance(ofst.value, int): + # resolve at compile time using magic _OFST op + return ["_OFST", sym, ofst.value] + else: + # if we can't resolve at compile time, resolve at runtime + # ofst = _compile_to_assembly(ofst, withargs, existing_labels, break_dest, height_) + return ofst + [sym, "ADD"] diff --git a/vyper/venom/passes/base_pass.py b/vyper/venom/passes/base_pass.py new file mode 100644 index 0000000000..11da80ac66 --- /dev/null +++ b/vyper/venom/passes/base_pass.py @@ -0,0 +1,21 @@ +class IRPass: + """ + Decorator for IR passes. This decorator will run the pass repeatedly + until no more changes are made. + """ + + @classmethod + def run_pass(cls, *args, **kwargs): + t = cls() + count = 0 + + while True: + changes_count = t._run_pass(*args, **kwargs) or 0 + count += changes_count + if changes_count == 0: + break + + return count + + def _run_pass(self, *args, **kwargs): + raise NotImplementedError(f"Not implemented! {self.__class__}.run_pass()") diff --git a/vyper/venom/passes/constant_propagation.py b/vyper/venom/passes/constant_propagation.py new file mode 100644 index 0000000000..94b556124e --- /dev/null +++ b/vyper/venom/passes/constant_propagation.py @@ -0,0 +1,13 @@ +from vyper.utils import ir_pass +from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.function import IRFunction + + +def _process_basic_block(ctx: IRFunction, bb: IRBasicBlock): + pass + + +@ir_pass +def ir_pass_constant_propagation(ctx: IRFunction): + for bb in ctx.basic_blocks: + _process_basic_block(ctx, bb) diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py new file mode 100644 index 0000000000..26994bd27f --- /dev/null +++ b/vyper/venom/passes/dft.py @@ -0,0 +1,54 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis import DFG +from vyper.venom.basicblock import IRBasicBlock, IRInstruction +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +# DataFlow Transformation +class DFTPass(IRPass): + def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction): + if inst in self.visited_instructions: + return + self.visited_instructions.add(inst) + + if inst.opcode == "phi": + # phi instructions stay at the beginning of the basic block + # and no input processing is needed + bb.instructions.append(inst) + return + + for op in inst.get_inputs(): + target = self.dfg.get_producing_instruction(op) + if target.parent != inst.parent or target.fence_id != inst.fence_id: + # don't reorder across basic block or fence boundaries + continue + self._process_instruction_r(bb, target) + + bb.instructions.append(inst) + + def _process_basic_block(self, bb: IRBasicBlock) -> None: + self.ctx.append_basic_block(bb) + + instructions = bb.instructions + bb.instructions = [] + + for inst in instructions: + inst.fence_id = self.fence_id + if inst.volatile: + self.fence_id += 1 + + for inst in instructions: + self._process_instruction_r(bb, inst) + + def _run_pass(self, ctx: IRFunction) -> None: + self.ctx = ctx + self.dfg = DFG.build_dfg(ctx) + self.fence_id = 0 + self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() + + basic_blocks = ctx.basic_blocks + ctx.basic_blocks = [] + + for bb in basic_blocks: + self._process_basic_block(bb) diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py new file mode 100644 index 0000000000..9ee1012f91 --- /dev/null +++ b/vyper/venom/passes/normalization.py @@ -0,0 +1,90 @@ +from vyper.exceptions import CompilerPanic +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRVariable +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +class NormalizationPass(IRPass): + """ + This pass splits basic blocks when there are multiple conditional predecessors. + The code generator expect a normalized CFG, that has the property that + each basic block has at most one conditional predecessor. + """ + + changes = 0 + + def _split_basic_block(self, bb: IRBasicBlock) -> None: + # Iterate over the predecessors of the basic block + for in_bb in list(bb.cfg_in): + jump_inst = in_bb.instructions[-1] + assert bb in in_bb.cfg_out + + # Handle static and dynamic branching + if jump_inst.opcode == "jnz": + self._split_for_static_branch(bb, in_bb) + elif jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRVariable): + self._split_for_dynamic_branch(bb, in_bb) + else: + continue + + self.changes += 1 + + def _split_for_static_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: + jump_inst = in_bb.instructions[-1] + for i, op in enumerate(jump_inst.operands): + if op == bb.label: + edge = i + break + else: + # none of the edges points to this bb + raise CompilerPanic("bad CFG") + + assert edge in (1, 2) # the arguments which can be labels + + split_bb = self._insert_split_basicblock(bb, in_bb) + + # Redirect the original conditional jump to the intermediary basic block + jump_inst.operands[edge] = split_bb.label + + def _split_for_dynamic_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: + split_bb = self._insert_split_basicblock(bb, in_bb) + + # Update any affected labels in the data segment + # TODO: this DESTROYS the cfg! refactor so the translation of the + # selector table produces indirect jumps properly. + for inst in self.ctx.data_segment: + if inst.opcode == "db" and inst.operands[0] == bb.label: + inst.operands[0] = split_bb.label + + def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRBasicBlock: + # Create an intermediary basic block and append it + source = in_bb.label.value + target = bb.label.value + split_bb = IRBasicBlock(IRLabel(f"{target}_split_{source}"), self.ctx) + split_bb.append_instruction(IRInstruction("jmp", [bb.label])) + self.ctx.append_basic_block(split_bb) + + # Rewire the CFG + # TODO: this is cursed code, it is necessary instead of just running + # calculate_cfg() because split_for_dynamic_branch destroys the CFG! + # ideally, remove this rewiring and just re-run calculate_cfg(). + split_bb.add_cfg_in(in_bb) + split_bb.add_cfg_out(bb) + in_bb.remove_cfg_out(bb) + in_bb.add_cfg_out(split_bb) + bb.remove_cfg_in(in_bb) + bb.add_cfg_in(split_bb) + return split_bb + + def _run_pass(self, ctx: IRFunction) -> int: + self.ctx = ctx + self.changes = 0 + + for bb in ctx.basic_blocks: + if len(bb.cfg_in) > 1: + self._split_basic_block(bb) + + # Sanity check + assert ctx.normalized, "Normalization pass failed" + + return self.changes diff --git a/vyper/venom/stack_model.py b/vyper/venom/stack_model.py new file mode 100644 index 0000000000..66c62b74d2 --- /dev/null +++ b/vyper/venom/stack_model.py @@ -0,0 +1,100 @@ +from vyper.venom.basicblock import IROperand, IRVariable + + +class StackModel: + NOT_IN_STACK = object() + _stack: list[IROperand] + + def __init__(self): + self._stack = [] + + def copy(self): + new = StackModel() + new._stack = self._stack.copy() + return new + + @property + def height(self) -> int: + """ + Returns the height of the stack map. + """ + return len(self._stack) + + def push(self, op: IROperand) -> None: + """ + Pushes an operand onto the stack map. + """ + assert isinstance(op, IROperand), f"{type(op)}: {op}" + self._stack.append(op) + + def pop(self, num: int = 1) -> None: + del self._stack[len(self._stack) - num :] + + def get_depth(self, op: IROperand) -> int: + """ + Returns the depth of the first matching operand in the stack map. + If the operand is not in the stack map, returns NOT_IN_STACK. + """ + assert isinstance(op, IROperand), f"{type(op)}: {op}" + + for i, stack_op in enumerate(reversed(self._stack)): + if stack_op.value == op.value: + return -i + + return StackModel.NOT_IN_STACK # type: ignore + + def get_phi_depth(self, phi1: IRVariable, phi2: IRVariable) -> int: + """ + Returns the depth of the first matching phi variable in the stack map. + If the none of the phi operands are in the stack, returns NOT_IN_STACK. + Asserts that exactly one of phi1 and phi2 is found. + """ + assert isinstance(phi1, IRVariable) + assert isinstance(phi2, IRVariable) + + ret = StackModel.NOT_IN_STACK + for i, stack_item in enumerate(reversed(self._stack)): + if stack_item in (phi1, phi2): + assert ( + ret is StackModel.NOT_IN_STACK + ), f"phi argument is not unique! {phi1}, {phi2}, {self._stack}" + ret = -i + + return ret # type: ignore + + def peek(self, depth: int) -> IROperand: + """ + Returns the top of the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot peek non-in-stack depth" + return self._stack[depth - 1] + + def poke(self, depth: int, op: IROperand) -> None: + """ + Pokes an operand at the given depth in the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot poke non-in-stack depth" + assert depth <= 0, "Bad depth" + assert isinstance(op, IROperand), f"{type(op)}: {op}" + self._stack[depth - 1] = op + + def dup(self, depth: int) -> None: + """ + Duplicates the operand at the given depth in the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot dup non-existent operand" + assert depth <= 0, "Cannot dup positive depth" + self._stack.append(self.peek(depth)) + + def swap(self, depth: int) -> None: + """ + Swaps the operand at the given depth in the stack map with the top of the stack. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot swap non-existent operand" + assert depth < 0, "Cannot swap positive depth" + top = self._stack[-1] + self._stack[-1] = self._stack[depth - 1] + self._stack[depth - 1] = top + + def __repr__(self) -> str: + return f"" diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py new file mode 100644 index 0000000000..f6ec45440a --- /dev/null +++ b/vyper/venom/venom_to_assembly.py @@ -0,0 +1,461 @@ +from typing import Any + +from vyper.ir.compile_ir import PUSH, DataHeader, RuntimeHeader, optimize_assembly +from vyper.utils import MemoryPositions, OrderedSet +from vyper.venom.analysis import calculate_cfg, calculate_liveness, input_vars_from +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, + MemType, +) +from vyper.venom.function import IRFunction +from vyper.venom.passes.normalization import NormalizationPass +from vyper.venom.stack_model import StackModel + +# instructions which map one-to-one from venom to EVM +_ONE_TO_ONE_INSTRUCTIONS = frozenset( + [ + "revert", + "coinbase", + "calldatasize", + "calldatacopy", + "calldataload", + "gas", + "gasprice", + "gaslimit", + "address", + "origin", + "number", + "extcodesize", + "extcodehash", + "returndatasize", + "returndatacopy", + "callvalue", + "selfbalance", + "sload", + "sstore", + "mload", + "mstore", + "timestamp", + "caller", + "selfdestruct", + "signextend", + "stop", + "shr", + "shl", + "and", + "xor", + "or", + "add", + "sub", + "mul", + "div", + "mod", + "exp", + "eq", + "iszero", + "lg", + "lt", + "slt", + "sgt", + "log0", + "log1", + "log2", + "log3", + "log4", + ] +) + + +# TODO: "assembly" gets into the recursion due to how the original +# IR was structured recursively in regards with the deploy instruction. +# There, recursing into the deploy instruction was by design, and +# made it easier to make the assembly generated "recursive" (i.e. +# instructions being lists of instructions). We don't have this restriction +# anymore, so we can probably refactor this to be iterative in coordination +# with the assembler. My suggestion is to let this be for now, and we can +# refactor it later when we are finished phasing out the old IR. +class VenomCompiler: + ctx: IRFunction + label_counter = 0 + visited_instructions: OrderedSet # {IRInstruction} + visited_basicblocks: OrderedSet # {IRBasicBlock} + + def __init__(self, ctx: IRFunction): + self.ctx = ctx + self.label_counter = 0 + self.visited_instructions = OrderedSet() + self.visited_basicblocks = OrderedSet() + + def generate_evm(self, no_optimize: bool = False) -> list[str]: + self.visited_instructions = OrderedSet() + self.visited_basicblocks = OrderedSet() + self.label_counter = 0 + + stack = StackModel() + asm: list[str] = [] + + # Before emitting the assembly, we need to make sure that the + # CFG is normalized. Calling calculate_cfg() will denormalize IR (reset) + # so it should not be called after calling NormalizationPass.run_pass(). + # Liveness is then computed for the normalized IR, and we can proceed to + # assembly generation. + # This is a side-effect of how dynamic jumps are temporarily being used + # to support the O(1) dispatcher. -> look into calculate_cfg() + calculate_cfg(self.ctx) + NormalizationPass.run_pass(self.ctx) + calculate_liveness(self.ctx) + + assert self.ctx.normalized, "Non-normalized CFG!" + + self._generate_evm_for_basicblock_r(asm, self.ctx.basic_blocks[0], stack) + + # Append postambles + revert_postamble = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] + runtime = None + if isinstance(asm[-1], list) and isinstance(asm[-1][0], RuntimeHeader): + runtime = asm.pop() + + asm.extend(revert_postamble) + if runtime: + runtime.extend(revert_postamble) + asm.append(runtime) + + # Append data segment + data_segments: dict[Any, list[Any]] = dict() + for inst in self.ctx.data_segment: + if inst.opcode == "dbname": + label = inst.operands[0].value + data_segments[label] = [DataHeader(f"_sym_{label}")] + elif inst.opcode == "db": + data_segments[label].append(f"_sym_{inst.operands[0].value}") + + extent_point = asm if not isinstance(asm[-1], list) else asm[-1] + extent_point.extend([data_segments[label] for label in data_segments]) # type: ignore + + if no_optimize is False: + optimize_assembly(asm) + + return asm + + def _stack_reorder( + self, assembly: list, stack: StackModel, _stack_ops: OrderedSet[IRVariable] + ) -> None: + # make a list so we can index it + stack_ops = [x for x in _stack_ops.keys()] + stack_ops_count = len(_stack_ops) + + for i in range(stack_ops_count): + op = stack_ops[i] + final_stack_depth = -(stack_ops_count - i - 1) + depth = stack.get_depth(op) # type: ignore + + if depth == final_stack_depth: + continue + + self.swap(assembly, stack, depth) + self.swap(assembly, stack, final_stack_depth) + + def _emit_input_operands( + self, assembly: list, inst: IRInstruction, ops: list[IROperand], stack: StackModel + ) -> None: + # PRE: we already have all the items on the stack that have + # been scheduled to be killed. now it's just a matter of emitting + # SWAPs, DUPs and PUSHes until we match the `ops` argument + + # dumb heuristic: if the top of stack is not wanted here, swap + # it with something that is wanted + if ops and stack.height > 0 and stack.peek(0) not in ops: + for op in ops: + if isinstance(op, IRVariable) and op not in inst.dup_requirements: + self.swap_op(assembly, stack, op) + break + + emitted_ops = OrderedSet[IROperand]() + for op in ops: + if isinstance(op, IRLabel): + # invoke emits the actual instruction itself so we don't need to emit it here + # but we need to add it to the stack map + if inst.opcode != "invoke": + assembly.append(f"_sym_{op.value}") + stack.push(op) + continue + + if isinstance(op, IRLiteral): + assembly.extend([*PUSH(op.value)]) + stack.push(op) + continue + + if op in inst.dup_requirements: + self.dup_op(assembly, stack, op) + + if op in emitted_ops: + self.dup_op(assembly, stack, op) + + # REVIEW: this seems like it can be reordered across volatile + # boundaries (which includes memory fences). maybe just + # remove it entirely at this point + if isinstance(op, IRVariable) and op.mem_type == MemType.MEMORY: + assembly.extend([*PUSH(op.mem_addr)]) + assembly.append("MLOAD") + + emitted_ops.add(op) + + def _generate_evm_for_basicblock_r( + self, asm: list, basicblock: IRBasicBlock, stack: StackModel + ) -> None: + if basicblock in self.visited_basicblocks: + return + self.visited_basicblocks.add(basicblock) + + # assembly entry point into the block + asm.append(f"_sym_{basicblock.label}") + asm.append("JUMPDEST") + + self.clean_stack_from_cfg_in(asm, basicblock, stack) + + for inst in basicblock.instructions: + asm = self._generate_evm_for_instruction(asm, inst, stack) + + for bb in basicblock.cfg_out: + self._generate_evm_for_basicblock_r(asm, bb, stack.copy()) + + # pop values from stack at entry to bb + # note this produces the same result(!) no matter which basic block + # we enter from in the CFG. + def clean_stack_from_cfg_in( + self, asm: list, basicblock: IRBasicBlock, stack: StackModel + ) -> None: + if len(basicblock.cfg_in) == 0: + return + + to_pop = OrderedSet[IRVariable]() + for in_bb in basicblock.cfg_in: + # inputs is the input variables we need from in_bb + inputs = input_vars_from(in_bb, basicblock) + + # layout is the output stack layout for in_bb (which works + # for all possible cfg_outs from the in_bb). + layout = in_bb.out_vars + + # pop all the stack items which in_bb produced which we don't need. + to_pop |= layout.difference(inputs) + + for var in to_pop: + depth = stack.get_depth(var) + # don't pop phantom phi inputs + if depth is StackModel.NOT_IN_STACK: + continue + + if depth != 0: + stack.swap(depth) + self.pop(asm, stack) + + def _generate_evm_for_instruction( + self, assembly: list, inst: IRInstruction, stack: StackModel + ) -> list[str]: + opcode = inst.opcode + + # + # generate EVM for op + # + + # Step 1: Apply instruction special stack manipulations + + if opcode in ["jmp", "jnz", "invoke"]: + operands = inst.get_non_label_operands() + elif opcode == "alloca": + operands = inst.operands[1:2] + elif opcode == "iload": + operands = [] + elif opcode == "istore": + operands = inst.operands[0:1] + else: + operands = inst.operands + + if opcode == "phi": + ret = inst.get_outputs()[0] + phi1, phi2 = inst.get_inputs() + depth = stack.get_phi_depth(phi1, phi2) + # collapse the arguments to the phi node in the stack. + # example, for `%56 = %label1 %13 %label2 %14`, we will + # find an instance of %13 *or* %14 in the stack and replace it with %56. + to_be_replaced = stack.peek(depth) + if to_be_replaced in inst.dup_requirements: + # %13/%14 is still live(!), so we make a copy of it + self.dup(assembly, stack, depth) + stack.poke(0, ret) + else: + stack.poke(depth, ret) + return assembly + + # Step 2: Emit instruction's input operands + self._emit_input_operands(assembly, inst, operands, stack) + + # Step 3: Reorder stack + if opcode in ["jnz", "jmp"]: + # prepare stack for jump into another basic block + assert inst.parent and isinstance(inst.parent.cfg_out, OrderedSet) + b = next(iter(inst.parent.cfg_out)) + target_stack = input_vars_from(inst.parent, b) + # TODO optimize stack reordering at entry and exit from basic blocks + self._stack_reorder(assembly, stack, target_stack) + + # final step to get the inputs to this instruction ordered + # correctly on the stack + self._stack_reorder(assembly, stack, OrderedSet(operands)) + + # some instructions (i.e. invoke) need to do stack manipulations + # with the stack model containing the return value(s), so we fiddle + # with the stack model beforehand. + + # Step 4: Push instruction's return value to stack + stack.pop(len(operands)) + if inst.output is not None: + stack.push(inst.output) + + # Step 5: Emit the EVM instruction(s) + if opcode in _ONE_TO_ONE_INSTRUCTIONS: + assembly.append(opcode.upper()) + elif opcode == "alloca": + pass + elif opcode == "param": + pass + elif opcode == "store": + pass + elif opcode == "dbname": + pass + elif opcode in ["codecopy", "dloadbytes"]: + assembly.append("CODECOPY") + elif opcode == "jnz": + # jump if not zero + if_nonzero_label = inst.operands[1] + if_zero_label = inst.operands[2] + assembly.append(f"_sym_{if_nonzero_label.value}") + assembly.append("JUMPI") + + # make sure the if_zero_label will be optimized out + # assert if_zero_label == next(iter(inst.parent.cfg_out)).label + + assembly.append(f"_sym_{if_zero_label.value}") + assembly.append("JUMP") + + elif opcode == "jmp": + if isinstance(inst.operands[0], IRLabel): + assembly.append(f"_sym_{inst.operands[0].value}") + assembly.append("JUMP") + else: + assembly.append("JUMP") + elif opcode == "gt": + assembly.append("GT") + elif opcode == "lt": + assembly.append("LT") + elif opcode == "invoke": + target = inst.operands[0] + assert isinstance(target, IRLabel), "invoke target must be a label" + assembly.extend( + [ + f"_sym_label_ret_{self.label_counter}", + f"_sym_{target.value}", + "JUMP", + f"_sym_label_ret_{self.label_counter}", + "JUMPDEST", + ] + ) + self.label_counter += 1 + if stack.height > 0 and stack.peek(0) in inst.dup_requirements: + self.pop(assembly, stack) + elif opcode == "call": + assembly.append("CALL") + elif opcode == "staticcall": + assembly.append("STATICCALL") + elif opcode == "ret": + assembly.append("JUMP") + elif opcode == "return": + assembly.append("RETURN") + elif opcode == "phi": + pass + elif opcode == "sha3": + assembly.append("SHA3") + elif opcode == "sha3_64": + assembly.extend( + [ + *PUSH(MemoryPositions.FREE_VAR_SPACE2), + "MSTORE", + *PUSH(MemoryPositions.FREE_VAR_SPACE), + "MSTORE", + *PUSH(64), + *PUSH(MemoryPositions.FREE_VAR_SPACE), + "SHA3", + ] + ) + elif opcode == "ceil32": + assembly.extend([*PUSH(31), "ADD", *PUSH(31), "NOT", "AND"]) + elif opcode == "assert": + assembly.extend(["ISZERO", "_sym___revert", "JUMPI"]) + elif opcode == "deploy": + memsize = inst.operands[0].value + padding = inst.operands[2].value + # TODO: fix this by removing deploy opcode altogether me move emition to ir translation + while assembly[-1] != "JUMPDEST": + assembly.pop() + assembly.extend( + ["_sym_subcode_size", "_sym_runtime_begin", "_mem_deploy_start", "CODECOPY"] + ) + assembly.extend(["_OFST", "_sym_subcode_size", padding]) # stack: len + assembly.extend(["_mem_deploy_start"]) # stack: len mem_ofst + assembly.extend(["RETURN"]) + assembly.append([RuntimeHeader("_sym_runtime_begin", memsize, padding)]) # type: ignore + assembly = assembly[-1] + elif opcode == "iload": + loc = inst.operands[0].value + assembly.extend(["_OFST", "_mem_deploy_end", loc, "MLOAD"]) + elif opcode == "istore": + loc = inst.operands[1].value + assembly.extend(["_OFST", "_mem_deploy_end", loc, "MSTORE"]) + else: + raise Exception(f"Unknown opcode: {opcode}") + + # Step 6: Emit instructions output operands (if any) + if inst.output is not None: + assert isinstance(inst.output, IRVariable), "Return value must be a variable" + if inst.output.mem_type == MemType.MEMORY: + assembly.extend([*PUSH(inst.output.mem_addr)]) + + return assembly + + def pop(self, assembly, stack, num=1): + stack.pop(num) + assembly.extend(["POP"] * num) + + def swap(self, assembly, stack, depth): + if depth == 0: + return + stack.swap(depth) + assembly.append(_evm_swap_for(depth)) + + def dup(self, assembly, stack, depth): + stack.dup(depth) + assembly.append(_evm_dup_for(depth)) + + def swap_op(self, assembly, stack, op): + self.swap(assembly, stack, stack.get_depth(op)) + + def dup_op(self, assembly, stack, op): + self.dup(assembly, stack, stack.get_depth(op)) + + +def _evm_swap_for(depth: int) -> str: + swap_idx = -depth + assert 1 <= swap_idx <= 16, "Unsupported swap depth" + return f"SWAP{swap_idx}" + + +def _evm_dup_for(depth: int) -> str: + dup_idx = 1 - depth + assert 1 <= dup_idx <= 16, "Unsupported dup depth" + return f"DUP{dup_idx}"