Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: memory allocation in certain builtins using msize #3610

Merged
merged 16 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/compiler/ir/test_optimize_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@
(["sub", "x", 0], ["x"]),
(["sub", "x", "x"], [0]),
(["sub", ["sload", 0], ["sload", 0]], None),
(["sub", ["callvalue"], ["callvalue"]], None),
(["sub", ["callvalue"], ["callvalue"]], [0]),
(["sub", ["msize"], ["msize"]], None),
(["sub", ["gas"], ["gas"]], None),
(["sub", -1, ["sload", 0]], ["not", ["sload", 0]]),
(["mul", "x", 1], ["x"]),
(["div", "x", 1], ["x"]),
Expand Down Expand Up @@ -210,7 +212,9 @@
(["eq", -1, ["add", -(2**255), 2**255 - 1]], [1]), # test compile-time wrapping
(["eq", -2, ["add", 2**256 - 1, 2**256 - 1]], [1]), # test compile-time wrapping
(["eq", "x", "x"], [1]),
(["eq", "callvalue", "callvalue"], None),
(["eq", "gas", "gas"], None),
(["eq", "msize", "msize"], None),
(["eq", "callvalue", "callvalue"], [1]),
(["ne", "x", "x"], [0]),
]

Expand Down
209 changes: 209 additions & 0 deletions tests/parser/functions/test_create_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,212 @@ def test2(target: address, salt: bytes32) -> address:
# test2 = c.test2(b"\x01", salt)
# assert HexBytes(test2) == create2_address_of(c.address, salt, vyper_initcode(b"\x01"))
# assert_tx_failed(lambda: c.test2(bytecode, salt))


# XXX: these various tests to check the msize allocator for
# create_copy_of and create_from_blueprint depend on calling convention
# and variables writing to memory. think of ways to make more robust to
# changes in calling convention and memory layout
@pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"])
def test_create_from_blueprint_complex_value(
get_contract, deploy_blueprint_for, w3, blueprint_prefix
):
# check msize allocator does not get trampled by value= kwarg
code = """
var: uint256

@external
@payable
def __init__(x: uint256):
self.var = x

@external
def foo()-> uint256:
return self.var
"""

prefix_len = len(blueprint_prefix)

some_constant = b"\00" * 31 + b"\x0c"

deployer_code = f"""
created_address: public(address)
x: constant(Bytes[32]) = {some_constant}

@internal
def foo() -> uint256:
g:uint256 = 42
return 3

@external
@payable
def test(target: address):
self.created_address = create_from_blueprint(
target,
x,
code_offset={prefix_len},
value=self.foo(),
raw_args=True
)
"""

foo_contract = get_contract(code, 12)
expected_runtime_code = w3.eth.get_code(foo_contract.address)

f, FooContract = deploy_blueprint_for(code, initcode_prefix=blueprint_prefix)

d = get_contract(deployer_code)

d.test(f.address, transact={"value": 3})

test = FooContract(d.created_address())
assert w3.eth.get_code(test.address) == expected_runtime_code
assert test.foo() == 12


@pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"])
def test_create_from_blueprint_complex_salt_raw_args(
get_contract, deploy_blueprint_for, w3, blueprint_prefix
):
# test msize allocator does not get trampled by salt= kwarg
code = """
var: uint256

@external
@payable
def __init__(x: uint256):
self.var = x

@external
def foo()-> uint256:
return self.var
"""

some_constant = b"\00" * 31 + b"\x0c"
prefix_len = len(blueprint_prefix)

deployer_code = f"""
created_address: public(address)

x: constant(Bytes[32]) = {some_constant}
salt: constant(bytes32) = keccak256("kebab")

@internal
def foo() -> bytes32:
g:uint256 = 42
return salt

@external
@payable
def test(target: address):
self.created_address = create_from_blueprint(
target,
x,
code_offset={prefix_len},
salt=self.foo(),
raw_args= True
)
"""

foo_contract = get_contract(code, 12)
expected_runtime_code = w3.eth.get_code(foo_contract.address)

f, FooContract = deploy_blueprint_for(code, initcode_prefix=blueprint_prefix)

d = get_contract(deployer_code)

d.test(f.address, transact={})

test = FooContract(d.created_address())
assert w3.eth.get_code(test.address) == expected_runtime_code
assert test.foo() == 12


@pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"])
def test_create_from_blueprint_complex_salt_no_constructor_args(
get_contract, deploy_blueprint_for, w3, blueprint_prefix
):
# test msize allocator does not get trampled by salt= kwarg
code = """
var: uint256

@external
@payable
def __init__():
self.var = 12

@external
def foo()-> uint256:
return self.var
"""

prefix_len = len(blueprint_prefix)
deployer_code = f"""
created_address: public(address)

salt: constant(bytes32) = keccak256("kebab")

@external
@payable
def test(target: address):
self.created_address = create_from_blueprint(
target,
code_offset={prefix_len},
salt=keccak256(_abi_encode(target))
)
"""

foo_contract = get_contract(code)
expected_runtime_code = w3.eth.get_code(foo_contract.address)

f, FooContract = deploy_blueprint_for(code, initcode_prefix=blueprint_prefix)

d = get_contract(deployer_code)

d.test(f.address, transact={})

test = FooContract(d.created_address())
assert w3.eth.get_code(test.address) == expected_runtime_code
assert test.foo() == 12


def test_create_copy_of_complex_kwargs(get_contract, w3):
# test msize allocator does not get trampled by salt= kwarg
complex_salt = """
created_address: public(address)

@external
def test(target: address) -> address:
self.created_address = create_copy_of(
target,
salt=keccak256(_abi_encode(target))
)
return self.created_address

"""

c = get_contract(complex_salt)
bytecode = w3.eth.get_code(c.address)
c.test(c.address, transact={})
test1 = c.created_address()
assert w3.eth.get_code(test1) == bytecode

# test msize allocator does not get trampled by value= kwarg
complex_value = """
created_address: public(address)

@external
@payable
def test(target: address) -> address:
value: uint256 = 2
self.created_address = create_copy_of(target, value = [2,2,2][value])
return self.created_address

"""

c = get_contract(complex_value)
bytecode = w3.eth.get_code(c.address)

c.test(c.address, transact={"value": 2})
test1 = c.created_address()
assert w3.eth.get_code(test1) == bytecode
158 changes: 158 additions & 0 deletions tests/parser/functions/test_raw_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,164 @@ def baz(_addr: address, should_raise: bool) -> uint256:
assert caller.baz(target.address, False) == 3


# XXX: these test_raw_call_clean_mem* tests depend on variables and
# calling convention writing to memory. think of ways to make more
# robust to changes to calling convention and memory layout.


def test_raw_call_msg_data_clean_mem(get_contract):
# test msize uses clean memory and does not get overwritten by
# any raw_call() arguments
code = """
identity: constant(address) = 0x0000000000000000000000000000000000000004

@external
def foo():
pass

@internal
@view
def get_address()->address:
a:uint256 = 121 # 0x79
return identity
@external
def bar(f: uint256, u: uint256) -> Bytes[100]:
# embed an internal call in the calculation of address
a: Bytes[100] = raw_call(self.get_address(), msg.data, max_outsize=100)
return a
"""

c = get_contract(code)
assert (
c.bar(1, 2).hex() == "ae42e951"
"0000000000000000000000000000000000000000000000000000000000000001"
"0000000000000000000000000000000000000000000000000000000000000002"
)


def test_raw_call_clean_mem2(get_contract):
# test msize uses clean memory and does not get overwritten by
# any raw_call() arguments, another way
code = """
buf: Bytes[100]

@external
def bar(f: uint256, g: uint256, h: uint256) -> Bytes[100]:
# embed a memory modifying expression in the calculation of address
self.buf = raw_call(
[0x0000000000000000000000000000000000000004,][f-1],
msg.data,
max_outsize=100
)
return self.buf
"""
c = get_contract(code)

assert (
c.bar(1, 2, 3).hex() == "9309b76e"
"0000000000000000000000000000000000000000000000000000000000000001"
"0000000000000000000000000000000000000000000000000000000000000002"
"0000000000000000000000000000000000000000000000000000000000000003"
)


def test_raw_call_clean_mem3(get_contract):
# test msize uses clean memory and does not get overwritten by
# any raw_call() arguments, and also test order of evaluation for
# scope_multi
code = """
buf: Bytes[100]
canary: String[32]

@internal
def bar() -> address:
self.canary = "bar"
return 0x0000000000000000000000000000000000000004

@internal
def goo() -> uint256:
self.canary = "goo"
return 0

@external
def foo() -> String[32]:
self.buf = raw_call(self.bar(), msg.data, value = self.goo(), max_outsize=100)
return self.canary
"""
c = get_contract(code)
assert c.foo() == "goo"


def test_raw_call_clean_mem_kwargs_value(get_contract):
# test msize uses clean memory and does not get overwritten by
# any raw_call() kwargs
code = """
buf: Bytes[100]

# add a dummy function to trigger memory expansion in the selector table routine
@external
def foo():
pass

@internal
def _value() -> uint256:
x: uint256 = 1
return x

@external
def bar(f: uint256) -> Bytes[100]:
# embed a memory modifying expression in the calculation of address
self.buf = raw_call(
0x0000000000000000000000000000000000000004,
msg.data,
max_outsize=100,
value=self._value()
)
return self.buf
"""
c = get_contract(code, value=1)

assert (
c.bar(13).hex() == "0423a132"
"000000000000000000000000000000000000000000000000000000000000000d"
)


def test_raw_call_clean_mem_kwargs_gas(get_contract):
# test msize uses clean memory and does not get overwritten by
# any raw_call() kwargs
code = """
buf: Bytes[100]

# add a dummy function to trigger memory expansion in the selector table routine
@external
def foo():
pass

@internal
def _gas() -> uint256:
x: uint256 = msg.gas
return x

@external
def bar(f: uint256) -> Bytes[100]:
# embed a memory modifying expression in the calculation of address
self.buf = raw_call(
0x0000000000000000000000000000000000000004,
msg.data,
max_outsize=100,
gas=self._gas()
)
return self.buf
"""
c = get_contract(code, value=1)

assert (
c.bar(15).hex() == "0423a132"
"000000000000000000000000000000000000000000000000000000000000000f"
)


uncompilable_code = [
(
"""
Expand Down
Loading