From 78fa8dd8f91ba0cb26277eeffb585c68c83e7daa Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 5 Sep 2023 08:26:37 -0400 Subject: [PATCH] fix: order of evaluation for some builtins (#3583) ecadd, ecmul, addmod, mulmod in the case that the arguments have side effects, they could be evaluated out of order chainsec june 2023 review 5.1 --------- Co-authored-by: tserg <8017125+tserg@users.noreply.github.com> Co-authored-by: trocher <43437004+trocher@users.noreply.github.com> --- tests/parser/functions/test_addmod.py | 32 ++++++++++ tests/parser/functions/test_ec.py | 40 ++++++++++++ tests/parser/functions/test_mulmod.py | 32 ++++++++++ vyper/builtins/functions.py | 92 ++++++++++++--------------- 4 files changed, 143 insertions(+), 53 deletions(-) diff --git a/tests/parser/functions/test_addmod.py b/tests/parser/functions/test_addmod.py index 67a7e9b101..b3135660bb 100644 --- a/tests/parser/functions/test_addmod.py +++ b/tests/parser/functions/test_addmod.py @@ -55,3 +55,35 @@ def c() -> uint256: c = get_contract_with_gas_estimation(code) assert c.foo() == 2 + + +def test_uint256_addmod_evaluation_order(get_contract_with_gas_estimation): + code = """ +a: uint256 + +@external +def foo1() -> uint256: + self.a = 0 + return uint256_addmod(self.a, 1, self.bar()) + +@external +def foo2() -> uint256: + self.a = 0 + return uint256_addmod(self.a, self.bar(), 3) + +@external +def foo3() -> uint256: + self.a = 0 + return uint256_addmod(1, self.a, self.bar()) + +@internal +def bar() -> uint256: + self.a = 1 + return 2 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo1() == 1 + assert c.foo2() == 2 + assert c.foo3() == 1 diff --git a/tests/parser/functions/test_ec.py b/tests/parser/functions/test_ec.py index 9ce37d0721..e1d9e3d2ee 100644 --- a/tests/parser/functions/test_ec.py +++ b/tests/parser/functions/test_ec.py @@ -76,6 +76,26 @@ def foo(a: Foo) -> uint256[2]: assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) +def test_ecadd_evaluation_order(get_contract_with_gas_estimation): + code = """ +x: uint256[2] + +@internal +def bar() -> uint256[2]: + self.x = ecadd([1, 2], [1, 2]) + return [1, 2] + +@external +def foo() -> bool: + self.x = [1, 2] + a: uint256[2] = ecadd([1, 2], [1, 2]) + b: uint256[2] = ecadd(self.x, self.bar()) + return a[0] == b[0] and a[1] == b[1] + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() is True + + def test_ecmul(get_contract_with_gas_estimation): ecmuller = """ x3: uint256[2] @@ -136,3 +156,23 @@ def foo(a: Foo) -> uint256[2]: assert c2.foo(c1.address) == G1_times_three assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_ecmul_evaluation_order(get_contract_with_gas_estimation): + code = """ +x: uint256[2] + +@internal +def bar() -> uint256: + self.x = ecmul([1, 2], 3) + return 3 + +@external +def foo() -> bool: + self.x = [1, 2] + a: uint256[2] = ecmul([1, 2], 3) + b: uint256[2] = ecmul(self.x, self.bar()) + return a[0] == b[0] and a[1] == b[1] + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() is True diff --git a/tests/parser/functions/test_mulmod.py b/tests/parser/functions/test_mulmod.py index 1ea7a3f8e8..96477897b9 100644 --- a/tests/parser/functions/test_mulmod.py +++ b/tests/parser/functions/test_mulmod.py @@ -73,3 +73,35 @@ def c() -> uint256: c = get_contract_with_gas_estimation(code) assert c.foo() == 600 + + +def test_uint256_mulmod_evaluation_order(get_contract_with_gas_estimation): + code = """ +a: uint256 + +@external +def foo1() -> uint256: + self.a = 1 + return uint256_mulmod(self.a, 2, self.bar()) + +@external +def foo2() -> uint256: + self.a = 1 + return uint256_mulmod(self.bar(), self.a, 2) + +@external +def foo3() -> uint256: + self.a = 1 + return uint256_mulmod(2, self.a, self.bar()) + +@internal +def bar() -> uint256: + self.a = 7 + return 5 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo1() == 2 + assert c.foo2() == 1 + assert c.foo3() == 2 diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index e8e001306c..053ee512dc 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -25,9 +25,9 @@ eval_once_check, eval_seq, get_bytearray_length, - get_element_ptr, get_type_for_exact_size, ir_tuple_from_args, + make_setter, needs_external_call_wrap, promote_signed_int, sar, @@ -782,10 +782,6 @@ def build_IR(self, expr, args, kwargs, context): ) -def _getelem(arg, ind): - return unwrap_location(get_element_ptr(arg, IRnode.from_list(ind, typ=INT128_T))) - - class ECAdd(BuiltinFunction): _id = "ecadd" _inputs = [("a", SArrayT(UINT256_T, 2)), ("b", SArrayT(UINT256_T, 2))] @@ -793,28 +789,22 @@ class ECAdd(BuiltinFunction): @process_inputs def build_IR(self, expr, args, kwargs, context): - placeholder_node = IRnode.from_list( - context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY - ) + buf_t = get_type_for_exact_size(128) - with args[0].cache_when_complex("a") as (b1, a), args[1].cache_when_complex("b") as (b2, b): - o = IRnode.from_list( - [ - "seq", - ["mstore", placeholder_node, _getelem(a, 0)], - ["mstore", ["add", placeholder_node, 32], _getelem(a, 1)], - ["mstore", ["add", placeholder_node, 64], _getelem(b, 0)], - ["mstore", ["add", placeholder_node, 96], _getelem(b, 1)], - [ - "assert", - ["staticcall", ["gas"], 6, placeholder_node, 128, placeholder_node, 64], - ], - placeholder_node, - ], - typ=SArrayT(UINT256_T, 2), - location=MEMORY, - ) - return b2.resolve(b1.resolve(o)) + buf = context.new_internal_variable(buf_t) + + ret = ["seq"] + + dst0 = IRnode.from_list(buf, typ=SArrayT(UINT256_T, 2), location=MEMORY) + ret.append(make_setter(dst0, args[0])) + + dst1 = IRnode.from_list(buf + 64, typ=SArrayT(UINT256_T, 2), location=MEMORY) + ret.append(make_setter(dst1, args[1])) + + ret.append(["assert", ["staticcall", ["gas"], 6, buf, 128, buf, 64]]) + ret.append(buf) + + return IRnode.from_list(ret, typ=SArrayT(UINT256_T, 2), location=MEMORY) class ECMul(BuiltinFunction): @@ -824,27 +814,22 @@ class ECMul(BuiltinFunction): @process_inputs def build_IR(self, expr, args, kwargs, context): - placeholder_node = IRnode.from_list( - context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY - ) + buf_t = get_type_for_exact_size(96) - with args[0].cache_when_complex("a") as (b1, a), args[1].cache_when_complex("b") as (b2, b): - o = IRnode.from_list( - [ - "seq", - ["mstore", placeholder_node, _getelem(a, 0)], - ["mstore", ["add", placeholder_node, 32], _getelem(a, 1)], - ["mstore", ["add", placeholder_node, 64], b], - [ - "assert", - ["staticcall", ["gas"], 7, placeholder_node, 96, placeholder_node, 64], - ], - placeholder_node, - ], - typ=SArrayT(UINT256_T, 2), - location=MEMORY, - ) - return b2.resolve(b1.resolve(o)) + buf = context.new_internal_variable(buf_t) + + ret = ["seq"] + + dst0 = IRnode.from_list(buf, typ=SArrayT(UINT256_T, 2), location=MEMORY) + ret.append(make_setter(dst0, args[0])) + + dst1 = IRnode.from_list(buf + 64, typ=UINT256_T, location=MEMORY) + ret.append(make_setter(dst1, args[1])) + + ret.append(["assert", ["staticcall", ["gas"], 7, buf, 96, buf, 64]]) + ret.append(buf) + + return IRnode.from_list(ret, typ=SArrayT(UINT256_T, 2), location=MEMORY) def _generic_element_getter(op): @@ -1525,13 +1510,14 @@ def evaluate(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - c = args[2] - - with c.cache_when_complex("c") as (b1, c): - ret = IRnode.from_list( - ["seq", ["assert", c], [self._opcode, args[0], args[1], c]], typ=UINT256_T - ) - return b1.resolve(ret) + x, y, z = args + with x.cache_when_complex("x") as (b1, x): + with y.cache_when_complex("y") as (b2, y): + with z.cache_when_complex("z") as (b3, z): + ret = IRnode.from_list( + ["seq", ["assert", z], [self._opcode, x, y, z]], typ=UINT256_T + ) + return b1.resolve(b2.resolve(b3.resolve(ret))) class AddMod(_AddMulMod):