diff --git a/tests/functional/builtins/codegen/test_extract32.py b/tests/functional/builtins/codegen/test_extract32.py index 8a92adbb07..f8db51ee36 100644 --- a/tests/functional/builtins/codegen/test_extract32.py +++ b/tests/functional/builtins/codegen/test_extract32.py @@ -1,6 +1,7 @@ import pytest from vyper.evm.opcodes import version_check +from vyper.exceptions import CompilerPanic @pytest.mark.parametrize("location", ["storage", "transient"]) @@ -98,3 +99,50 @@ def foq(inp: Bytes[32]) -> address: with tx_failed(): c.foq(b"crow" * 8) + + +# to fix in future release +@pytest.mark.xfail(raises=CompilerPanic, reason="risky overlap") +def test_extract32_order_of_eval(get_contract): + extract32_code = """ +var:DynArray[Bytes[96], 1] + +@internal +def bar() -> uint256: + self.var[0] = b'hellohellohellohellohellohellohello' + self.var.pop() + return 3 + +@external +def foo() -> bytes32: + self.var = [b'abcdefghijklmnopqrstuvwxyz123456789'] + return extract32(self.var[0], self.bar(), output_type=bytes32) + """ + + c = get_contract(extract32_code) + assert c.foo() == b"defghijklmnopqrstuvwxyz123456789" + + +# to fix in future release +@pytest.mark.xfail(raises=CompilerPanic, reason="risky overlap") +def test_extract32_order_of_eval_extcall(get_contract): + slice_code = """ +var:DynArray[Bytes[96], 1] + +interface Bar: + def bar() -> uint256: payable + +@external +def bar() -> uint256: + self.var[0] = b'hellohellohellohellohellohellohello' + self.var.pop() + return 3 + +@external +def foo() -> bytes32: + self.var = [b'abcdefghijklmnopqrstuvwxyz123456789'] + return extract32(self.var[0], extcall Bar(self).bar(), output_type=bytes32) + """ + + c = get_contract(slice_code) + assert c.foo() == b"defghijklmnopqrstuvwxyz123456789" diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index 08800e7a8c..d5d1efca0f 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -5,7 +5,7 @@ from vyper.compiler import compile_code from vyper.compiler.settings import OptimizationLevel, Settings from vyper.evm.opcodes import version_check -from vyper.exceptions import ArgumentException, TypeMismatch +from vyper.exceptions import ArgumentException, CompilerPanic, TypeMismatch _fun_bytes32_bounds = [(0, 32), (3, 29), (27, 5), (0, 5), (5, 3), (30, 2)] @@ -562,3 +562,53 @@ def foo(cs: String[64]) -> uint256: c = get_contract(code) # ensure that counter was incremented only once assert c.foo(arg) == 1 + + +# to fix in future release +@pytest.mark.xfail(raises=CompilerPanic, reason="risky overlap") +def test_slice_order_of_eval(get_contract): + slice_code = """ +var:DynArray[Bytes[96], 1] + +interface Bar: + def bar() -> uint256: payable + +@external +def bar() -> uint256: + self.var[0] = b'hellohellohellohellohellohellohello' + self.var.pop() + return 32 + +@external +def foo() -> Bytes[96]: + self.var = [b'abcdefghijklmnopqrstuvwxyz123456789'] + return slice(self.var[0], 3, extcall Bar(self).bar()) + """ + + c = get_contract(slice_code) + assert c.foo() == b"defghijklmnopqrstuvwxyz123456789" + + +# to fix in future release +@pytest.mark.xfail(raises=CompilerPanic, reason="risky overlap") +def test_slice_order_of_eval2(get_contract): + slice_code = """ +var:DynArray[Bytes[96], 1] + +interface Bar: + def bar() -> uint256: payable + +@external +def bar() -> uint256: + self.var[0] = b'hellohellohellohellohellohellohello' + self.var.pop() + return 3 + +@external +def foo() -> Bytes[96]: + self.var = [b'abcdefghijklmnopqrstuvwxyz123456789'] + return slice(self.var[0], extcall Bar(self).bar(), 32) + """ + + c = get_contract(slice_code) + assert c.foo() == b"defghijklmnopqrstuvwxyz123456789" diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 2564329b65..672d978455 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -29,6 +29,7 @@ get_type_for_exact_size, ir_tuple_from_args, make_setter, + potential_overlap, promote_signed_int, sar, shl, @@ -357,6 +358,9 @@ def build_IR(self, expr, args, kwargs, context): assert is_bytes32, src src = ensure_in_memory(src, context) + if potential_overlap(src, start) or potential_overlap(src, length): + raise CompilerPanic("risky overlap") + with src.cache_when_complex("src") as (b1, src), start.cache_when_complex("start") as ( b2, start, @@ -862,6 +866,9 @@ def build_IR(self, expr, args, kwargs, context): bytez, index = args ret_type = kwargs["output_type"] + if potential_overlap(bytez, index): + raise CompilerPanic("risky overlap") + def finalize(ret): annotation = "extract32" ret = IRnode.from_list(ret, typ=ret_type, annotation=annotation)