Skip to content

Commit

Permalink
add fixes for create memory cleanliness
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Sep 19, 2023
1 parent 9c5c780 commit bedb9fd
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 21 deletions.
50 changes: 32 additions & 18 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
clamp_basetype,
clamp_nonzero,
copy_bytes,
dummy_node_for_type,
ensure_in_memory,
eval_once_check,
eval_seq,
Expand Down Expand Up @@ -1592,13 +1593,15 @@ def build_IR(self, expr, context):

# CREATE* functions

CREATE2_SENTINEL = dummy_node_for_type(BYTES32_T)


# create helper functions
# generates CREATE op sequence + zero check for result
def _create_ir(value, buf, length, salt=None, checked=True):
def _create_ir(value, buf, length, salt, checked=True):
args = [value, buf, length]
create_op = "create"
if salt is not None:
if salt is not CREATE2_SENTINEL:
create_op = "create2"
args.append(salt)

Expand Down Expand Up @@ -1716,8 +1719,9 @@ def build_IR(self, expr, args, kwargs, context):
context.check_is_not_constant("use {self._id}", expr)

should_use_create2 = "salt" in [kwarg.arg for kwarg in expr.keywords]

if not should_use_create2:
kwargs["salt"] = None
kwargs["salt"] = CREATE2_SENTINEL

ir_builder = self._build_create_IR(expr, args, context, **kwargs)

Expand Down Expand Up @@ -1797,13 +1801,17 @@ def _add_gas_estimate(self, args, should_use_create2):
def _build_create_IR(self, expr, args, context, value, salt):
target = args[0]

with target.cache_when_complex("create_target") as (b1, target):
# something we can pass to scope_multi
with scope_multi((target, value), ("create_target", "create_value")) as (
b1,
(target, value),
), salt.cache_when_complex("create_salt") as (b2, salt):
codesize = IRnode.from_list(["extcodesize", target])
msize = IRnode.from_list(["msize"])
with codesize.cache_when_complex("target_codesize") as (
b2,
codesize,
), msize.cache_when_complex("mem_ofst") as (b3, mem_ofst):
with scope_multi((codesize, msize), ("target_codesize", "mem_ofst")) as (
b3,
(codesize, mem_ofst),
):
ir = ["seq"]

# make sure there is actually code at the target
Expand Down Expand Up @@ -1880,17 +1888,23 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar
# (since the abi encoder could write to fresh memory).
# it would be good to not require the memory copy, but need
# to evaluate memory safety.
with target.cache_when_complex("create_target") as (b1, target), argslen.cache_when_complex(
"encoded_args_len"
) as (b2, encoded_args_len), code_offset.cache_when_complex("code_ofst") as (b3, codeofst):
codesize = IRnode.from_list(["sub", ["extcodesize", target], codeofst])
with scope_multi(
(target, value, argslen, code_offset),
("create_target", "create_value", "encoded_args_len", "code_ofst"),
) as (b1, (target, value, encoded_args_len, code_offset)), salt.cache_when_complex(
"create_salt"
) as (
b2,
salt,
):
codesize = IRnode.from_list(["sub", ["extcodesize", target], code_offset])
# copy code to memory starting from msize. we are clobbering
# unused memory so it's safe.
msize = IRnode.from_list(["msize"], location=MEMORY)
with codesize.cache_when_complex("target_codesize") as (
b4,
codesize,
), msize.cache_when_complex("mem_ofst") as (b5, mem_ofst):
with scope_multi((codesize, msize), ("target_codesize", "mem_ofst")) as (
b3,
(codesize, mem_ofst),
):
ir = ["seq"]

# make sure there is code at the target, and that
Expand All @@ -1910,7 +1924,7 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar
# copy the target code into memory.
# layout starting from mem_ofst:
# 00...00 (22 0's) | preamble | bytecode
ir.append(["extcodecopy", target, mem_ofst, codeofst, codesize])
ir.append(["extcodecopy", target, mem_ofst, code_offset, codesize])

ir.append(copy_bytes(add_ofst(mem_ofst, codesize), argbuf, encoded_args_len, bufsz))

Expand All @@ -1925,7 +1939,7 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar

ir.append(_create_ir(value, mem_ofst, length, salt))

return b1.resolve(b2.resolve(b3.resolve(b4.resolve(b5.resolve(ir)))))
return b1.resolve(b2.resolve(b3.resolve(ir)))


class _UnsafeMath(BuiltinFunction):
Expand Down
36 changes: 33 additions & 3 deletions vyper/codegen/ir_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,44 @@ class Encoding(Enum):
# future: packed


# create multiple with scopes if any of the items are complex, to force
# ordering of side effects.
# CMC 2023-08-10 this is horrible! remove this _as soon as_ we have
# shortcut for chaining multiple cache_when_complex calls
# CMC 2023-08-10 remove this and scope_together _as soon as_ we have
# real variables in IR (that we can declare without explicit scoping -
# needs liveness analysis).
@contextlib.contextmanager
def scope_multi(ir_nodes, names):
assert len(ir_nodes) == len(names)

builders = []
scoped_ir_nodes = []

class _MultiBuilder:
def resolve(self, body):
# sanity check that it's initialized properly
assert len(builders) == len(ir_nodes)
ret = body
for b in builders:
ret = b.resolve(ret)
return ret

mb = _MultiBuilder()

with contextlib.ExitStack() as stack:
for arg, name in zip(ir_nodes, names):
b, ir_node = stack.enter_context(arg.cache_when_complex(name))

builders.append(b)
scoped_ir_nodes.append(ir_node)

yield mb, scoped_ir_nodes


# create multiple with scopes if any of the items are complex, to force
# ordering of side effects.
@contextlib.contextmanager
def scope_together(ir_nodes, names):
assert len(ir_nodes) == len(names)

should_scope = any(s._optimized.is_complex_ir for s in ir_nodes)

class _Builder:
Expand All @@ -77,6 +106,7 @@ def resolve(self, body):
return ret

b = _Builder()

if should_scope:
ir_vars = tuple(
IRnode.from_list(name, typ=arg.typ, location=arg.location, encoding=arg.encoding)
Expand Down

0 comments on commit bedb9fd

Please sign in to comment.