From 1f374d41fa2eecbdcd97574bc7d0965c88ed5742 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 9 Jul 2021 16:27:59 -0700 Subject: [PATCH 1/2] Reduce code size by sharing code for clamps All clamps have the following code (assume condition is on the stack) ISZERO _revert JUMPI # if not condition goto _revert PUSH1 0 DUP REVERT _revert JUMPDEST Since clamps are super common this code gets repeated. This reduces code size by adding the PUSH1 0 DUP REVERT to each program's postamble, and then clamps are simplified to the following (assume condition is on the stack) _revert JUMPI # if not condition goto _revert -- cool now 5 opcodes are skipped This commit also optimizes the following common "truthy" sequences: ISZERO ISZERO ISZERO -> ISZERO ISZERI ISZERO _jumpdest JUMPI -> _jumpdest JUMPI --- vyper/compile_lll.py | 203 +++++++++++++++++++++++++++++-------------- 1 file changed, 136 insertions(+), 67 deletions(-) diff --git a/vyper/compile_lll.py b/vyper/compile_lll.py index 41ed685ae6..8c79f91141 100644 --- a/vyper/compile_lll.py +++ b/vyper/compile_lll.py @@ -47,16 +47,31 @@ def is_symbol(i): return isinstance(i, str) and i[:5] == "_sym_" -def get_revert(mem_start=None, mem_len=None): - o = [] - end_symbol = mksymbol() - o.extend([end_symbol, "JUMPI"]) - if (mem_start, mem_len) == (None, None): - o.extend(["PUSH1", 0, "DUP1", "REVERT"]) - else: - o.extend([mem_len, mem_start, "REVERT"]) - o.extend([end_symbol, "JUMPDEST"]) - return o +def _assert_false(): + # use a shared failure block for common case of assert(x). + # in the future we might want to change the code + # at _sym_revert0 to: INVALID + return ["_sym_revert0", "JUMPI"] + + +def _add_postambles(asm_ops): + to_append = [] + + if "_sym_revert0" in asm_ops: + # shared failure block + to_append.extend(["_sym_revert0", "JUMPDEST", "PUSH1", 0, "DUP1", "REVERT"]) + + if len(to_append) > 0: + # for some reason there might not be a STOP at the end of asm_ops. + # (generally vyper programs will have it but raw LLL might not). + asm_ops.append("STOP") + asm_ops.extend(to_append) + + # need to do this recursively since every sublist is basically + # treated as its own program (there are no global labels.) + for t in asm_ops: + if isinstance(t, list): + _add_postambles(t) class instruction(str): @@ -85,9 +100,16 @@ def apply_line_no_wrapper(*args, **kwargs): return apply_line_no_wrapper +def compile_to_assembly(code): + res = _compile_to_assembly(code) + + _add_postambles(res) + return res + + # Compiles LLL to assembly @apply_line_numbers -def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=None, height=0): +def _compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=None, height=0): if withargs is None: withargs = {} if not isinstance(withargs, dict): @@ -102,7 +124,7 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No if isinstance(code.value, str) and code.value.upper() in get_opcodes(): o = [] for i, c in enumerate(code.args[::-1]): - o.extend(compile_to_assembly(c, withargs, existing_labels, break_dest, height + i)) + o.extend(_compile_to_assembly(c, withargs, existing_labels, break_dest, height + i)) o.append(code.value.upper()) return o # Numbers @@ -124,7 +146,7 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No raise Exception("Set expects two arguments, the first being a stack variable") if height - withargs[code.args[0].value] > 16: raise Exception("With statement too deep") - return compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height) + [ + return _compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height) + [ "SWAP" + str(height - withargs[code.args[0].value]), "POP", ] @@ -136,7 +158,7 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No return ["_sym_codeend"] # Calldataload equivalent for code elif code.value == "codeload": - return compile_to_assembly( + return _compile_to_assembly( LLLnode.from_list( [ "seq", @@ -152,22 +174,22 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No # If statements (2 arguments, ie. if x: y) elif code.value in ("if", "if_unchecked") and len(code.args) == 2: o = [] - o.extend(compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height)) + o.extend(_compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height)) end_symbol = mksymbol() o.extend(["ISZERO", end_symbol, "JUMPI"]) - o.extend(compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) + o.extend(_compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) o.extend([end_symbol, "JUMPDEST"]) return o # If statements (3 arguments, ie. if x: y, else: z) elif code.value == "if" and len(code.args) == 3: o = [] - o.extend(compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height)) + o.extend(_compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height)) mid_symbol = mksymbol() end_symbol = mksymbol() o.extend(["ISZERO", mid_symbol, "JUMPI"]) - o.extend(compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) + o.extend(_compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) o.extend([end_symbol, "JUMP", mid_symbol, "JUMPDEST"]) - o.extend(compile_to_assembly(code.args[2], withargs, existing_labels, break_dest, height)) + o.extend(_compile_to_assembly(code.args[2], withargs, existing_labels, break_dest, height)) o.extend([end_symbol, "JUMPDEST"]) return o # Repeat statements (compiled from for loops) @@ -176,16 +198,16 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No o = [] loops = num_to_bytearray(code.args[2].value) start, continue_dest, end = mksymbol(), mksymbol(), mksymbol() - o.extend(compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height)) + o.extend(_compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height)) o.extend( - compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height + 1,) + _compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height + 1,) ) o.extend(["PUSH" + str(len(loops))] + loops) # stack: memloc, startvalue, rounds o.extend(["DUP2", "DUP4", "MSTORE", "ADD", start, "JUMPDEST"]) # stack: memloc, exit_index o.extend( - compile_to_assembly( + _compile_to_assembly( code.args[3], withargs, existing_labels, @@ -232,11 +254,11 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No # With statements elif code.value == "with": o = [] - o.extend(compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) + o.extend(_compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) old = withargs.get(code.args[0].value, None) withargs[code.args[0].value] = height o.extend( - compile_to_assembly(code.args[2], withargs, existing_labels, break_dest, height + 1,) + _compile_to_assembly(code.args[2], withargs, existing_labels, break_dest, height + 1,) ) if code.args[2].valency: o.extend(["SWAP1", "POP"]) @@ -253,17 +275,29 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No begincode = mksymbol() endcode = mksymbol() o.extend([endcode, "JUMP", begincode, "BLANK"]) - # The `append(...)` call here is intentional - o.append(compile_to_assembly(code.args[0], {}, existing_labels, None, 0)) + + lll = _compile_to_assembly(code.args[0], {}, existing_labels, None, 0) + + # `append(...)` call here is intentional. + # each sublist is essentially its own program with its + # own symbols. + # in the later step when the "lll" block compiled to EVM, + # compile_to_evm has logic to resolve symbols in "lll" to + # position from start of runtime-code (instead of position + # from start of bytecode). + o.append(lll) + o.extend([endcode, "JUMPDEST", begincode, endcode, "SUB", begincode]) - o.extend(compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) + o.extend(_compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) + + # COPY the code to memory for deploy o.extend(["CODECOPY", begincode, endcode, "SUB"]) return o # Seq (used to piece together multiple statements) elif code.value == "seq": o = [] for arg in code.args: - o.extend(compile_to_assembly(arg, withargs, existing_labels, break_dest, height)) + o.extend(_compile_to_assembly(arg, withargs, existing_labels, break_dest, height)) if arg.valency == 1 and arg != code.args[-1]: o.append("POP") return o @@ -271,20 +305,21 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No elif code.value == "seq_unchecked": o = [] for arg in code.args: - o.extend(compile_to_assembly(arg, withargs, existing_labels, break_dest, height)) + o.extend(_compile_to_assembly(arg, withargs, existing_labels, break_dest, height)) # if arg.valency == 1 and arg != code.args[-1]: # o.append('POP') return o # Assure (if false, invalid opcode) elif code.value == "assert_unreachable": - o = compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) + o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) end_symbol = mksymbol() o.extend([end_symbol, "JUMPI", "INVALID", end_symbol, "JUMPDEST"]) return o # Assert (if false, exit) elif code.value == "assert": - o = compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) - o.extend(get_revert()) + o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) + o.extend(["ISZERO"]) + o.extend(_assert_false()) return o # Unsigned/signed clamp, check less-than elif code.value in CLAMP_OP_NAMES: @@ -301,63 +336,63 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No ) ) if is_free_of_clamp_errors: - return compile_to_assembly( + return _compile_to_assembly( code.args[0], withargs, existing_labels, break_dest, height, ) else: raise Exception( f"Invalid {code.value} with values {code.args[0]} and {code.args[1]}" ) - o = compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) + o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) o.extend( - compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height + 1,) + _compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height + 1,) ) o.extend(["DUP2"]) # Stack: num num bound if code.value == "uclamplt": - o.extend(["LT"]) + o.extend(["LT", "ISZERO"]) elif code.value == "clamplt": - o.extend(["SLT"]) + o.extend(["SLT", "ISZERO"]) elif code.value == "uclample": - o.extend(["GT", "ISZERO"]) + o.extend(["GT"]) elif code.value == "clample": - o.extend(["SGT", "ISZERO"]) + o.extend(["SGT"]) elif code.value == "uclampgt": - o.extend(["GT"]) + o.extend(["GT", "ISZERO"]) elif code.value == "clampgt": - o.extend(["SGT"]) + o.extend(["SGT", "ISZERO"]) elif code.value == "uclampge": - o.extend(["LT", "ISZERO"]) + o.extend(["LT"]) elif code.value == "clampge": - o.extend(["SLT", "ISZERO"]) - o.extend(get_revert()) + o.extend(["SLT"]) + o.extend(_assert_false()) return o # Signed clamp, check against upper and lower bounds elif code.value in ("clamp", "uclamp"): comp1 = "SGT" if code.value == "clamp" else "GT" comp2 = "SLT" if code.value == "clamp" else "LT" - o = compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) + o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) o.extend( - compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height + 1,) + _compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height + 1,) ) o.extend(["DUP1"]) o.extend( - compile_to_assembly(code.args[2], withargs, existing_labels, break_dest, height + 3,) + _compile_to_assembly(code.args[2], withargs, existing_labels, break_dest, height + 3,) ) - o.extend(["SWAP1", comp1, "ISZERO"]) - o.extend(get_revert()) - o.extend(["DUP1", "SWAP2", "SWAP1", comp2, "ISZERO"]) - o.extend(get_revert()) + o.extend(["SWAP1", comp1]) + o.extend(_assert_false()) + o.extend(["DUP1", "SWAP2", "SWAP1", comp2]) + o.extend(_assert_false()) return o # Checks that a value is nonzero elif code.value == "clamp_nonzero": - o = compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) - o.extend(["DUP1"]) - o.extend(get_revert()) + o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) + o.extend(["DUP1", "ISZERO"]) + o.extend(_assert_false()) return o # SHA3 a single value elif code.value == "sha3_32": - o = compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) + o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) o.extend( [ "PUSH1", @@ -373,8 +408,8 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No return o # SHA3 a 64 byte value elif code.value == "sha3_64": - o = compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) - o.extend(compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) + o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) + o.extend(_compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) o.extend( [ "PUSH1", @@ -393,7 +428,7 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No return o # <= operator elif code.value == "le": - return compile_to_assembly( + return _compile_to_assembly( LLLnode.from_list(["iszero", ["gt", code.args[0], code.args[1]]]), withargs, existing_labels, @@ -402,7 +437,7 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No ) # >= operator elif code.value == "ge": - return compile_to_assembly( + return _compile_to_assembly( LLLnode.from_list(["iszero", ["lt", code.args[0], code.args[1]]]), withargs, existing_labels, @@ -411,7 +446,7 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No ) # <= operator elif code.value == "sle": - return compile_to_assembly( + return _compile_to_assembly( LLLnode.from_list(["iszero", ["sgt", code.args[0], code.args[1]]]), withargs, existing_labels, @@ -420,7 +455,7 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No ) # >= operator elif code.value == "sge": - return compile_to_assembly( + return _compile_to_assembly( LLLnode.from_list(["iszero", ["slt", code.args[0], code.args[1]]]), withargs, existing_labels, @@ -429,7 +464,7 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No ) # != operator elif code.value == "ne": - return compile_to_assembly( + return _compile_to_assembly( LLLnode.from_list(["iszero", ["eq", code.args[0], code.args[1]]]), withargs, existing_labels, @@ -438,7 +473,7 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No ) # e.g. 95 -> 96, 96 -> 96, 97 -> 128 elif code.value == "ceil32": - return compile_to_assembly( + return _compile_to_assembly( LLLnode.from_list( [ "with", @@ -500,9 +535,7 @@ def note_breakpoint(line_number_map, item, pos): line_number_map["breakpoints"].add(item.lineno + 1) -# Assembles assembly into EVM -def assembly_to_evm(assembly, start_pos=0): - +def _prune_unreachable_code(assembly): # In converting LLL to assembly we sometimes end up with unreachable # instructions - POPing to clear the stack or STOPing execution at the # end of a function that has already returned or reverted. This should @@ -517,6 +550,8 @@ def assembly_to_evm(assembly, start_pos=0): else: i += 1 + +def _merge_jumpdests(assembly): # When a nested subroutine finishes and is the final action within it's # parent subroutine, we end up with multiple simultaneous JUMPDEST # instructions that can be merged to reduce the bytecode size. @@ -530,16 +565,50 @@ def assembly_to_evm(assembly, start_pos=0): continue i += 1 + +def _merge_iszero(assembly): + i = 0 + while i < len(assembly) - 2: + if assembly[i : i + 3] == ["ISZERO", "ISZERO", "ISZERO"]: # noqa: E203 + del assembly[i : i + 2] # noqa: E203 + else: + i += 1 + i = 0 + while i < len(assembly) - 3: + # ISZERO ISZERO could map truthy to 1, + # but it could also just be a no-op before JUMPI. + if ( + assembly[i : i + 2] == ["ISZERO", "ISZERO"] # noqa: E203 + and is_symbol(assembly[i + 2]) + and assembly[i + 3] == "JUMPI" + ): + del assembly[i : i + 2] # noqa: E203 + else: + i += 1 + + +# Assembles assembly into EVM +def assembly_to_evm(assembly, start_pos=0): + _prune_unreachable_code(assembly) + + _merge_iszero(assembly) + + _merge_jumpdests(assembly) + line_number_map = { "breakpoints": set(), "pc_breakpoints": set(), "pc_jump_map": {0: "-"}, "pc_pos_map": {}, } + posmap = {} sub_assemblies = [] codes = [] pos = start_pos + + # go through the code, resolving symbolic locations + # (i.e. JUMPDEST locations) to actual code locations for i, item in enumerate(assembly): note_line_num(line_number_map, item, pos) if item == "DEBUG": @@ -600,7 +669,7 @@ def assembly_to_evm(assembly, start_pos=0): o += codes[j] break else: - # Should never reach because, assembly is create in compile_to_assembly. + # Should never reach because, assembly is create in _compile_to_assembly. raise Exception("Weird symbol in assembly: " + str(item)) # pragma: no cover assert len(o) == pos - start_pos From 82c564195579e6ae56bd598b413f6b36d5c05d34 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 18 Jul 2021 16:21:39 -0700 Subject: [PATCH 2/2] Ensure line numbers are applied to asm --- vyper/compile_lll.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vyper/compile_lll.py b/vyper/compile_lll.py index 8c79f91141..fd07371667 100644 --- a/vyper/compile_lll.py +++ b/vyper/compile_lll.py @@ -38,7 +38,7 @@ def mksymbol(): def mkdebug(pc_debugger, pos): - i = instruction("DEBUG", pos) + i = Instruction("DEBUG", pos) i.pc_debugger = pc_debugger return [i] @@ -74,7 +74,7 @@ def _add_postambles(asm_ops): _add_postambles(t) -class instruction(str): +class Instruction(str): def __new__(cls, sstr, *args, **kwargs): return super().__new__(cls, sstr) @@ -92,7 +92,7 @@ def apply_line_no_wrapper(*args, **kwargs): code = args[0] ret = func(*args, **kwargs) new_ret = [ - instruction(i, code.pos) if isinstance(i, str) and not isinstance(i, instruction) else i + Instruction(i, code.pos) if isinstance(i, str) and not isinstance(i, Instruction) else i for i in ret ] return new_ret @@ -100,6 +100,7 @@ def apply_line_no_wrapper(*args, **kwargs): return apply_line_no_wrapper +@apply_line_numbers def compile_to_assembly(code): res = _compile_to_assembly(code) @@ -514,7 +515,7 @@ def _compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=N def note_line_num(line_number_map, item, pos): # Record line number attached to pos. - if isinstance(item, instruction): + if isinstance(item, Instruction): if item.lineno is not None: offsets = (item.lineno, item.col_offset, item.end_lineno, item.end_col_offset) else: