Skip to content

Commit

Permalink
fix: order of evaluation for some builtins (#3583)
Browse files Browse the repository at this point in the history
ecadd, ecmul, addmod, mulmod

in the case that the arguments have side effects, they could be
evaluated out of order

chainsec june 2023 review 5.1

---------

Co-authored-by: tserg <[email protected]>
Co-authored-by: trocher <[email protected]>
  • Loading branch information
3 people authored Sep 5, 2023
1 parent 2c21eab commit 78fa8dd
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 53 deletions.
32 changes: 32 additions & 0 deletions tests/parser/functions/test_addmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,35 @@ def c() -> uint256:
c = get_contract_with_gas_estimation(code)

assert c.foo() == 2


def test_uint256_addmod_evaluation_order(get_contract_with_gas_estimation):
code = """
a: uint256
@external
def foo1() -> uint256:
self.a = 0
return uint256_addmod(self.a, 1, self.bar())
@external
def foo2() -> uint256:
self.a = 0
return uint256_addmod(self.a, self.bar(), 3)
@external
def foo3() -> uint256:
self.a = 0
return uint256_addmod(1, self.a, self.bar())
@internal
def bar() -> uint256:
self.a = 1
return 2
"""

c = get_contract_with_gas_estimation(code)

assert c.foo1() == 1
assert c.foo2() == 2
assert c.foo3() == 1
40 changes: 40 additions & 0 deletions tests/parser/functions/test_ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@ def foo(a: Foo) -> uint256[2]:
assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))


def test_ecadd_evaluation_order(get_contract_with_gas_estimation):
code = """
x: uint256[2]
@internal
def bar() -> uint256[2]:
self.x = ecadd([1, 2], [1, 2])
return [1, 2]
@external
def foo() -> bool:
self.x = [1, 2]
a: uint256[2] = ecadd([1, 2], [1, 2])
b: uint256[2] = ecadd(self.x, self.bar())
return a[0] == b[0] and a[1] == b[1]
"""
c = get_contract_with_gas_estimation(code)
assert c.foo() is True


def test_ecmul(get_contract_with_gas_estimation):
ecmuller = """
x3: uint256[2]
Expand Down Expand Up @@ -136,3 +156,23 @@ def foo(a: Foo) -> uint256[2]:
assert c2.foo(c1.address) == G1_times_three

assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))


def test_ecmul_evaluation_order(get_contract_with_gas_estimation):
code = """
x: uint256[2]
@internal
def bar() -> uint256:
self.x = ecmul([1, 2], 3)
return 3
@external
def foo() -> bool:
self.x = [1, 2]
a: uint256[2] = ecmul([1, 2], 3)
b: uint256[2] = ecmul(self.x, self.bar())
return a[0] == b[0] and a[1] == b[1]
"""
c = get_contract_with_gas_estimation(code)
assert c.foo() is True
32 changes: 32 additions & 0 deletions tests/parser/functions/test_mulmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,35 @@ def c() -> uint256:
c = get_contract_with_gas_estimation(code)

assert c.foo() == 600


def test_uint256_mulmod_evaluation_order(get_contract_with_gas_estimation):
code = """
a: uint256
@external
def foo1() -> uint256:
self.a = 1
return uint256_mulmod(self.a, 2, self.bar())
@external
def foo2() -> uint256:
self.a = 1
return uint256_mulmod(self.bar(), self.a, 2)
@external
def foo3() -> uint256:
self.a = 1
return uint256_mulmod(2, self.a, self.bar())
@internal
def bar() -> uint256:
self.a = 7
return 5
"""

c = get_contract_with_gas_estimation(code)

assert c.foo1() == 2
assert c.foo2() == 1
assert c.foo3() == 2
92 changes: 39 additions & 53 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
eval_once_check,
eval_seq,
get_bytearray_length,
get_element_ptr,
get_type_for_exact_size,
ir_tuple_from_args,
make_setter,
needs_external_call_wrap,
promote_signed_int,
sar,
Expand Down Expand Up @@ -782,39 +782,29 @@ def build_IR(self, expr, args, kwargs, context):
)


def _getelem(arg, ind):
return unwrap_location(get_element_ptr(arg, IRnode.from_list(ind, typ=INT128_T)))


class ECAdd(BuiltinFunction):
_id = "ecadd"
_inputs = [("a", SArrayT(UINT256_T, 2)), ("b", SArrayT(UINT256_T, 2))]
_return_type = SArrayT(UINT256_T, 2)

@process_inputs
def build_IR(self, expr, args, kwargs, context):
placeholder_node = IRnode.from_list(
context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY
)
buf_t = get_type_for_exact_size(128)

with args[0].cache_when_complex("a") as (b1, a), args[1].cache_when_complex("b") as (b2, b):
o = IRnode.from_list(
[
"seq",
["mstore", placeholder_node, _getelem(a, 0)],
["mstore", ["add", placeholder_node, 32], _getelem(a, 1)],
["mstore", ["add", placeholder_node, 64], _getelem(b, 0)],
["mstore", ["add", placeholder_node, 96], _getelem(b, 1)],
[
"assert",
["staticcall", ["gas"], 6, placeholder_node, 128, placeholder_node, 64],
],
placeholder_node,
],
typ=SArrayT(UINT256_T, 2),
location=MEMORY,
)
return b2.resolve(b1.resolve(o))
buf = context.new_internal_variable(buf_t)

ret = ["seq"]

dst0 = IRnode.from_list(buf, typ=SArrayT(UINT256_T, 2), location=MEMORY)
ret.append(make_setter(dst0, args[0]))

dst1 = IRnode.from_list(buf + 64, typ=SArrayT(UINT256_T, 2), location=MEMORY)
ret.append(make_setter(dst1, args[1]))

ret.append(["assert", ["staticcall", ["gas"], 6, buf, 128, buf, 64]])
ret.append(buf)

return IRnode.from_list(ret, typ=SArrayT(UINT256_T, 2), location=MEMORY)


class ECMul(BuiltinFunction):
Expand All @@ -824,27 +814,22 @@ class ECMul(BuiltinFunction):

@process_inputs
def build_IR(self, expr, args, kwargs, context):
placeholder_node = IRnode.from_list(
context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY
)
buf_t = get_type_for_exact_size(96)

with args[0].cache_when_complex("a") as (b1, a), args[1].cache_when_complex("b") as (b2, b):
o = IRnode.from_list(
[
"seq",
["mstore", placeholder_node, _getelem(a, 0)],
["mstore", ["add", placeholder_node, 32], _getelem(a, 1)],
["mstore", ["add", placeholder_node, 64], b],
[
"assert",
["staticcall", ["gas"], 7, placeholder_node, 96, placeholder_node, 64],
],
placeholder_node,
],
typ=SArrayT(UINT256_T, 2),
location=MEMORY,
)
return b2.resolve(b1.resolve(o))
buf = context.new_internal_variable(buf_t)

ret = ["seq"]

dst0 = IRnode.from_list(buf, typ=SArrayT(UINT256_T, 2), location=MEMORY)
ret.append(make_setter(dst0, args[0]))

dst1 = IRnode.from_list(buf + 64, typ=UINT256_T, location=MEMORY)
ret.append(make_setter(dst1, args[1]))

ret.append(["assert", ["staticcall", ["gas"], 7, buf, 96, buf, 64]])
ret.append(buf)

return IRnode.from_list(ret, typ=SArrayT(UINT256_T, 2), location=MEMORY)


def _generic_element_getter(op):
Expand Down Expand Up @@ -1525,13 +1510,14 @@ def evaluate(self, node):

@process_inputs
def build_IR(self, expr, args, kwargs, context):
c = args[2]

with c.cache_when_complex("c") as (b1, c):
ret = IRnode.from_list(
["seq", ["assert", c], [self._opcode, args[0], args[1], c]], typ=UINT256_T
)
return b1.resolve(ret)
x, y, z = args
with x.cache_when_complex("x") as (b1, x):
with y.cache_when_complex("y") as (b2, y):
with z.cache_when_complex("z") as (b3, z):
ret = IRnode.from_list(
["seq", ["assert", z], [self._opcode, x, y, z]], typ=UINT256_T
)
return b1.resolve(b2.resolve(b3.resolve(ret)))


class AddMod(_AddMulMod):
Expand Down

0 comments on commit 78fa8dd

Please sign in to comment.