diff --git a/tests/compiler/ir/test_optimize_ir.py b/tests/compiler/ir/test_optimize_ir.py index 1466166501..cb46ba238d 100644 --- a/tests/compiler/ir/test_optimize_ir.py +++ b/tests/compiler/ir/test_optimize_ir.py @@ -1,9 +1,13 @@ import pytest from vyper.codegen.ir_node import IRnode +from vyper.evm.opcodes import EVM_VERSIONS, anchor_evm_version from vyper.exceptions import StaticAssertionException from vyper.ir import optimizer +POST_CANCUN = {k: v for k, v in EVM_VERSIONS.items() if v >= EVM_VERSIONS["cancun"]} + + optimize_list = [ (["eq", 1, 2], [0]), (["lt", 1, 2], [1]), @@ -272,3 +276,106 @@ def test_operator_set_values(): assert optimizer.COMPARISON_OPS == {"lt", "gt", "le", "ge", "slt", "sgt", "sle", "sge"} assert optimizer.STRICT_COMPARISON_OPS == {"lt", "gt", "slt", "sgt"} assert optimizer.UNSTRICT_COMPARISON_OPS == {"le", "ge", "sle", "sge"} + + +mload_merge_list = [ + # copy "backward" with no overlap between src and dst buffers, + # OK to become mcopy + ( + ["seq", ["mstore", 32, ["mload", 128]], ["mstore", 64, ["mload", 160]]], + ["mcopy", 32, 128, 64], + ), + # copy with overlap "backwards", OK to become mcopy + (["seq", ["mstore", 32, ["mload", 64]], ["mstore", 64, ["mload", 96]]], ["mcopy", 32, 64, 64]), + # "stationary" overlap (i.e. a no-op mcopy), OK to become mcopy + (["seq", ["mstore", 32, ["mload", 32]], ["mstore", 64, ["mload", 64]]], ["mcopy", 32, 32, 64]), + # copy "forward" with no overlap, OK to become mcopy + (["seq", ["mstore", 64, ["mload", 0]], ["mstore", 96, ["mload", 32]]], ["mcopy", 64, 0, 64]), + # copy "forwards" with overlap by one word, must NOT become mcopy + (["seq", ["mstore", 64, ["mload", 32]], ["mstore", 96, ["mload", 64]]], None), + # check "forward" overlap by one byte, must NOT become mcopy + (["seq", ["mstore", 64, ["mload", 1]], ["mstore", 96, ["mload", 33]]], None), + # check "forward" overlap by one byte again, must NOT become mcopy + (["seq", ["mstore", 63, ["mload", 0]], ["mstore", 95, ["mload", 32]]], None), + # copy 3 words with partial overlap "forwards", partially becomes mcopy + # (2 words are mcopied and 1 word is mload/mstored + ( + [ + "seq", + ["mstore", 96, ["mload", 32]], + ["mstore", 128, ["mload", 64]], + ["mstore", 160, ["mload", 96]], + ], + ["seq", ["mcopy", 96, 32, 64], ["mstore", 160, ["mload", 96]]], + ), + # copy 4 words with partial overlap "forwards", becomes 2 mcopies of 2 words each + ( + [ + "seq", + ["mstore", 96, ["mload", 32]], + ["mstore", 128, ["mload", 64]], + ["mstore", 160, ["mload", 96]], + ["mstore", 192, ["mload", 128]], + ], + ["seq", ["mcopy", 96, 32, 64], ["mcopy", 160, 96, 64]], + ), + # copy 4 words with 1 byte of overlap, must NOT become mcopy + ( + [ + "seq", + ["mstore", 96, ["mload", 33]], + ["mstore", 128, ["mload", 65]], + ["mstore", 160, ["mload", 97]], + ["mstore", 192, ["mload", 129]], + ], + None, + ), + # Ensure only sequential mstore + mload sequences are optimized + ( + [ + "seq", + ["mstore", 0, ["mload", 32]], + ["sstore", 0, ["calldataload", 4]], + ["mstore", 32, ["mload", 64]], + ], + None, + ), + # not-word aligned optimizations (not overlap) + (["seq", ["mstore", 0, ["mload", 1]], ["mstore", 32, ["mload", 33]]], ["mcopy", 0, 1, 64]), + # not-word aligned optimizations (overlap) + (["seq", ["mstore", 1, ["mload", 0]], ["mstore", 33, ["mload", 32]]], None), + # not-word aligned optimizations (overlap and not-overlap) + ( + [ + "seq", + ["mstore", 0, ["mload", 1]], + ["mstore", 32, ["mload", 33]], + ["mstore", 1, ["mload", 0]], + ["mstore", 33, ["mload", 32]], + ], + ["seq", ["mcopy", 0, 1, 64], ["mstore", 1, ["mload", 0]], ["mstore", 33, ["mload", 32]]], + ), + # overflow test + ( + [ + "seq", + ["mstore", 2**256 - 1 - 31 - 32, ["mload", 0]], + ["mstore", 2**256 - 1 - 31, ["mload", 32]], + ], + ["mcopy", 2**256 - 1 - 31 - 32, 0, 64], + ), +] + + +@pytest.mark.parametrize("ir", mload_merge_list) +@pytest.mark.parametrize("evm_version", list(POST_CANCUN.keys())) +def test_mload_merge(ir, evm_version): + with anchor_evm_version(evm_version): + optimized = optimizer.optimize(IRnode.from_list(ir[0])) + if ir[1] is None: + # no-op, assert optimizer does nothing + expected = IRnode.from_list(ir[0]) + else: + expected = IRnode.from_list(ir[1]) + + assert optimized == expected diff --git a/tests/parser/features/test_assignment.py b/tests/parser/features/test_assignment.py index e550f60541..35b008a8ba 100644 --- a/tests/parser/features/test_assignment.py +++ b/tests/parser/features/test_assignment.py @@ -442,3 +442,63 @@ def bug(p: Point) -> Point: """ c = get_contract(code) assert c.bug((1, 2)) == (2, 1) + + +mload_merge_codes = [ + ( + """ +@external +def foo() -> uint256[4]: + # copy "backwards" + xs: uint256[4] = [1, 2, 3, 4] + +# dst < src + xs[0] = xs[1] + xs[1] = xs[2] + xs[2] = xs[3] + + return xs + """, + [2, 3, 4, 4], + ), + ( + """ +@external +def foo() -> uint256[4]: + # copy "forwards" + xs: uint256[4] = [1, 2, 3, 4] + +# src < dst + xs[1] = xs[0] + xs[2] = xs[1] + xs[3] = xs[2] + + return xs + """, + [1, 1, 1, 1], + ), + ( + """ +@external +def foo() -> uint256[5]: + # partial "forward" copy + xs: uint256[5] = [1, 2, 3, 4, 5] + +# src < dst + xs[2] = xs[0] + xs[3] = xs[1] + xs[4] = xs[2] + + return xs + """, + [1, 2, 1, 2, 1], + ), +] + + +# functional test that mload merging does not occur when source and dest +# buffers overlap. (note: mload merging only applies after cancun) +@pytest.mark.parametrize("code,expected_result", mload_merge_codes) +def test_mcopy_overlap(get_contract, code, expected_result): + c = get_contract(code) + assert c.foo() == expected_result diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index 08c2168381..8df4bbac2d 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -662,10 +662,10 @@ def _rewrite_mstore_dload(argz): def _merge_mload(argz): if not version_check(begin="cancun"): return False - return _merge_load(argz, "mload", "mcopy") + return _merge_load(argz, "mload", "mcopy", allow_overlap=False) -def _merge_load(argz, _LOAD, _COPY): +def _merge_load(argz, _LOAD, _COPY, allow_overlap=True): # look for sequential operations copying from X to Y # and merge them into a single copy operation changed = False @@ -689,9 +689,14 @@ def _merge_load(argz, _LOAD, _COPY): initial_dst_offset = dst_offset initial_src_offset = src_offset idx = i + + # dst and src overlap, discontinue the optimization + has_overlap = initial_src_offset < initial_dst_offset < src_offset + 32 + if ( initial_dst_offset + total_length == dst_offset and initial_src_offset + total_length == src_offset + and (allow_overlap or not has_overlap) ): mstore_nodes.append(ir_node) total_length += 32