Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: block mload merging when src and dst overlap #3635

Merged
merged 11 commits into from
Oct 3, 2023
73 changes: 73 additions & 0 deletions tests/compiler/ir/test_optimize_ir.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import pytest

from vyper.codegen.ir_node import IRnode
from vyper.evm.opcodes import EVM_VERSIONS, anchor_evm_version
from vyper.exceptions import StaticAssertionException
from vyper.ir import optimizer

POST_CANCUN = {k: v for k, v in EVM_VERSIONS.items() if v >= EVM_VERSIONS["cancun"]}


optimize_list = [
(["eq", 1, 2], [0]),
(["lt", 1, 2], [1]),
Expand Down Expand Up @@ -272,3 +276,72 @@ def test_operator_set_values():
assert optimizer.COMPARISON_OPS == {"lt", "gt", "le", "ge", "slt", "sgt", "sle", "sge"}
assert optimizer.STRICT_COMPARISON_OPS == {"lt", "gt", "slt", "sgt"}
assert optimizer.UNSTRICT_COMPARISON_OPS == {"le", "ge", "sle", "sge"}


mload_merge_list = [
# copy "backward" with no overlap between src and dst buffers,
# OK to become mcopy
(
["seq", ["mstore", 32, ["mload", 128]], ["mstore", 64, ["mload", 160]]],
["mcopy", 32, 128, 64],
),
# copy with overlap "backwards", OK to become mcopy
(["seq", ["mstore", 32, ["mload", 64]], ["mstore", 64, ["mload", 96]]], ["mcopy", 32, 64, 64]),
# "stationary" overlap (i.e. a no-op mcopy), OK to become mcopy
(["seq", ["mstore", 32, ["mload", 32]], ["mstore", 64, ["mload", 64]]], ["mcopy", 32, 32, 64]),
# copy "forward" with no overlap, OK to become mcopy
(["seq", ["mstore", 64, ["mload", 0]], ["mstore", 96, ["mload", 32]]], ["mcopy", 64, 0, 64]),
# copy "forwards" with overlap by one word, must NOT become mcopy
(["seq", ["mstore", 64, ["mload", 32]], ["mstore", 96, ["mload", 64]]], None),
# check "forward" overlap by one byte, must NOT become mcopy
(["seq", ["mstore", 64, ["mload", 1]], ["mstore", 96, ["mload", 33]]], None),
# check "forward" overlap by one byte again, must NOT become mcopy
(["seq", ["mstore", 63, ["mload", 0]], ["mstore", 95, ["mload", 32]]], None),
# copy 3 words with partial overlap "forwards", partially becomes mcopy
# (2 words are mcopied and 1 word is mload/mstored
(
[
"seq",
["mstore", 96, ["mload", 32]],
["mstore", 128, ["mload", 64]],
["mstore", 160, ["mload", 96]],
],
["seq", ["mcopy", 96, 32, 64], ["mstore", 160, ["mload", 96]]],
),
# copy 4 words with partial overlap "forwards", becomes 2 mcopies of 2 words each
(
[
"seq",
["mstore", 96, ["mload", 32]],
["mstore", 128, ["mload", 64]],
["mstore", 160, ["mload", 96]],
["mstore", 192, ["mload", 128]],
],
["seq", ["mcopy", 96, 32, 64], ["mcopy", 160, 96, 64]],
),
# copy 4 words with 1 byte of overlap, must NOT become mcopy
(
[
"seq",
["mstore", 96, ["mload", 33]],
["mstore", 128, ["mload", 65]],
["mstore", 160, ["mload", 97]],
["mstore", 192, ["mload", 129]],
],
None,
),
]


@pytest.mark.parametrize("ir", mload_merge_list)
@pytest.mark.parametrize("evm_version", list(POST_CANCUN.keys()))
def test_mload_merge(ir, evm_version):
with anchor_evm_version(evm_version):
optimized = optimizer.optimize(IRnode.from_list(ir[0]))
if ir[1] is None:
# no-op, assert optimizer does nothing
expected = IRnode.from_list(ir[0])
else:
expected = IRnode.from_list(ir[1])

assert optimized == expected
60 changes: 60 additions & 0 deletions tests/parser/features/test_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,63 @@ def bug(p: Point) -> Point:
"""
c = get_contract(code)
assert c.bug((1, 2)) == (2, 1)


mload_merge_codes = [
(
"""
@external
def foo() -> uint256[4]:
# copy "backwards"
xs: uint256[4] = [1, 2, 3, 4]

# dst < src
xs[0] = xs[1]
xs[1] = xs[2]
xs[2] = xs[3]

return xs
""",
[2, 3, 4, 4],
),
(
"""
@external
def foo() -> uint256[4]:
# copy "forwards"
xs: uint256[4] = [1, 2, 3, 4]

# src < dst
xs[1] = xs[0]
xs[2] = xs[1]
xs[3] = xs[2]

return xs
""",
[1, 1, 1, 1],
),
(
"""
@external
def foo() -> uint256[5]:
# partial "forward" copy
xs: uint256[5] = [1, 2, 3, 4, 5]

# src < dst
xs[2] = xs[0]
xs[3] = xs[1]
xs[4] = xs[2]

return xs
""",
[1, 2, 1, 2, 1],
),
]


# functional test that mload merging does not occur when source and dest
# buffers overlap. (note: mload merging only applies after cancun)
@pytest.mark.parametrize("code,expected_result", mload_merge_codes)
def test_mcopy_overlap(get_contract, code, expected_result):
c = get_contract(code)
assert c.foo() == expected_result
9 changes: 7 additions & 2 deletions vyper/ir/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,10 +662,10 @@ def _rewrite_mstore_dload(argz):
def _merge_mload(argz):
if not version_check(begin="cancun"):
return False
return _merge_load(argz, "mload", "mcopy")
return _merge_load(argz, "mload", "mcopy", allow_overlap=False)


def _merge_load(argz, _LOAD, _COPY):
def _merge_load(argz, _LOAD, _COPY, allow_overlap=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: think generally params should be "safe-by-default", and the "safer" option would be to have no overlap.

Copy link
Member Author

@charles-cooper charles-cooper Oct 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm i see, my thinking was that this only affects where the source and destination address spaces are the same, and there is only one case for this, mload / mstore. but i am open to changing the default.

# look for sequential operations copying from X to Y
# and merge them into a single copy operation
changed = False
Expand All @@ -689,9 +689,14 @@ def _merge_load(argz, _LOAD, _COPY):
initial_dst_offset = dst_offset
initial_src_offset = src_offset
idx = i

# dst and src overlap, discontinue the optimization
has_overlap = initial_src_offset < initial_dst_offset < src_offset + 32

if (
initial_dst_offset + total_length == dst_offset
and initial_src_offset + total_length == src_offset
and (allow_overlap or not has_overlap)
):
mstore_nodes.append(ir_node)
total_length += 32
Expand Down
Loading