diff --git a/tests/functional/builtins/codegen/test_keccak256.py b/tests/functional/builtins/codegen/test_keccak256.py index 90fa8b9e09..3b0b9f2018 100644 --- a/tests/functional/builtins/codegen/test_keccak256.py +++ b/tests/functional/builtins/codegen/test_keccak256.py @@ -1,3 +1,6 @@ +from vyper.utils import hex_to_int + + def test_hash_code(get_contract_with_gas_estimation, keccak): hash_code = """ @external @@ -80,3 +83,31 @@ def try32(inp: bytes32) -> bool: assert c.tryy(b"\x35" * 33) is True print("Passed KECCAK256 hash test") + + +def test_hash_constant_bytes32(get_contract_with_gas_estimation, keccak): + hex_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(bytes32) = {hex_val} +BAR: constant(bytes32) = keccak256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert "0x" + c.foo().hex() == keccak(hex_to_int(hex_val).to_bytes(32, "big")).hex() + + +def test_hash_constant_string(get_contract_with_gas_estimation, keccak): + str_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(String[66]) = "{str_val}" +BAR: constant(bytes32) = keccak256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert "0x" + c.foo().hex() == keccak(str_val.encode()).hex() diff --git a/tests/functional/builtins/codegen/test_sha256.py b/tests/functional/builtins/codegen/test_sha256.py index 468e684645..8e1b89bd31 100644 --- a/tests/functional/builtins/codegen/test_sha256.py +++ b/tests/functional/builtins/codegen/test_sha256.py @@ -2,6 +2,8 @@ import pytest +from vyper.utils import hex_to_int + pytestmark = pytest.mark.usefixtures("memory_mocker") @@ -77,3 +79,31 @@ def bar() -> bytes32: c.set(test_val, transact={}) assert c.a() == test_val assert c.bar() == hashlib.sha256(test_val).digest() + + +def test_sha256_constant_bytes32(get_contract_with_gas_estimation): + hex_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(bytes32) = {hex_val} +BAR: constant(bytes32) = sha256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() == hashlib.sha256(hex_to_int(hex_val).to_bytes(32, "big")).digest() + + +def test_sha256_constant_string(get_contract_with_gas_estimation): + str_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(String[66]) = "{str_val}" +BAR: constant(bytes32) = sha256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() == hashlib.sha256(str_val.encode()).digest() diff --git a/tests/functional/builtins/codegen/test_unary.py b/tests/functional/builtins/codegen/test_unary.py index 33f79be233..2be5c0d33f 100644 --- a/tests/functional/builtins/codegen/test_unary.py +++ b/tests/functional/builtins/codegen/test_unary.py @@ -69,16 +69,11 @@ def bar() -> decimal: def test_negation_int128(get_contract): code = """ -a: constant(int128) = -2**127 - -@external -def foo() -> int128: - return -2**127 +a: constant(int128) = min_value(int128) @external def bar() -> int128: return -(a+1) """ c = get_contract(code) - assert c.foo() == -(2**127) assert c.bar() == 2**127 - 1 diff --git a/tests/functional/builtins/folding/test_abs.py b/tests/functional/builtins/folding/test_abs.py index a91a4f1ad3..7a8e99d4bc 100644 --- a/tests/functional/builtins/folding/test_abs.py +++ b/tests/functional/builtins/folding/test_abs.py @@ -4,7 +4,7 @@ from vyper import ast as vy_ast from vyper.builtins import functions as vy_fn -from vyper.exceptions import OverflowException +from vyper.exceptions import InvalidType @pytest.mark.fuzzing @@ -21,7 +21,7 @@ def foo(a: int256) -> int256: vyper_ast = vy_ast.parse_to_ast(f"abs({a})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["abs"].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE["abs"].fold(old_node) assert contract.foo(a) == new_node.value == abs(a) @@ -35,7 +35,7 @@ def test_abs_upper_bound_folding(get_contract, a): def foo(a: int256) -> int256: return abs({a}) """ - with pytest.raises(OverflowException): + with pytest.raises(InvalidType): get_contract(source) @@ -55,7 +55,7 @@ def test_abs_lower_bound_folded(get_contract, tx_failed): source = """ @external def foo() -> int256: - return abs(-2**255) + return abs(min_value(int256)) """ - with pytest.raises(OverflowException): + with pytest.raises(InvalidType): get_contract(source) diff --git a/tests/functional/builtins/folding/test_addmod_mulmod.py b/tests/functional/builtins/folding/test_addmod_mulmod.py index 33dcc62984..03325240a5 100644 --- a/tests/functional/builtins/folding/test_addmod_mulmod.py +++ b/tests/functional/builtins/folding/test_addmod_mulmod.py @@ -24,6 +24,6 @@ def foo(a: uint256, b: uint256, c: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({a}, {b}, {c})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name].fold(old_node) assert contract.foo(a, b, c) == new_node.value diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index 63e733644f..d5c0f1ac86 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -28,7 +28,7 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.fold() assert contract.foo(a, b) == new_node.value @@ -49,7 +49,7 @@ def foo(a: uint256, b: uint256) -> uint256: old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.fold() # force bounds check, no-op because validate_numeric_bounds # already does this, but leave in for hygiene (in case # more types are added). @@ -79,7 +79,7 @@ def foo(a: int256, b: uint256) -> int256: old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.fold() validate_expected_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. @@ -104,6 +104,6 @@ def foo(a: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"~{value}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.fold() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_epsilon.py b/tests/functional/builtins/folding/test_epsilon.py index 794648cfce..3110d5eae5 100644 --- a/tests/functional/builtins/folding/test_epsilon.py +++ b/tests/functional/builtins/folding/test_epsilon.py @@ -15,6 +15,6 @@ def foo() -> {typ_name}: vyper_ast = vy_ast.parse_to_ast(f"epsilon({typ_name})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["epsilon"].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE["epsilon"].fold(old_node) assert contract.foo() == new_node.value diff --git a/tests/functional/builtins/folding/test_floor_ceil.py b/tests/functional/builtins/folding/test_floor_ceil.py index 87db23889a..8e3f14e9ec 100644 --- a/tests/functional/builtins/folding/test_floor_ceil.py +++ b/tests/functional/builtins/folding/test_floor_ceil.py @@ -30,6 +30,6 @@ def foo(a: decimal) -> int256: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name].fold(old_node) assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_fold_as_wei_value.py b/tests/functional/builtins/folding/test_fold_as_wei_value.py index 210ab51f0d..9af1618bcb 100644 --- a/tests/functional/builtins/folding/test_fold_as_wei_value.py +++ b/tests/functional/builtins/folding/test_fold_as_wei_value.py @@ -32,7 +32,7 @@ def foo(a: decimal) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value:.10f}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue().evaluate(old_node) + new_node = vy_fn.AsWeiValue().fold(old_node) assert contract.foo(value) == new_node.value @@ -51,6 +51,6 @@ def foo(a: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue().evaluate(old_node) + new_node = vy_fn.AsWeiValue().fold(old_node) assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_keccak_sha.py b/tests/functional/builtins/folding/test_keccak_sha.py index a2fe460dd1..768a46a40d 100644 --- a/tests/functional/builtins/folding/test_keccak_sha.py +++ b/tests/functional/builtins/folding/test_keccak_sha.py @@ -22,7 +22,7 @@ def foo(a: String[100]) -> bytes32: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}('''{value}''')") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name].fold(old_node) assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -41,7 +41,7 @@ def foo(a: Bytes[100]) -> bytes32: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name].fold(old_node) assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -62,6 +62,6 @@ def foo(a: Bytes[100]) -> bytes32: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name].fold(old_node) assert f"0x{contract.foo(value).hex()}" == new_node.value diff --git a/tests/functional/builtins/folding/test_len.py b/tests/functional/builtins/folding/test_len.py index edf33120dd..f4d54e202b 100644 --- a/tests/functional/builtins/folding/test_len.py +++ b/tests/functional/builtins/folding/test_len.py @@ -17,7 +17,7 @@ def foo(a: String[1024]) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"len('{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len().evaluate(old_node) + new_node = vy_fn.Len().fold(old_node) assert contract.foo(value) == new_node.value @@ -35,7 +35,7 @@ def foo(a: Bytes[1024]) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"len(b'{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len().evaluate(old_node) + new_node = vy_fn.Len().fold(old_node) assert contract.foo(value.encode()) == new_node.value @@ -53,6 +53,6 @@ def foo(a: Bytes[1024]) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"len({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len().evaluate(old_node) + new_node = vy_fn.Len().fold(old_node) assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_min_max.py b/tests/functional/builtins/folding/test_min_max.py index 309f7519c0..4548e482ca 100644 --- a/tests/functional/builtins/folding/test_min_max.py +++ b/tests/functional/builtins/folding/test_min_max.py @@ -31,7 +31,7 @@ def foo(a: decimal, b: decimal) -> decimal: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name].fold(old_node) assert contract.foo(left, right) == new_node.value @@ -50,7 +50,7 @@ def foo(a: int128, b: int128) -> int128: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name].fold(old_node) assert contract.foo(left, right) == new_node.value @@ -69,6 +69,6 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name].fold(old_node) assert contract.foo(left, right) == new_node.value diff --git a/tests/functional/builtins/folding/test_powmod.py b/tests/functional/builtins/folding/test_powmod.py index 8667ec93fd..f531e77af6 100644 --- a/tests/functional/builtins/folding/test_powmod.py +++ b/tests/functional/builtins/folding/test_powmod.py @@ -21,6 +21,6 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"pow_mod256({a}, {b})") old_node = vyper_ast.body[0].value - new_node = vy_fn.PowMod256().evaluate(old_node) + new_node = vy_fn.PowMod256().fold(old_node) assert contract.foo(a, b) == new_node.value diff --git a/tests/functional/codegen/calling_convention/test_default_parameters.py b/tests/functional/codegen/calling_convention/test_default_parameters.py index 03f5d9fca2..462748a9c7 100644 --- a/tests/functional/codegen/calling_convention/test_default_parameters.py +++ b/tests/functional/codegen/calling_convention/test_default_parameters.py @@ -305,6 +305,48 @@ def foo(a: address = empty(address)): def foo(a: int112 = min_value(int112)): self.A = a """, + """ +struct X: + x: int128 + y: address +BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345}) +@external +def out_literals(a: int128 = BAR.x + 1) -> X: + return BAR + """, + """ +struct X: + x: int128 + y: address +struct Y: + x: X + y: uint256 +BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345}) +FOO: constant(Y) = Y({x: BAR, y: 256}) +@external +def out_literals(a: int128 = FOO.x.x + 1) -> Y: + return FOO + """, + """ +struct Bar: + a: bool + +BAR: constant(Bar) = Bar({a: True}) + +@external +def foo(x: bool = True and not BAR.a): + pass + """, + """ +struct Bar: + a: uint256 + +BAR: constant(Bar) = Bar({ a: 123 }) + +@external +def foo(x: bool = BAR.a + 1 > 456): + pass + """, ] diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index 0360396f03..0af4f9f937 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -388,7 +388,7 @@ def test_int128_too_long(get_contract, tx_failed): contract_1 = """ @external def foo() -> int256: - return (2**255)-1 + return max_value(int256) """ c = get_contract(contract_1) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index bc1a12ae9e..5544b896a2 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -776,7 +776,7 @@ def test_for() -> int128: a = i return a """, - TypeMismatch, + InvalidType, ), ( """ diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 65d2df9038..7d363fadc0 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -435,7 +435,7 @@ def ok() -> {typ}: @external def should_fail() -> int256: - return -2**255 # OOB for all int/uint types with less than 256 bits + return min_value(int256) """ code = f""" diff --git a/tests/functional/codegen/types/numbers/test_constants.py b/tests/functional/codegen/types/numbers/test_constants.py index 8244bc5487..af871983ab 100644 --- a/tests/functional/codegen/types/numbers/test_constants.py +++ b/tests/functional/codegen/types/numbers/test_constants.py @@ -4,7 +4,7 @@ import pytest from vyper.compiler import compile_code -from vyper.exceptions import InvalidType +from vyper.exceptions import TypeMismatch from vyper.utils import MemoryPositions @@ -158,7 +158,7 @@ def test_custom_constants_fail(get_contract, assert_compile_failed, storage_type def foo() -> {return_type}: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) def test_constant_address(get_contract): diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 4ef6874ae9..eec79a0a46 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -315,6 +315,23 @@ def test_array(x: int128, y: int128, z: int128, w: int128) -> int128: def test_array_negative_accessor(get_contract_with_gas_estimation, assert_compile_failed): + array_constant_negative_accessor = """ +FOO: constant(int128) = -1 +@external +def test_array(x: int128, y: int128, z: int128, w: int128) -> int128: + a: int128[4] = [0, 0, 0, 0] + a[0] = x + a[1] = y + a[2] = z + a[3] = w + return a[-4] * 1000 + a[-3] * 100 + a[-2] * 10 + a[FOO] + """ + + assert_compile_failed( + lambda: get_contract_with_gas_estimation(array_constant_negative_accessor), + ArrayIndexException, + ) + array_negative_accessor = """ @external def test_array(x: int128, y: int128, z: int128, w: int128) -> int128: @@ -1728,7 +1745,7 @@ def test_constant_list_fail(get_contract, assert_compile_failed, storage_type, r def foo() -> DynArray[{return_type}, 3]: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -1740,7 +1757,7 @@ def test_constant_list_fail_2(get_contract, assert_compile_failed, storage_type, def foo() -> {return_type}: return MY_CONSTANT[0] """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index 657c4ba0b8..b5b9538c20 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -2,7 +2,7 @@ import pytest -from vyper.exceptions import ArrayIndexException, InvalidType, OverflowException, TypeMismatch +from vyper.exceptions import ArrayIndexException, OverflowException, TypeMismatch def test_list_tester_code(get_contract_with_gas_estimation): @@ -705,7 +705,7 @@ def test_constant_list_fail(get_contract, assert_compile_failed, storage_type, r def foo() -> {return_type}[3]: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -717,7 +717,7 @@ def test_constant_list_fail_2(get_contract, assert_compile_failed, storage_type, def foo() -> {return_type}: return MY_CONSTANT[0] """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -824,7 +824,7 @@ def test_constant_nested_list_fail(get_contract, assert_compile_failed, storage_ def foo() -> {return_type}[2][3]: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -838,4 +838,4 @@ def test_constant_nested_list_fail_2( def foo() -> {return_type}: return MY_CONSTANT[0][0] """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) diff --git a/tests/functional/syntax/exceptions/test_argument_exception.py b/tests/functional/syntax/exceptions/test_argument_exception.py index fc06395015..0b7ec21bdb 100644 --- a/tests/functional/syntax/exceptions/test_argument_exception.py +++ b/tests/functional/syntax/exceptions/test_argument_exception.py @@ -1,13 +1,13 @@ import pytest -from vyper import compiler +from vyper import compile_code from vyper.exceptions import ArgumentException fail_list = [ """ @external def foo(): - x = as_wei_value(5, "vader") + x: uint256 = as_wei_value(5, "vader") """, """ @external @@ -95,4 +95,4 @@ def foo(): @pytest.mark.parametrize("bad_code", fail_list) def test_function_declaration_exception(bad_code): with pytest.raises(ArgumentException): - compiler.compile_code(bad_code) + compile_code(bad_code) diff --git a/tests/functional/syntax/test_abi_decode.py b/tests/functional/syntax/test_abi_decode.py index f05ff429cd..a6665bb84c 100644 --- a/tests/functional/syntax/test_abi_decode.py +++ b/tests/functional/syntax/test_abi_decode.py @@ -26,7 +26,7 @@ def bar(j: String[32]) -> bool: @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_abi_encode_fail(bad_code, exc): +def test_abi_decode_fail(bad_code, exc): with pytest.raises(exc): compiler.compile_code(bad_code) @@ -41,5 +41,5 @@ def foo(x: Bytes[32]) -> uint256: @pytest.mark.parametrize("good_code", valid_list) -def test_abi_encode_success(good_code): +def test_abi_decode_success(good_code): assert compiler.compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_abs.py b/tests/functional/syntax/test_abs.py new file mode 100644 index 0000000000..0841ff05d6 --- /dev/null +++ b/tests/functional/syntax/test_abs.py @@ -0,0 +1,40 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +fail_list = [ + ( + """ +@external +def foo(): + y: int256 = abs( + -57896044618658097711785492504343953926634992332820282019728792003956564819968 + ) + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_abs_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(int256) = -3 +BAR: constant(int256) = abs(FOO) + +@external +def foo(): + a: int256 = BAR + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_abs_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_addmulmod.py b/tests/functional/syntax/test_addmulmod.py index ddff4d3e01..17c7b3ab8c 100644 --- a/tests/functional/syntax/test_addmulmod.py +++ b/tests/functional/syntax/test_addmulmod.py @@ -1,5 +1,6 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidType fail_list = [ @@ -25,3 +26,24 @@ def foo() -> uint256: @pytest.mark.parametrize("code,exc", fail_list) def test_add_mod_fail(assert_compile_failed, get_contract, code, exc): assert_compile_failed(lambda: get_contract(code), exc) + + +valid_list = [ + """ +FOO: constant(uint256) = 3 +BAR: constant(uint256) = 5 +BAZ: constant(uint256) = 19 +BAX: constant(uint256) = uint256_addmod(FOO, BAR, BAZ) + """, + """ +FOO: constant(uint256) = 3 +BAR: constant(uint256) = 5 +BAZ: constant(uint256) = 19 +BAX: constant(uint256) = uint256_mulmod(FOO, BAR, BAZ) + """, +] + + +@pytest.mark.parametrize("code", valid_list) +def test_addmulmod_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_as_wei_value.py b/tests/functional/syntax/test_as_wei_value.py index a5232a5c9a..a68f30424e 100644 --- a/tests/functional/syntax/test_as_wei_value.py +++ b/tests/functional/syntax/test_as_wei_value.py @@ -1,13 +1,29 @@ import pytest -from vyper.exceptions import ArgumentException, InvalidType, StructureException +from vyper import compile_code +from vyper.exceptions import ( + ArgumentException, + InvalidLiteral, + InvalidType, + OverflowException, + StructureException, + UndeclaredDefinition, +) fail_list = [ ( """ @external def foo(): - x: int128 = as_wei_value(5, szabo) + x: uint256 = as_wei_value(5, szabo) + """, + UndeclaredDefinition, + ), + ( + """ +@external +def foo(): + x: uint256 = as_wei_value(5, "szaboo") """, ArgumentException, ), @@ -28,12 +44,50 @@ def foo(): """, InvalidType, ), + ( + """ +@external +def foo() -> uint256: + return as_wei_value( + 115792089237316195423570985008687907853269984665640564039457584007913129639937, + 'milliether' + ) + """, + OverflowException, + ), + ( + """ +@external +def foo(): + x: uint256 = as_wei_value(-1, "szabo") + """, + InvalidLiteral, + ), + ( + """ +FOO: constant(uint256) = as_wei_value(5, szabo) + """, + UndeclaredDefinition, + ), + ( + """ +FOO: constant(uint256) = as_wei_value(5, "szaboo") + """, + ArgumentException, + ), + ( + """ +FOO: constant(uint256) = as_wei_value(-1, "szabo") + """, + InvalidLiteral, + ), ] @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_as_wei_fail(get_contract_with_gas_estimation, bad_code, exc, assert_compile_failed): - assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), exc) +def test_as_wei_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) valid_list = [ @@ -59,6 +113,14 @@ def foo() -> uint256: x: address = 0x1234567890123456789012345678901234567890 return x.balance """, + """ +y: constant(String[5]) = "szabo" +x: constant(uint256) = as_wei_value(5, y) + +@external +def foo(): + a: uint256 = x + """, ] diff --git a/tests/functional/syntax/test_ceil.py b/tests/functional/syntax/test_ceil.py new file mode 100644 index 0000000000..41f4175d01 --- /dev/null +++ b/tests/functional/syntax/test_ceil.py @@ -0,0 +1,19 @@ +import pytest + +from vyper import compile_code + +valid_list = [ + """ +BAR: constant(decimal) = 2.5 +FOO: constant(int256) = ceil(BAR) + +@external +def foo(): + a: int256 = FOO + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_ceil_good(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_dynamic_array.py b/tests/functional/syntax/test_dynamic_array.py index 99a01a17c8..f566a80625 100644 --- a/tests/functional/syntax/test_dynamic_array.py +++ b/tests/functional/syntax/test_dynamic_array.py @@ -1,6 +1,6 @@ import pytest -from vyper import compiler +from vyper import compile_code from vyper.exceptions import StructureException fail_list = [ @@ -24,12 +24,21 @@ def foo(): """, StructureException, ), + ( + """ +@external +def foo(): + a: DynArray[uint256, FOO] = [1, 2, 3] + """, + StructureException, + ), ] @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_block_fail(assert_compile_failed, get_contract, bad_code, exc): - assert_compile_failed(lambda: get_contract(bad_code), exc) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) valid_list = [ @@ -48,4 +57,4 @@ def test_block_fail(assert_compile_failed, get_contract, bad_code, exc): @pytest.mark.parametrize("good_code", valid_list) def test_dynarray_pass(good_code): - assert compiler.compile_code(good_code) is not None + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_epsilon.py b/tests/functional/syntax/test_epsilon.py new file mode 100644 index 0000000000..8fa5bc70ca --- /dev/null +++ b/tests/functional/syntax/test_epsilon.py @@ -0,0 +1,19 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +fail_list = [ + ( + """ +FOO: constant(address) = epsilon(address) + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) diff --git a/tests/functional/syntax/test_floor.py b/tests/functional/syntax/test_floor.py new file mode 100644 index 0000000000..5c30aecbe1 --- /dev/null +++ b/tests/functional/syntax/test_floor.py @@ -0,0 +1,19 @@ +import pytest + +from vyper import compile_code + +valid_list = [ + """ +BAR: constant(decimal) = 2.5 +FOO: constant(int256) = floor(BAR) + +@external +def foo(): + a: int256 = FOO + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_floor_good(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index 7c7f9c476d..99db0939fa 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -3,7 +3,12 @@ import pytest from vyper import compiler -from vyper.exceptions import ArgumentException, StateAccessViolation, StructureException +from vyper.exceptions import ( + ArgumentException, + StateAccessViolation, + StructureException, + TypeMismatch, +) fail_list = [ ( @@ -20,6 +25,17 @@ def foo(): ( """ @external +def bar(): + for i in range(1,2,bound=0): + pass + """, + StructureException, + "Bound must be at least 1", + "0", + ), + ( + """ +@external def foo(): x: uint256 = 100 for _ in range(10, bound=x): @@ -181,6 +197,44 @@ def foo(x: int128): "Value must be a literal integer, unless a bound is specified", "x", ), + ( + """ +@external +def bar(x: uint256): + for i in range(3, x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +FOO: constant(int128) = 3 +BAR: constant(uint256) = 7 + +@external +def foo(): + for i in range(FOO, BAR): + pass + """, + TypeMismatch, + "Iterator values are of different types", + "range(FOO, BAR)", + ), + ( + """ +FOO: constant(int128) = -1 + +@external +def foo(): + for i in range(10, bound=FOO): + pass + """, + StructureException, + "Bound must be at least 1", + "FOO", + ), ] for_code_regex = re.compile(r"for .+ in (.*):") diff --git a/tests/functional/syntax/test_len.py b/tests/functional/syntax/test_len.py index bbde7e4897..b8cc61df1d 100644 --- a/tests/functional/syntax/test_len.py +++ b/tests/functional/syntax/test_len.py @@ -1,7 +1,6 @@ import pytest -from pytest import raises -from vyper import compiler +from vyper import compile_code from vyper.exceptions import TypeMismatch fail_list = [ @@ -21,11 +20,11 @@ def foo(inp: int128) -> uint256: @pytest.mark.parametrize("bad_code", fail_list) def test_block_fail(bad_code): if isinstance(bad_code, tuple): - with raises(bad_code[1]): - compiler.compile_code(bad_code[0]) + with pytest.raises(bad_code[1]): + compile_code(bad_code[0]) else: - with raises(TypeMismatch): - compiler.compile_code(bad_code) + with pytest.raises(TypeMismatch): + compile_code(bad_code) valid_list = [ @@ -39,9 +38,18 @@ def foo(inp: Bytes[10]) -> uint256: def foo(inp: String[10]) -> uint256: return len(inp) """, + """ +BAR: constant(String[5]) = "vyper" +FOO: constant(uint256) = len(BAR) + +@external +def foo() -> uint256: + a: uint256 = FOO + return a + """, ] @pytest.mark.parametrize("good_code", valid_list) def test_list_success(good_code): - assert compiler.compile_code(good_code) is not None + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_method_id.py b/tests/functional/syntax/test_method_id.py new file mode 100644 index 0000000000..849c1b0d55 --- /dev/null +++ b/tests/functional/syntax/test_method_id.py @@ -0,0 +1,50 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidLiteral, InvalidType + +fail_list = [ + ( + """ +@external +def foo(): + a: Bytes[4] = method_id("bar ()") + """, + InvalidLiteral, + ), + ( + """ +FOO: constant(Bytes[4]) = method_id(1) + """, + InvalidType, + ), + ( + """ +FOO: constant(Bytes[4]) = method_id("bar ()") + """, + InvalidLiteral, + ), +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_method_id_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(String[5]) = "foo()" +BAR: constant(Bytes[4]) = method_id(FOO) + +@external +def foo(a: Bytes[4] = BAR): + pass + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_method_id_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_minmax.py b/tests/functional/syntax/test_minmax.py index 2ad3d363f1..78ee74635c 100644 --- a/tests/functional/syntax/test_minmax.py +++ b/tests/functional/syntax/test_minmax.py @@ -1,6 +1,7 @@ import pytest -from vyper.exceptions import InvalidType, TypeMismatch +from vyper import compile_code +from vyper.exceptions import InvalidType, OverflowException, TypeMismatch fail_list = [ ( @@ -19,9 +20,45 @@ def foo(): """, TypeMismatch, ), + ( + """ +@external +def foo(): + a: decimal = min(1.0, 18707220957835557353007165858768422651595.9365500928) + """, + OverflowException, + ), ] @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_block_fail(assert_compile_failed, get_contract_with_gas_estimation, bad_code, exc): - assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), exc) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(uint256) = 123 +BAR: constant(uint256) = 456 +BAZ: constant(uint256) = min(FOO, BAR) + +@external +def foo(): + a: uint256 = BAZ + """, + """ +FOO: constant(uint256) = 123 +BAR: constant(uint256) = 456 +BAZ: constant(uint256) = max(FOO, BAR) + +@external +def foo(): + a: uint256 = BAZ + """, +] + + +@pytest.mark.parametrize("good_code", valid_list) +def test_block_success(good_code): + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_minmax_value.py b/tests/functional/syntax/test_minmax_value.py index e154cad23f..8cc3370b42 100644 --- a/tests/functional/syntax/test_minmax_value.py +++ b/tests/functional/syntax/test_minmax_value.py @@ -1,21 +1,39 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidType fail_list = [ - """ + ( + """ @external def foo(): a: address = min_value(address) """, - """ + InvalidType, + ), + ( + """ @external def foo(): a: address = max_value(address) """, + InvalidType, + ), + ( + """ +FOO: constant(address) = min_value(address) + +@external +def foo(): + a: address = FOO + """, + InvalidType, + ), ] -@pytest.mark.parametrize("bad_code", fail_list) -def test_block_fail(assert_compile_failed, get_contract_with_gas_estimation, bad_code): - assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), InvalidType) +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) diff --git a/tests/functional/syntax/test_powmod.py b/tests/functional/syntax/test_powmod.py new file mode 100644 index 0000000000..12ea23152c --- /dev/null +++ b/tests/functional/syntax/test_powmod.py @@ -0,0 +1,39 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +fail_list = [ + ( + """ +@external +def foo(): + a: uint256 = pow_mod256(-1, -1) + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_powmod_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(uint256) = 3 +BAR: constant(uint256) = 5 +BAZ: constant(uint256) = pow_mod256(FOO, BAR) + +@external +def foo(): + a: uint256 = BAZ + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_powmod_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_raw_call.py b/tests/functional/syntax/test_raw_call.py index b1286e7a8e..c0b38d1d1e 100644 --- a/tests/functional/syntax/test_raw_call.py +++ b/tests/functional/syntax/test_raw_call.py @@ -1,6 +1,6 @@ import pytest -from vyper import compiler +from vyper import compile_code from vyper.exceptions import ArgumentException, InvalidType, SyntaxException, TypeMismatch fail_list = [ @@ -39,7 +39,7 @@ def foo(): @pytest.mark.parametrize("bad_code,exc", fail_list) def test_raw_call_fail(bad_code, exc): with pytest.raises(exc): - compiler.compile_code(bad_code) + compile_code(bad_code) valid_list = [ @@ -90,9 +90,23 @@ def foo(): value=self.balance - self.balances[0] ) """, + # test constants + """ +OUTSIZE: constant(uint256) = 4 +REVERT_ON_FAILURE: constant(bool) = True +@external +def foo(): + x: Bytes[9] = raw_call( + 0x1234567890123456789012345678901234567890, + b"cow", + max_outsize=OUTSIZE, + gas=595757, + revert_on_failure=REVERT_ON_FAILURE + ) + """, ] @pytest.mark.parametrize("good_code", valid_list) def test_raw_call_success(good_code): - assert compiler.compile_code(good_code) is not None + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_ternary.py b/tests/functional/syntax/test_ternary.py index 325be3e43b..6a2bb9c072 100644 --- a/tests/functional/syntax/test_ternary.py +++ b/tests/functional/syntax/test_ternary.py @@ -1,6 +1,6 @@ import pytest -from vyper.compiler import compile_code +from vyper import compile_code from vyper.exceptions import InvalidType, TypeMismatch good_list = [ @@ -82,7 +82,7 @@ def foo() -> uint256: def foo() -> uint256: return 1 if TEST else 2 """, - InvalidType, + TypeMismatch, ), ( # bad test type: variable """ diff --git a/tests/functional/syntax/test_uint2str.py b/tests/functional/syntax/test_uint2str.py new file mode 100644 index 0000000000..9e6dde30cc --- /dev/null +++ b/tests/functional/syntax/test_uint2str.py @@ -0,0 +1,19 @@ +import pytest + +from vyper import compile_code + +valid_list = [ + """ +FOO: constant(uint256) = 3 +BAR: constant(String[78]) = uint2str(FOO) + +@external +def foo(): + a: String[78] = BAR + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_addmulmod_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_unary.py b/tests/functional/syntax/test_unary.py new file mode 100644 index 0000000000..5942ee15db --- /dev/null +++ b/tests/functional/syntax/test_unary.py @@ -0,0 +1,21 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +fail_list = [ + ( + """ +@external +def foo() -> int128: + return -2**127 + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("code,exc", fail_list) +def test_unary_fail(code, exc): + with pytest.raises(exc): + compile_code(code) diff --git a/tests/unit/ast/nodes/test_evaluate_binop_decimal.py b/tests/unit/ast/nodes/test_fold_binop_decimal.py similarity index 97% rename from tests/unit/ast/nodes/test_evaluate_binop_decimal.py rename to tests/unit/ast/nodes/test_fold_binop_decimal.py index 44b82e321d..0a586e1704 100644 --- a/tests/unit/ast/nodes/test_evaluate_binop_decimal.py +++ b/tests/unit/ast/nodes/test_fold_binop_decimal.py @@ -31,7 +31,7 @@ def foo(a: decimal, b: decimal) -> decimal: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.fold() is_valid = True except ZeroDivisionException: is_valid = False @@ -49,7 +49,7 @@ def test_binop_pow(): old_node = vyper_ast.body[0].value with pytest.raises(TypeMismatch): - old_node.evaluate() + old_node.fold() @pytest.mark.fuzzing diff --git a/tests/unit/ast/nodes/test_evaluate_binop_int.py b/tests/unit/ast/nodes/test_fold_binop_int.py similarity index 96% rename from tests/unit/ast/nodes/test_evaluate_binop_int.py rename to tests/unit/ast/nodes/test_fold_binop_int.py index 405d557f7d..c603daee46 100644 --- a/tests/unit/ast/nodes/test_evaluate_binop_int.py +++ b/tests/unit/ast/nodes/test_fold_binop_int.py @@ -27,7 +27,7 @@ def foo(a: int128, b: int128) -> int128: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.fold() is_valid = True except ZeroDivisionException: is_valid = False @@ -57,7 +57,7 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.fold() is_valid = new_node.value >= 0 except ZeroDivisionException: is_valid = False @@ -85,7 +85,7 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{left} ** {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.fold() assert contract.foo(left, right) == new_node.value diff --git a/tests/unit/ast/nodes/test_evaluate_boolop.py b/tests/unit/ast/nodes/test_fold_boolop.py similarity index 98% rename from tests/unit/ast/nodes/test_evaluate_boolop.py rename to tests/unit/ast/nodes/test_fold_boolop.py index 8b70537c39..5de4b60bda 100644 --- a/tests/unit/ast/nodes/test_evaluate_boolop.py +++ b/tests/unit/ast/nodes/test_fold_boolop.py @@ -26,7 +26,7 @@ def foo({input_value}) -> bool: vyper_ast = vy_ast.parse_to_ast(literal_op) old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.fold() assert contract.foo(*values) == new_node.value diff --git a/tests/unit/ast/nodes/test_evaluate_compare.py b/tests/unit/ast/nodes/test_fold_compare.py similarity index 95% rename from tests/unit/ast/nodes/test_evaluate_compare.py rename to tests/unit/ast/nodes/test_fold_compare.py index 07f8e70de6..735be43cfd 100644 --- a/tests/unit/ast/nodes/test_evaluate_compare.py +++ b/tests/unit/ast/nodes/test_fold_compare.py @@ -21,7 +21,7 @@ def foo(a: int128, b: int128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.fold() assert contract.foo(left, right) == new_node.value @@ -41,7 +41,7 @@ def foo(a: uint128, b: uint128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.fold() assert contract.foo(left, right) == new_node.value @@ -65,7 +65,7 @@ def bar(a: int128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} in {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.fold() # check runtime == fully folded assert contract.foo(left, right) == new_node.value @@ -94,7 +94,7 @@ def bar(a: int128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} not in {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.fold() # check runtime == fully folded assert contract.foo(left, right) == new_node.value @@ -109,4 +109,4 @@ def test_compare_type_mismatch(op): vyper_ast = vy_ast.parse_to_ast(f"1 {op} 1.0") old_node = vyper_ast.body[0].value with pytest.raises(UnfoldableNode): - old_node.evaluate() + old_node.fold() diff --git a/tests/unit/ast/nodes/test_evaluate_subscript.py b/tests/unit/ast/nodes/test_fold_subscript.py similarity index 94% rename from tests/unit/ast/nodes/test_evaluate_subscript.py rename to tests/unit/ast/nodes/test_fold_subscript.py index ca50a076a5..59a5725b2c 100644 --- a/tests/unit/ast/nodes/test_evaluate_subscript.py +++ b/tests/unit/ast/nodes/test_fold_subscript.py @@ -21,6 +21,6 @@ def foo(array: int128[10], idx: uint256) -> int128: vyper_ast = vy_ast.parse_to_ast(f"{array}[{idx}]") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.fold() assert contract.foo(array, idx) == new_node.value diff --git a/tests/unit/ast/nodes/test_evaluate_unaryop.py b/tests/unit/ast/nodes/test_fold_unaryop.py similarity index 96% rename from tests/unit/ast/nodes/test_evaluate_unaryop.py rename to tests/unit/ast/nodes/test_fold_unaryop.py index 63d7a0b7ff..dc447955ed 100644 --- a/tests/unit/ast/nodes/test_evaluate_unaryop.py +++ b/tests/unit/ast/nodes/test_fold_unaryop.py @@ -14,7 +14,7 @@ def foo(a: bool) -> bool: vyper_ast = vy_ast.parse_to_ast(f"not {bool_cond}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.fold() assert contract.foo(bool_cond) == new_node.value diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index dc49f72561..2fbfb73ccf 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -69,12 +69,14 @@ def test_basic_ast(): "lineno": 2, "node_id": 2, "src": "1:1:0", + "type": "int128", }, "value": None, "is_constant": False, "is_immutable": False, "is_public": False, "is_transient": False, + "type": "int128", } diff --git a/tests/unit/ast/test_folding.py b/tests/unit/ast/test_folding.py index 62a7140e97..8347fa90dd 100644 --- a/tests/unit/ast/test_folding.py +++ b/tests/unit/ast/test_folding.py @@ -2,194 +2,408 @@ from vyper import ast as vy_ast from vyper.ast import folding -from vyper.exceptions import OverflowException +from vyper.exceptions import InvalidType, OverflowException +from vyper.semantics import validate_semantics -def test_integration(): - test_ast = vy_ast.parse_to_ast("[1+2, 6+7][8-8]") - expected_ast = vy_ast.parse_to_ast("3") +def test_integration(dummy_input_bundle): + test = """ +@external +def foo(): + a: uint256 = [1+2, 6+7][8-8] + """ + + expected = """ +@external +def foo(): + a: uint256 = 3 + """ + test_ast = vy_ast.parse_to_ast(test) + expected_ast = vy_ast.parse_to_ast(expected) + + validate_semantics(test_ast, dummy_input_bundle) folding.fold(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) -def test_replace_binop_simple(): - test_ast = vy_ast.parse_to_ast("1 + 2") - expected_ast = vy_ast.parse_to_ast("3") +def test_replace_binop_simple(dummy_input_bundle): + test = """ +@external +def foo(): + a: uint256 = 1 + 2 + """ + expected = """ +@external +def foo(): + a: uint256 = 3 + """ + + test_ast = vy_ast.parse_to_ast(test) + expected_ast = vy_ast.parse_to_ast(expected) + + validate_semantics(test_ast, dummy_input_bundle) folding.replace_literal_ops(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) -def test_replace_binop_nested(): - test_ast = vy_ast.parse_to_ast("((6 + (2**4)) * 4) / 2") - expected_ast = vy_ast.parse_to_ast("44") +def test_replace_binop_nested(dummy_input_bundle): + test = """ +@external +def foo(): + a: uint256 = ((6 + (2**4)) * 4) / 2 + """ + + expected = """ +@external +def foo(): + a: uint256 = 44 + """ + test_ast = vy_ast.parse_to_ast(test) + expected_ast = vy_ast.parse_to_ast(expected) + validate_semantics(test_ast, dummy_input_bundle) folding.replace_literal_ops(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) -def test_replace_binop_nested_intermediate_overflow(): - test_ast = vy_ast.parse_to_ast("2**255 * 2 / 10") +def test_replace_binop_nested_intermediate_overflow(dummy_input_bundle): + test = """ +@external +def foo(): + a: uint256 = 2**255 * 2 / 10 + """ + test_ast = vy_ast.parse_to_ast(test) with pytest.raises(OverflowException): - folding.fold(test_ast) + validate_semantics(test_ast, dummy_input_bundle) -def test_replace_binop_nested_intermediate_underflow(): - test_ast = vy_ast.parse_to_ast("-2**255 * 2 - 10 + 100") - with pytest.raises(OverflowException): - folding.fold(test_ast) +def test_replace_binop_nested_intermediate_underflow(dummy_input_bundle): + test = """ +@external +def foo(): + a: int256 = -2**255 * 2 - 10 + 100 + """ + test_ast = vy_ast.parse_to_ast(test) + with pytest.raises(InvalidType): + validate_semantics(test_ast, dummy_input_bundle) -def test_replace_decimal_nested_intermediate_overflow(): - test_ast = vy_ast.parse_to_ast( - "18707220957835557353007165858768422651595.9365500927 + 1e-10 - 1e-10" - ) +def test_replace_decimal_nested_intermediate_overflow(dummy_input_bundle): + test = """ +@external +def foo(): + a: decimal = 18707220957835557353007165858768422651595.9365500927 + 1e-10 - 1e-10 + """ + test_ast = vy_ast.parse_to_ast(test) with pytest.raises(OverflowException): - folding.fold(test_ast) + validate_semantics(test_ast, dummy_input_bundle) -def test_replace_decimal_nested_intermediate_underflow(): - test_ast = vy_ast.parse_to_ast( - "-18707220957835557353007165858768422651595.9365500928 - 1e-10 + 1e-10" - ) +def test_replace_decimal_nested_intermediate_underflow(dummy_input_bundle): + test = """ +@external +def foo(): + a: decimal = -18707220957835557353007165858768422651595.9365500928 - 1e-10 + 1e-10 + """ + test_ast = vy_ast.parse_to_ast(test) with pytest.raises(OverflowException): - folding.fold(test_ast) + validate_semantics(test_ast, dummy_input_bundle) -def test_replace_literal_ops(): - test_ast = vy_ast.parse_to_ast("[not True, True and False, True or False]") - expected_ast = vy_ast.parse_to_ast("[False, False, True]") +def test_replace_literal_ops(dummy_input_bundle): + test = """ +@external +def foo(): + a: bool[3] = [not True, True and False, True or False] + """ + + expected = """ +@external +def foo(): + a: bool[3] = [False, False, True] + """ + test_ast = vy_ast.parse_to_ast(test) + expected_ast = vy_ast.parse_to_ast(expected) + validate_semantics(test_ast, dummy_input_bundle) folding.replace_literal_ops(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) -def test_replace_subscripts_simple(): - test_ast = vy_ast.parse_to_ast("[foo, bar, baz][1]") - expected_ast = vy_ast.parse_to_ast("bar") +def test_replace_subscripts_simple(dummy_input_bundle): + test = """ +@external +def foo(): + a: uint256 = [1, 2, 3][1] + """ + + expected = """ +@external +def foo(): + a: uint256 = 2 + """ + test_ast = vy_ast.parse_to_ast(test) + expected_ast = vy_ast.parse_to_ast(expected) + validate_semantics(test_ast, dummy_input_bundle) folding.replace_subscripts(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) -def test_replace_subscripts_nested(): - test_ast = vy_ast.parse_to_ast("[[0, 1], [2, 3], [4, 5]][2][1]") - expected_ast = vy_ast.parse_to_ast("5") +def test_replace_subscripts_nested(dummy_input_bundle): + test = """ +@external +def foo(): + a: uint256 = [[0, 1], [2, 3], [4, 5]][2][1] + """ + + expected = """ +@external +def foo(): + a: uint256 = 5 + """ + test_ast = vy_ast.parse_to_ast(test) + expected_ast = vy_ast.parse_to_ast(expected) + validate_semantics(test_ast, dummy_input_bundle) folding.replace_subscripts(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) constants_modified = [ - "bar = FOO", - "bar: int128[FOO]", - "[1, 2, FOO]", - "def bar(a: int128 = FOO): pass", - "log bar(FOO)", - "FOO + 1", - "a: int128[FOO / 2]", - "a[FOO - 1] = 44", + """ +FOO: constant(uint256) = 4 + +@external +def foo(): + bar: uint256 = 1 + bar = FOO + """, + """ +FOO: constant(uint256) = 4 +bar: int128[FOO] + """, + """ +FOO: constant(uint256) = 4 + +@external +def foo(): + a: uint256[3] = [1, 2, FOO] + """, + """ +FOO: constant(uint256) = 4 +@external +def bar(a: uint256 = FOO): + pass + """, + """ +FOO: constant(uint256) = 4 + +event bar: + a: uint256 + +@external +def foo(): + log bar(FOO) + """, + """ +FOO: constant(uint256) = 4 + +@external +def foo(): + a: uint256 = FOO + 1 + """, + """ +FOO: constant(uint256) = 4 + +@external +def foo(): + a: int128[FOO / 2] = [1, 2] + """, + """ +FOO: constant(uint256) = 4 + +@external +def bar(x: DynArray[uint256, 4]): + a: DynArray[uint256, 4] = x + a[FOO - 1] = 44 + """, ] @pytest.mark.parametrize("source", constants_modified) -def test_replace_constant(source): +def test_replace_constant(dummy_input_bundle, source): unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) - folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True) + validate_semantics(folded_ast, dummy_input_bundle) + folding.replace_user_defined_constants(folded_ast) assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) constants_unmodified = [ - "FOO = 42", - "self.FOO = 42", - "bar = FOO()", - "FOO()", - "bar = FOO()", - "bar = self.FOO", - "log FOO(bar)", - "[1, 2, FOO()]", - "FOO[42] = 2", -] + """ +FOO: immutable(uint256) +@external +def __init__(): + FOO = 42 + """, + """ +FOO: uint256 -@pytest.mark.parametrize("source", constants_unmodified) -def test_replace_constant_no(source): - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) +@external +def foo(): + self.FOO = 42 + """, + """ +bar: uint256 - folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True) +@internal +def FOO() -> uint256: + return 123 - assert vy_ast.compare_nodes(unmodified_ast, folded_ast) +@external +def foo(): + bar: uint256 = 456 + bar = self.FOO() + """, + """ +@internal +def FOO(): + pass + +@external +def foo(): + self.FOO() + """, + """ +FOO: uint256 +@external +def foo(): + bar: uint256 = 1 + bar = self.FOO + """, + """ +event FOO: + a: uint256 -userdefined_modified = [ - "FOO", - "foo = FOO", - "foo: int128[FOO] = 42", - "foo = [FOO]", - "foo += FOO", - "def foo(bar: int128 = FOO): pass", - "def foo(): bar = FOO", - "def foo(): return FOO", +@external +def foo(bar: uint256): + log FOO(bar) + """, + """ +@internal +def FOO() -> uint256: + return 3 + +@external +def foo(): + a: uint256[3] = [1, 2, self.FOO()] + """, + """ +@external +def foo(): + FOO: DynArray[uint256, 5] = [1, 2, 3, 4, 5] + FOO[4] = 2 + """, ] -@pytest.mark.parametrize("source", userdefined_modified) -def test_replace_userdefined_constant(source): - source = f"FOO: constant(int128) = 42\n{source}" - +@pytest.mark.parametrize("source", constants_unmodified) +def test_replace_constant_no(dummy_input_bundle, source): unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) + validate_semantics(folded_ast, dummy_input_bundle) folding.replace_user_defined_constants(folded_ast) - assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) + assert vy_ast.compare_nodes(unmodified_ast, folded_ast) -userdefined_unmodified = [ - "FOO: constant(int128) = 42", - "FOO = 42", - "FOO += 42", - "FOO()", - "def foo(FOO: int128 = 42): pass", - "def foo(): FOO = 42", - "def FOO(): pass", +userdefined_modified = [ + """ +@external +def foo(): + foo: int128 = FOO + """, + """ +@external +def foo(): + foo: DynArray[int128, FOO] = [] + """, + """ +@external +def foo(): + foo: int128[1] = [FOO] + """, + """ +@external +def foo(): + foo: int128 = 3 + foo += FOO + """, + """ +@external +def foo(bar: int128 = FOO): + pass + """, + """ +@external +def foo() -> int128: + return FOO + """, ] -@pytest.mark.parametrize("source", userdefined_unmodified) -def test_replace_userdefined_constant_no(source): +@pytest.mark.parametrize("source", userdefined_modified) +def test_replace_userdefined_constant(dummy_input_bundle, source): source = f"FOO: constant(int128) = 42\n{source}" unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) + validate_semantics(folded_ast, dummy_input_bundle) folding.replace_user_defined_constants(folded_ast) - assert vy_ast.compare_nodes(unmodified_ast, folded_ast) + assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) dummy_address = "0x000000000000000000000000000000000000dEaD" -userdefined_attributes = [("b: uint256 = ADDR.balance", f"b: uint256 = {dummy_address}.balance")] +userdefined_attributes = [ + ( + """ +@external +def foo(): + b: uint256 = ADDR.balance + """, + f""" +@external +def foo(): + b: uint256 = {dummy_address}.balance + """, + ) +] @pytest.mark.parametrize("source", userdefined_attributes) -def test_replace_userdefined_attribute(source): +def test_replace_userdefined_attribute(dummy_input_bundle, source): preamble = f"ADDR: constant(address) = {dummy_address}" l_source = f"{preamble}\n{source[0]}" r_source = f"{preamble}\n{source[1]}" l_ast = vy_ast.parse_to_ast(l_source) + validate_semantics(l_ast, dummy_input_bundle) folding.replace_user_defined_constants(l_ast) r_ast = vy_ast.parse_to_ast(r_source) @@ -197,11 +411,24 @@ def test_replace_userdefined_attribute(source): assert vy_ast.compare_nodes(l_ast, r_ast) -userdefined_struct = [("b: Foo = FOO", "b: Foo = Foo({a: 123, b: 456})")] +userdefined_struct = [ + ( + """ +@external +def foo(): + b: Foo = FOO + """, + """ +@external +def foo(): + b: Foo = Foo({a: 123, b: 456}) + """, + ) +] @pytest.mark.parametrize("source", userdefined_struct) -def test_replace_userdefined_struct(source): +def test_replace_userdefined_struct(dummy_input_bundle, source): preamble = """ struct Foo: a: uint256 @@ -213,6 +440,7 @@ def test_replace_userdefined_struct(source): r_source = f"{preamble}\n{source[1]}" l_ast = vy_ast.parse_to_ast(l_source) + validate_semantics(l_ast, dummy_input_bundle) folding.replace_user_defined_constants(l_ast) r_ast = vy_ast.parse_to_ast(r_source) @@ -221,12 +449,23 @@ def test_replace_userdefined_struct(source): userdefined_nested_struct = [ - ("b: Foo = FOO", "b: Foo = Foo({f1: Bar({b1: 123, b2: 456}), f2: 789})") + ( + """ +@external +def foo(): + b: Foo = FOO + """, + """ +@external +def foo(): + b: Foo = Foo({f1: Bar({b1: 123, b2: 456}), f2: 789}) + """, + ) ] @pytest.mark.parametrize("source", userdefined_nested_struct) -def test_replace_userdefined_nested_struct(source): +def test_replace_userdefined_nested_struct(dummy_input_bundle, source): preamble = """ struct Bar: b1: uint256 @@ -242,6 +481,7 @@ def test_replace_userdefined_nested_struct(source): r_source = f"{preamble}\n{source[1]}" l_ast = vy_ast.parse_to_ast(l_source) + validate_semantics(l_ast, dummy_input_bundle) folding.replace_user_defined_constants(l_ast) r_ast = vy_ast.parse_to_ast(r_source) @@ -252,21 +492,34 @@ def test_replace_userdefined_nested_struct(source): builtin_folding_functions = [("ceil(4.2)", "5"), ("floor(4.2)", "4")] builtin_folding_sources = [ - "{}", - "foo = {}", - "foo = [{0}, {0}]", - "def foo(): {}", - "def foo(): return {}", - "def foo(bar: {}): pass", + """ +@external +def foo(): + foo: int256 = {} + """, + """ +foo: constant(int256[2]) = [{0}, {0}] + """, + """ +@external +def foo() -> int256: + return {} + """, + """ +@external +def foo(bar: int256 = {}): + pass + """, ] @pytest.mark.parametrize("source", builtin_folding_sources) @pytest.mark.parametrize("original,result", builtin_folding_functions) -def test_replace_builtins(source, original, result): +def test_replace_builtins(dummy_input_bundle, source, original, result): original_ast = vy_ast.parse_to_ast(source.format(original)) target_ast = vy_ast.parse_to_ast(source.format(result)) + validate_semantics(original_ast, dummy_input_bundle) folding.replace_builtin_functions(original_ast) assert vy_ast.compare_nodes(original_ast, target_ast) diff --git a/vyper/ast/README.md b/vyper/ast/README.md index 320c69da0c..9979b60cab 100644 --- a/vyper/ast/README.md +++ b/vyper/ast/README.md @@ -82,8 +82,7 @@ folding include: The process of literal folding includes: -1. Foldable node classes are evaluated via their `evaluate` method, which attempts -to create a new `Constant` from the content of the given node. +1. Foldable node classes are evaluated via their `fold` method, which attempts to create a new `Constant` from the content of the given node. 2. Replacement nodes are generated using the `from_node` class method within the new node class. 3. The modification of the tree is handled by `Module.replace_in_tree`, which locates diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index 087708a356..3eb3e163b1 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -1,10 +1,9 @@ -from typing import Optional, Union +from typing import Union from vyper.ast import nodes as vy_ast from vyper.builtins.functions import DISPATCH_TABLE -from vyper.exceptions import UnfoldableNode, UnknownType +from vyper.exceptions import UnfoldableNode from vyper.semantics.types.base import VyperType -from vyper.semantics.types.utils import type_from_annotation def fold(vyper_module: vy_ast.Module) -> None: @@ -45,10 +44,16 @@ def replace_literal_ops(vyper_module: vy_ast.Module) -> int: node_types = (vy_ast.BoolOp, vy_ast.BinOp, vy_ast.UnaryOp, vy_ast.Compare) for node in vyper_module.get_descendants(node_types, reverse=True): try: - new_node = node.evaluate() + new_node = node.fold() except UnfoldableNode: continue + # type may not be available if it is within a type's annotation + # e.g. DynArray[uint256, 2 ** 8] + typ = node._metadata.get("type") + if typ: + new_node._metadata["type"] = node._metadata["type"] + changed_nodes += 1 vyper_module.replace_in_tree(node, new_node) @@ -74,10 +79,12 @@ def replace_subscripts(vyper_module: vy_ast.Module) -> int: for node in vyper_module.get_descendants(vy_ast.Subscript, reverse=True): try: - new_node = node.evaluate() + new_node = node.fold() except UnfoldableNode: continue + new_node._metadata["type"] = node._metadata["type"] + changed_nodes += 1 vyper_module.replace_in_tree(node, new_node) @@ -107,13 +114,16 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int: name = node.func.id func = DISPATCH_TABLE.get(name) - if func is None or not hasattr(func, "evaluate"): + if func is None or not hasattr(func, "fold"): continue try: - new_node = func.evaluate(node) # type: ignore + new_node = func.fold(node) # type: ignore except UnfoldableNode: continue + if "type" in node._metadata: + new_node._metadata["type"] = node._metadata["type"] + changed_nodes += 1 vyper_module.replace_in_tree(node, new_node) @@ -138,26 +148,14 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: changed_nodes = 0 for node in vyper_module.get_children(vy_ast.VariableDecl): - if not isinstance(node.target, vy_ast.Name): - # left-hand-side of assignment is not a variable - continue if not node.is_constant: # annotation is not wrapped in `constant(...)` continue # Extract type definition from propagated annotation - type_ = None - try: - type_ = type_from_annotation(node.annotation) - except UnknownType: - # handle user-defined types e.g. structs - it's OK to not - # propagate the type annotation here because user-defined - # types can be unambiguously inferred at typechecking time - pass - - changed_nodes += replace_constant( - vyper_module, node.target.id, node.value, False, type_=type_ - ) + type_ = node._metadata["type"] + + changed_nodes += replace_constant(vyper_module, node.target.id, node.value, type_, False) return changed_nodes @@ -165,18 +163,16 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: # TODO constant folding on log events -def _replace(old_node, new_node, type_=None): +def _replace(old_node, new_node, type_): if isinstance(new_node, vy_ast.Constant): new_node = new_node.from_node(old_node, value=new_node.value) - if type_: - new_node._metadata["type"] = type_ + new_node._metadata["type"] = type_ return new_node elif isinstance(new_node, vy_ast.List): - base_type = type_.value_type if type_ else None + base_type = type_.value_type list_values = [_replace(old_node, i, type_=base_type) for i in new_node.elements] new_node = new_node.from_node(old_node, elements=list_values) - if type_: - new_node._metadata["type"] = type_ + new_node._metadata["type"] = type_ return new_node elif isinstance(new_node, vy_ast.Call): # Replace `Name` node with `Call` node @@ -188,6 +184,7 @@ def _replace(old_node, new_node, type_=None): new_node = new_node.from_node( old_node, func=new_node.func, args=new_node.args, keyword=keyword, keywords=keywords ) + new_node._metadata["type"] = type_ return new_node else: raise UnfoldableNode @@ -197,8 +194,8 @@ def replace_constant( vyper_module: vy_ast.Module, id_: str, replacement_node: Union[vy_ast.Constant, vy_ast.List, vy_ast.Call], + type_: VyperType, raise_on_error: bool, - type_: Optional[VyperType] = None, ) -> int: """ Replace references to a variable name with a literal value. diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index dba9f2a22d..25da0714ee 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -210,23 +210,6 @@ def _raise_syntax_exc(error_msg: str, ast_struct: dict) -> None: ) -def _validate_numeric_bounds( - node: Union["BinOp", "UnaryOp"], value: Union[decimal.Decimal, int] -) -> None: - if isinstance(value, decimal.Decimal): - # this will change if/when we add more decimal types - lower, upper = SizeLimits.MIN_AST_DECIMAL, SizeLimits.MAX_AST_DECIMAL - elif isinstance(value, int): - lower, upper = SizeLimits.MIN_INT256, SizeLimits.MAX_UINT256 - else: - raise CompilerPanic(f"Unexpected return type from {node._op}: {type(value)}") - if not lower <= value <= upper: - raise OverflowException( - f"Result of {node.op.description} ({value}) is outside bounds of all numeric types", - node, - ) - - class VyperNode: """ Base class for all vyper AST nodes. @@ -242,11 +225,13 @@ class VyperNode: _description : str, optional A human-readable description of the node. Used to give more verbose error messages. + _is_prefoldable : str, optional + If `True`, indicates that pre-folding should be attempted on the node. _only_empty_fields : Tuple, optional Field names that, if present, must be set to None or a `SyntaxException` is raised. This attribute is used to exclude syntax that is valid in Python but not in Vyper. - _terminus : bool, optional + _is_terminus : bool, optional If `True`, indicates that execution halts upon reaching this node. _translated_fields : Dict, optional Field names that are reassigned if encountered. Used to normalize fields @@ -390,15 +375,36 @@ def description(self): """ return getattr(self, "_description", type(self).__name__) - def evaluate(self) -> "VyperNode": + def get_folded_value_throwing(self) -> "VyperNode": + """ + Attempt to get the folded value and cache it on `_metadata["folded_value"]`. + Raises UnfoldableNode if not. + """ + if "folded_value" not in self._metadata: + self._metadata["folded_value"] = self.fold() + return self._metadata["folded_value"] + + def get_folded_value_maybe(self) -> Optional["VyperNode"]: + """ + Attempt to get the folded value and cache it on `_metadata["folded_value"]`. + Returns None if not. + """ + if "folded_value" not in self._metadata: + try: + self._metadata["folded_value"] = self.fold() + except (UnfoldableNode, VyperException): + return None + return self._metadata["folded_value"] + + def fold(self) -> "VyperNode": """ Attempt to evaluate the content of a node and generate a new node from it. - If a node cannot be evaluated it should raise `UnfoldableNode`. This base + If a node cannot be evaluated, it should raise `UnfoldableNode`. This base method acts as a catch-all to raise on any inherited classes that do not implement the method. """ - raise UnfoldableNode(f"{type(self)} cannot be evaluated") + raise UnfoldableNode(f"{type(self)} cannot be folded") def validate(self) -> None: """ @@ -769,6 +775,15 @@ class Constant(ExprNode): # inherited class for all simple constant node types __slots__ = ("value",) + def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): + super().__init__(parent, **kwargs) + + def get_folded_value_throwing(self) -> "VyperNode": + return self + + def get_folded_value_maybe(self) -> Optional["VyperNode"]: + return self + class Num(Constant): # inherited class for all numeric constant node types @@ -901,19 +916,65 @@ def s(self): return self.value +def check_literal(node: VyperNode) -> bool: + """ + Check if the given node is a literal value. + """ + if isinstance(node, Constant): + return True + elif isinstance(node, (Tuple, List)): + return all(check_literal(item) for item in node.elements) + + return False + + class List(ExprNode): __slots__ = ("elements",) + _is_prefoldable = True _translated_fields = {"elts": "elements"} + def fold(self) -> Optional[ExprNode]: + elements = [e.get_folded_value_throwing() for e in self.elements] + return type(self).from_node(self, elements=elements) + + def get_folded_value_throwing(self) -> "VyperNode": + if check_literal(self): + return self + + return super().get_folded_value_throwing() + + def get_folded_value_maybe(self) -> Optional["VyperNode"]: + if check_literal(self): + return self + + return super().get_folded_value_maybe() + class Tuple(ExprNode): __slots__ = ("elements",) + _is_prefoldable = True _translated_fields = {"elts": "elements"} def validate(self): if not self.elements: raise InvalidLiteral("Cannot have an empty tuple", self) + def fold(self) -> Optional[ExprNode]: + elements = [e.get_folded_value_throwing() for e in self.elements] + return type(self).from_node(self, elements=elements) + + def get_folded_value_throwing(self) -> "VyperNode": + if check_literal(self): + return self + + return super().get_folded_value_throwing() + + def get_folded_value_maybe(self) -> Optional["VyperNode"]: + if check_literal(self): + return self + + return super().get_folded_value_maybe() + class NameConstant(Constant): __slots__ = () @@ -933,8 +994,9 @@ class Name(ExprNode): class UnaryOp(ExprNode): __slots__ = ("op", "operand") + _is_prefoldable = True - def evaluate(self) -> ExprNode: + def fold(self) -> ExprNode: """ Attempt to evaluate the unary operation. @@ -943,16 +1005,17 @@ def evaluate(self) -> ExprNode: Int | Decimal Node representing the result of the evaluation. """ - if isinstance(self.op, Not) and not isinstance(self.operand, NameConstant): + operand = self.operand.get_folded_value_throwing() + + if isinstance(self.op, Not) and not isinstance(operand, NameConstant): raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if isinstance(self.op, USub) and not isinstance(self.operand, (Int, Decimal)): + if isinstance(self.op, USub) and not isinstance(operand, (Int, Decimal)): raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if isinstance(self.op, Invert) and not isinstance(self.operand, Int): + if isinstance(self.op, Invert) and not isinstance(operand, Int): raise UnfoldableNode("Node contains invalid field(s) for evaluation") - value = self.op._op(self.operand.value) - _validate_numeric_bounds(self, value) - return type(self.operand).from_node(self, value=value) + value = self.op._op(operand.value) + return type(operand).from_node(self, value=value) class Operator(VyperNode): @@ -981,8 +1044,9 @@ def _op(self, value): class BinOp(ExprNode): __slots__ = ("left", "op", "right") + _is_prefoldable = True - def evaluate(self) -> ExprNode: + def fold(self) -> ExprNode: """ Attempt to evaluate the arithmetic operation. @@ -991,7 +1055,7 @@ def evaluate(self) -> ExprNode: Int | Decimal Node representing the result of the evaluation. """ - left, right = self.left, self.right + left, right = [i.get_folded_value_throwing() for i in (self.left, self.right)] if type(left) is not type(right): raise UnfoldableNode("Node contains invalid field(s) for evaluation") if not isinstance(left, (Int, Decimal)): @@ -1001,10 +1065,9 @@ def evaluate(self) -> ExprNode: # on very large shifts and improve the error message for negative # values. if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256): - raise InvalidLiteral("Shift bits must be between 0 and 256", right) + raise InvalidLiteral("Shift bits must be between 0 and 256", self.right) value = self.op._op(left.value, right.value) - _validate_numeric_bounds(self, value) return type(left).from_node(self, value=value) @@ -1131,8 +1194,9 @@ class RShift(Operator): class BoolOp(ExprNode): __slots__ = ("op", "values") + _is_prefoldable = True - def evaluate(self) -> ExprNode: + def fold(self) -> ExprNode: """ Attempt to evaluate the boolean operation. @@ -1141,13 +1205,12 @@ def evaluate(self) -> ExprNode: NameConstant Node representing the result of the evaluation. """ - if next((i for i in self.values if not isinstance(i, NameConstant)), None): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") + values = [i.get_folded_value_throwing() for i in self.values] - values = [i.value for i in self.values] - if None in values: + if any(not isinstance(i, NameConstant) for i in values): raise UnfoldableNode("Node contains invalid field(s) for evaluation") + values = [i.value for i in values] value = self.op._op(values) return NameConstant.from_node(self, value=value) @@ -1179,6 +1242,7 @@ class Compare(ExprNode): """ __slots__ = ("left", "op", "right") + _is_prefoldable = True def __init__(self, *args, **kwargs): if len(kwargs["ops"]) > 1 or len(kwargs["comparators"]) > 1: @@ -1188,7 +1252,7 @@ def __init__(self, *args, **kwargs): kwargs["right"] = kwargs.pop("comparators")[0] super().__init__(*args, **kwargs) - def evaluate(self) -> ExprNode: + def fold(self) -> ExprNode: """ Attempt to evaluate the comparison. @@ -1197,7 +1261,7 @@ def evaluate(self) -> ExprNode: NameConstant Node representing the result of the evaluation. """ - left, right = self.left, self.right + left, right = [i.get_folded_value_throwing() for i in (self.left, self.right)] if not isinstance(left, Constant): raise UnfoldableNode("Node contains invalid field(s) for evaluation") @@ -1289,8 +1353,9 @@ class Attribute(ExprNode): class Subscript(ExprNode): __slots__ = ("slice", "value") + _is_prefoldable = True - def evaluate(self) -> ExprNode: + def fold(self) -> ExprNode: """ Attempt to evaluate the subscript. @@ -1302,12 +1367,18 @@ def evaluate(self) -> ExprNode: ExprNode Node representing the result of the evaluation. """ - if not isinstance(self.value, List): + slice_ = self.slice.value.get_folded_value_throwing() + value = self.value.get_folded_value_throwing() + + if not isinstance(value, List): raise UnfoldableNode("Subscript object is not a literal list") - elements = self.value.elements + elements = value.elements if len(set([type(i) for i in elements])) > 1: raise UnfoldableNode("List contains multiple node types") - idx = self.slice.get("value.value") + + if not isinstance(slice_, Int): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + idx = slice_.value if not isinstance(idx, int) or idx < 0 or idx >= len(elements): raise UnfoldableNode("Invalid index value") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 47856b6021..7531a6d02c 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -26,7 +26,9 @@ class VyperNode: def description(self): ... @classmethod def get_fields(cls: Any) -> set: ... - def evaluate(self) -> VyperNode: ... + def get_folded_value_throwing(self) -> VyperNode: ... + def get_folded_value_maybe(self) -> Optional[VyperNode]: ... + def fold(self) -> VyperNode: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... def to_dict(self) -> dict: ... diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index a5949dfd85..f955296ee0 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -1,12 +1,17 @@ import functools from typing import Any, Optional -from vyper.ast import nodes as vy_ast +from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args from vyper.codegen.expr import Expr from vyper.codegen.ir_node import IRnode -from vyper.exceptions import CompilerPanic, TypeMismatch -from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type +from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode +from vyper.semantics.analysis.base import Modifiability +from vyper.semantics.analysis.utils import ( + check_modifiability, + get_exact_type_from_node, + validate_expected_type, +) from vyper.semantics.types import TYPE_T, KwargSettings, VyperType from vyper.semantics.types.utils import type_from_annotation @@ -29,7 +34,7 @@ def process_arg(arg, expected_arg_type, context): def process_kwarg(kwarg_node, kwarg_settings, expected_kwarg_type, context): if kwarg_settings.require_literal: - return kwarg_node.value + return kwarg_node.get_folded_value_throwing().value return process_arg(kwarg_node, expected_kwarg_type, context) @@ -106,8 +111,10 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: for kwarg in node.keywords: kwarg_settings = self._kwargs[kwarg.arg] - if kwarg_settings.require_literal and not isinstance(kwarg.value, vy_ast.Constant): - raise TypeMismatch("Value for kwarg must be a literal", kwarg.value) + if kwarg_settings.require_literal and not check_modifiability( + kwarg.value, Modifiability.IMMUTABLE + ): + raise TypeMismatch("Value must be literal or environment variable", kwarg.value) self._validate_single(kwarg.value, kwarg_settings.typ) # typecheck varargs. we don't have type info from the signature, @@ -125,7 +132,10 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: return self._return_type - def infer_arg_types(self, node: vy_ast.Call) -> list[VyperType]: + def fold(self, node: vy_ast.Call) -> vy_ast.VyperNode: + raise UnfoldableNode(f"{type(self)} cannot be folded") + + def infer_arg_types(self, node: vy_ast.Call, expected_return_typ=None) -> list[VyperType]: self._validate_arg_types(node) ret = [expected for (_, expected) in self._inputs] diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index d50a31767d..a924a56010 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1,7 +1,6 @@ import hashlib import math import operator -from decimal import Decimal from vyper import ast as vy_ast from vyper.abi_types import ABI_Tuple @@ -44,7 +43,6 @@ CompilerPanic, InvalidLiteral, InvalidType, - OverflowException, StateAccessViolation, StructureException, TypeMismatch, @@ -88,7 +86,6 @@ EIP_170_LIMIT, SHA3_PER_WORD, MemoryPositions, - SizeLimits, bytes_to_int, ceil32, fourbytes_to_int, @@ -109,7 +106,7 @@ class FoldedFunctionT(BuiltinFunctionT): # Base class for nodes which should always be folded # Since foldable builtin functions are not folded before semantics validation, - # this flag is used for `check_kwargable` in semantics validation. + # this flag is used for `check_modifiability` in semantics validation. _kwargable = True @@ -126,7 +123,7 @@ def fetch_call_return(self, node): type_ = self.infer_arg_types(node)[0].typedef return type_ - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): validate_call_args(node, 1) input_typedef = TYPE_T(type_from_annotation(node.args[0])) return [input_typedef] @@ -138,12 +135,13 @@ class Floor(BuiltinFunctionT): # TODO: maybe use int136? _return_type = INT256_T - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Decimal): + value = node.args[0].get_folded_value_throwing() + if not isinstance(value, vy_ast.Decimal): raise UnfoldableNode - value = math.floor(node.args[0].value) + value = math.floor(value.value) return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -168,12 +166,13 @@ class Ceil(BuiltinFunctionT): # TODO: maybe use int136? _return_type = INT256_T - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Decimal): + value = node.args[0].get_folded_value_throwing() + if not isinstance(value, vy_ast.Decimal): raise UnfoldableNode - value = math.ceil(node.args[0].value) + value = math.ceil(value.value) return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -202,7 +201,7 @@ def fetch_call_return(self, node): return target_typedef.typedef # TODO: push this down into convert.py for more consistency - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): validate_call_args(node, 2) target_type = type_from_annotation(node.args[1]) @@ -337,7 +336,7 @@ def fetch_call_return(self, node): return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `b` b_type = get_possible_types_from_node(node.args[0]).pop() @@ -461,9 +460,9 @@ class Len(BuiltinFunctionT): _inputs = [("b", (StringT.any(), BytesT.any(), DArrayT.any()))] _return_type = UINT256_T - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 1) - arg = node.args[0] + arg = node.args[0].get_folded_value_throwing() if isinstance(arg, (vy_ast.Str, vy_ast.Bytes)): length = len(arg.value) elif isinstance(arg, vy_ast.Hex): @@ -474,7 +473,7 @@ def evaluate(self, node): return vy_ast.Int.from_node(node, value=length) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type typ = get_possible_types_from_node(node.args[0]).pop() @@ -504,7 +503,7 @@ def fetch_call_return(self, node): return_type.set_length(length) return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): if len(node.args) < 2: raise ArgumentException("Invalid argument count: expected at least 2", node) @@ -598,22 +597,23 @@ class Keccak256(BuiltinFunctionT): _inputs = [("value", (BytesT.any(), BYTES32_T, StringT.any()))] _return_type = BYTES32_T - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 1) - if isinstance(node.args[0], vy_ast.Bytes): - value = node.args[0].value - elif isinstance(node.args[0], vy_ast.Str): - value = node.args[0].value.encode() - elif isinstance(node.args[0], vy_ast.Hex): - length = len(node.args[0].value) // 2 - 1 - value = int(node.args[0].value, 16).to_bytes(length, "big") + value = node.args[0].get_folded_value_throwing() + if isinstance(value, vy_ast.Bytes): + value = value.value + elif isinstance(value, vy_ast.Str): + value = value.value.encode() + elif isinstance(value, vy_ast.Hex): + length = len(value.value) // 2 - 1 + value = int(value.value, 16).to_bytes(length, "big") else: raise UnfoldableNode hash_ = f"0x{keccak256(value).hex()}" return vy_ast.Hex.from_node(node, value=hash_) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `value` value_type = get_possible_types_from_node(node.args[0]).pop() @@ -645,22 +645,23 @@ class Sha256(BuiltinFunctionT): _inputs = [("value", (BYTES32_T, BytesT.any(), StringT.any()))] _return_type = BYTES32_T - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 1) - if isinstance(node.args[0], vy_ast.Bytes): - value = node.args[0].value - elif isinstance(node.args[0], vy_ast.Str): - value = node.args[0].value.encode() - elif isinstance(node.args[0], vy_ast.Hex): - length = len(node.args[0].value) // 2 - 1 - value = int(node.args[0].value, 16).to_bytes(length, "big") + value = node.args[0].get_folded_value_throwing() + if isinstance(value, vy_ast.Bytes): + value = value.value + elif isinstance(value, vy_ast.Str): + value = value.value.encode() + elif isinstance(value, vy_ast.Hex): + length = len(value.value) // 2 - 1 + value = int(value.value, 16).to_bytes(length, "big") else: raise UnfoldableNode hash_ = f"0x{hashlib.sha256(value).hexdigest()}" return vy_ast.Hex.from_node(node, value=hash_) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `value` value_type = get_possible_types_from_node(node.args[0]).pop() @@ -714,18 +715,20 @@ def build_IR(self, expr, args, kwargs, context): class MethodID(FoldedFunctionT): _id = "method_id" + _inputs = [("value", StringT.any())] + _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BytesT(4))} - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 1, ["output_type"]) - args = node.args - if not isinstance(args[0], vy_ast.Str): - raise InvalidType("method id must be given as a literal string", args[0]) - if " " in args[0].value: - raise InvalidLiteral("Invalid function signature - no spaces allowed.") + value = node.args[0].get_folded_value_throwing() + if not isinstance(value, vy_ast.Str): + raise InvalidType("method id must be given as a literal string", node.args[0]) + if " " in value.value: + raise InvalidLiteral("Invalid function signature - no spaces allowed.", node.args[0]) - return_type = self.infer_kwarg_types(node) - value = method_id_int(args[0].value) + return_type = self.infer_kwarg_types(node)["output_type"].typedef + value = method_id_int(value.value) if return_type.compare_type(BYTES4_T): return vy_ast.Hex.from_node(node, value=hex(value)) @@ -735,21 +738,25 @@ def evaluate(self, node): def fetch_call_return(self, node): validate_call_args(node, 1, ["output_type"]) - type_ = self.infer_kwarg_types(node) + type_ = self.infer_kwarg_types(node)["output_type"].typedef return type_ + def infer_arg_types(self, node, expected_return_typ=None): + # call `fold` for its typechecking side effects + self.fold(node) + return [self._inputs[0][1]] + def infer_kwarg_types(self, node): + # If `output_type` is not given, default to `Bytes[4]` + output_typedef = TYPE_T(BytesT(4)) if node.keywords: return_type = type_from_annotation(node.keywords[0].value) if return_type.compare_type(BYTES4_T): - return BYTES4_T - elif isinstance(return_type, BytesT) and return_type.length == 4: - return BytesT(4) - else: + output_typedef = TYPE_T(BYTES4_T) + elif not (isinstance(return_type, BytesT) and return_type.length == 4): raise ArgumentException("output_type must be Bytes[4] or bytes4", node.keywords[0]) - # If `output_type` is not given, default to `Bytes[4]` - return BytesT(4) + return {"output_type": output_typedef} class ECRecover(BuiltinFunctionT): @@ -762,7 +769,7 @@ class ECRecover(BuiltinFunctionT): ] _return_type = AddressT() - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) v_t, r_t, s_t = [get_possible_types_from_node(arg).pop() for arg in node.args[1:]] return [BYTES32_T, v_t, r_t, s_t] @@ -859,7 +866,7 @@ def fetch_call_return(self, node): return_type = self.infer_kwarg_types(node)["output_type"].typedef return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) input_type = get_possible_types_from_node(node.args[0]).pop() return [input_type, UINT256_T] @@ -974,42 +981,43 @@ class AsWeiValue(BuiltinFunctionT): } def get_denomination(self, node): - if not isinstance(node.args[1], vy_ast.Str): + value = node.args[1].get_folded_value_throwing() + if not isinstance(value, vy_ast.Str): raise ArgumentException( "Wei denomination must be given as a literal string", node.args[1] ) try: - denom = next(v for k, v in self.wei_denoms.items() if node.args[1].value in k) + denom = next(v for k, v in self.wei_denoms.items() if value.value in k) except StopIteration: - raise ArgumentException( - f"Unknown denomination: {node.args[1].value}", node.args[1] - ) from None + raise ArgumentException(f"Unknown denomination: {value.value}", node.args[1]) from None return denom - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 2) denom = self.get_denomination(node) - if not isinstance(node.args[0], (vy_ast.Decimal, vy_ast.Int)): + value = node.args[0].get_folded_value_throwing() + if not isinstance(value, (vy_ast.Decimal, vy_ast.Int)): raise UnfoldableNode - value = node.args[0].value + value = value.value if value < 0: raise InvalidLiteral("Negative wei value not allowed", node.args[0]) - if isinstance(value, int) and value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", node.args[0]) - if isinstance(value, Decimal) and value > SizeLimits.MAX_AST_DECIMAL: - raise InvalidLiteral("Value out of range for decimal", node.args[0]) - return vy_ast.Int.from_node(node, value=int(value * denom)) def fetch_call_return(self, node): self.infer_arg_types(node) return self._return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): + # call `fold` for its typechecking side effects` + try: + self.fold(node) + except UnfoldableNode: + pass + self._validate_arg_types(node) # return a concrete type instead of abstract type value_type = get_possible_types_from_node(node.args[0]).pop() @@ -1074,7 +1082,11 @@ def fetch_call_return(self, node): kwargz = {i.arg: i.value for i in node.keywords} outsize = kwargz.get("max_outsize") + if outsize is not None: + outsize = outsize.get_folded_value_throwing() revert_on_failure = kwargz.get("revert_on_failure") + if revert_on_failure is not None: + revert_on_failure = revert_on_failure.get_folded_value_throwing() revert_on_failure = revert_on_failure.value if revert_on_failure is not None else True if outsize is None or outsize.value == 0: @@ -1093,7 +1105,7 @@ def fetch_call_return(self, node): return return_type return TupleT([BoolT(), return_type]) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `data` data_type = get_possible_types_from_node(node.args[1]).pop() @@ -1268,7 +1280,7 @@ class RawRevert(BuiltinFunctionT): def fetch_call_return(self, node): return None - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) data_type = get_possible_types_from_node(node.args[0]).pop() return [data_type] @@ -1288,7 +1300,7 @@ class RawLog(BuiltinFunctionT): def fetch_call_return(self, node): self.infer_arg_types(node) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) if not isinstance(node.args[0], vy_ast.List) or len(node.args[0].elements) > 4: @@ -1338,19 +1350,18 @@ class BitwiseAnd(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_and()` is deprecated! Please use the & operator instead.") self.__class__._warned = True validate_call_args(node, 2) - for arg in node.args: - if not isinstance(arg, vy_ast.Int): + values = [i.get_folded_value_throwing() for i in node.args] + for val in values: + if not isinstance(val, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = node.args[0].value & node.args[1].value + value = values[0].value & values[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1364,19 +1375,18 @@ class BitwiseOr(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_or()` is deprecated! Please use the | operator instead.") self.__class__._warned = True validate_call_args(node, 2) - for arg in node.args: - if not isinstance(arg, vy_ast.Int): + values = [i.get_folded_value_throwing() for i in node.args] + for val in values: + if not isinstance(val, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = node.args[0].value | node.args[1].value + value = values[0].value | values[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1390,19 +1400,18 @@ class BitwiseXor(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_xor()` is deprecated! Please use the ^ operator instead.") self.__class__._warned = True validate_call_args(node, 2) - for arg in node.args: - if not isinstance(arg, vy_ast.Int): + values = [i.get_folded_value_throwing() for i in node.args] + for val in values: + if not isinstance(val, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = node.args[0].value ^ node.args[1].value + value = values[0].value ^ values[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1416,18 +1425,17 @@ class BitwiseNot(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_not()` is deprecated! Please use the ~ operator instead.") self.__class__._warned = True validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Int): + value = node.args[0].get_folded_value_throwing() + if not isinstance(value, vy_ast.Int): raise UnfoldableNode - value = node.args[0].value - if value < 0 or value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", node.args[0]) + value = value.value value = (2**256 - 1) - value return vy_ast.Int.from_node(node, value=value) @@ -1443,17 +1451,16 @@ class Shift(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def fold(self, node): if not self.__class__._warned: vyper_warn("`shift()` is deprecated! Please use the << or >> operator instead.") self.__class__._warned = True validate_call_args(node, 2) - if [i for i in node.args if not isinstance(i, vy_ast.Int)]: + args = [i.get_folded_value_throwing() for i in node.args] + if any(not isinstance(i, vy_ast.Int) for i in args): raise UnfoldableNode - value, shift = [i.value for i in node.args] - if value < 0 or value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", node.args[0]) + value, shift = [i.value for i in args] if shift < -256 or shift > 256: # this validation is performed to prevent the compiler from hanging # rather than for correctness because the post-folded constant would @@ -1470,7 +1477,7 @@ def fetch_call_return(self, node): # return type is the type of the first argument return self.infer_arg_types(node)[0] - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type instead of SignedIntegerAbstractType arg_ty = get_possible_types_from_node(node.args[0])[0] @@ -1495,17 +1502,16 @@ class _AddMulMod(BuiltinFunctionT): _inputs = [("a", UINT256_T), ("b", UINT256_T), ("c", UINT256_T)] _return_type = UINT256_T - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 3) - if isinstance(node.args[2], vy_ast.Int) and node.args[2].value == 0: + args = [i.get_folded_value_throwing() for i in node.args] + if isinstance(args[2], vy_ast.Int) and args[2].value == 0: raise ZeroDivisionException("Modulo by 0", node.args[2]) - for arg in node.args: + for arg in args: if not isinstance(arg, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = self._eval_fn(node.args[0].value, node.args[1].value) % node.args[2].value + value = self._eval_fn(args[0].value, args[1].value) % args[2].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1537,15 +1543,13 @@ class PowMod256(BuiltinFunctionT): _inputs = [("a", UINT256_T), ("b", UINT256_T)] _return_type = UINT256_T - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 2) - if next((i for i in node.args if not isinstance(i, vy_ast.Int)), None): - raise UnfoldableNode - - left, right = node.args - if left.value < 0 or right.value < 0: + values = [i.get_folded_value_throwing() for i in node.args] + if any(not isinstance(i, vy_ast.Int) for i in values): raise UnfoldableNode + left, right = values value = pow(left.value, right.value, 2**256) return vy_ast.Int.from_node(node, value=value) @@ -1560,18 +1564,13 @@ class Abs(BuiltinFunctionT): _inputs = [("value", INT256_T)] _return_type = INT256_T - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Int): + value = node.args[0].get_folded_value_throwing() + if not isinstance(value, vy_ast.Int): raise UnfoldableNode - value = node.args[0].value - if not SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_INT256: - raise OverflowException("Literal is outside of allowable range for int256") - value = abs(value) - if not SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_INT256: - raise OverflowException("Absolute literal value is outside allowable range for int256") - + value = abs(value.value) return vy_ast.Int.from_node(node, value=value) def build_IR(self, expr, context): @@ -1946,7 +1945,7 @@ def fetch_call_return(self, node): return_type = self.infer_arg_types(node).pop() return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) types_list = get_common_types(*node.args, filter_fn=lambda x: isinstance(x, IntegerT)) @@ -2004,34 +2003,26 @@ class UnsafeDiv(_UnsafeMath): class _MinMax(BuiltinFunctionT): _inputs = [("a", (DecimalT(), IntegerT.any())), ("b", (DecimalT(), IntegerT.any()))] - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 2) - if not isinstance(node.args[0], type(node.args[1])): + + left = node.args[0].get_folded_value_throwing() + right = node.args[1].get_folded_value_throwing() + if not isinstance(left, type(right)): raise UnfoldableNode - if not isinstance(node.args[0], (vy_ast.Decimal, vy_ast.Int)): + if not isinstance(left, (vy_ast.Decimal, vy_ast.Int)): raise UnfoldableNode - left, right = (i.value for i in node.args) - if isinstance(left, Decimal) and ( - min(left, right) < SizeLimits.MIN_AST_DECIMAL - or max(left, right) > SizeLimits.MAX_AST_DECIMAL - ): - raise InvalidType("Decimal value is outside of allowable range", node) - types_list = get_common_types( - *node.args, filter_fn=lambda x: isinstance(x, (IntegerT, DecimalT)) + *(left, right), filter_fn=lambda x: isinstance(x, (IntegerT, DecimalT)) ) if not types_list: raise TypeMismatch("Cannot perform action between dislike numeric types", node) - value = self._eval_fn(left, right) - return type(node.args[0]).from_node(node, value=value) + value = self._eval_fn(left.value, right.value) + return type(left).from_node(node, value=value) def fetch_call_return(self, node): - return_type = self.infer_arg_types(node).pop() - return return_type - - def infer_arg_types(self, node): self._validate_arg_types(node) types_list = get_common_types( @@ -2040,8 +2031,13 @@ def infer_arg_types(self, node): if not types_list: raise TypeMismatch("Cannot perform action between dislike numeric types", node) - type_ = types_list.pop() - return [type_, type_] + return types_list + + def infer_arg_types(self, node, expected_return_typ=None): + types_list = self.fetch_call_return(node) + # type mismatch should have been caught in `fetch_call_return` + assert expected_return_typ in types_list + return [expected_return_typ, expected_return_typ] @process_inputs def build_IR(self, expr, args, kwargs, context): @@ -2085,18 +2081,19 @@ def fetch_call_return(self, node): len_needed = math.ceil(bits * math.log(2) / math.log(10)) return StringT(len_needed) - def evaluate(self, node): + def fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Int): + value = node.args[0].get_folded_value_throwing() + if not isinstance(value, vy_ast.Int): raise UnfoldableNode - value = node.args[0].value + value = value.value if value < 0: raise InvalidType("Only unsigned ints allowed", node) value = str(value) return vy_ast.Str.from_node(node, value=value) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) input_type = get_possible_types_from_node(node.args[0]).pop() return [input_type] @@ -2493,7 +2490,7 @@ def fetch_call_return(self, node): _, output_type = self.infer_arg_types(node) return output_type.typedef - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) validate_call_args(node, 2, ["unwrap_tuple"]) @@ -2572,7 +2569,7 @@ def build_IR(self, expr, args, kwargs, context): class _MinMaxValue(TypenameFoldedFunctionT): - def evaluate(self, node): + def fold(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) @@ -2590,6 +2587,12 @@ def evaluate(self, node): ret._metadata["type"] = input_type return ret + def infer_arg_types(self, node, expected_return_typ=None): + # call `fold` for its typechecking side effects + self.fold(node) + input_typedef = TYPE_T(type_from_annotation(node.args[0])) + return [input_typedef] + class MinValue(_MinMaxValue): _id = "min_value" @@ -2608,7 +2611,7 @@ def _eval(self, type_): class Epsilon(TypenameFoldedFunctionT): _id = "epsilon" - def evaluate(self, node): + def fold(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 3063a289ab..3fbff31174 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -33,7 +33,8 @@ devdoc - Natspec developer documentation combined_json - All of the above format options combined as single JSON output layout - Storage layout of a Vyper contract -ast - AST in JSON format +ast - Annotated AST in JSON format +unannotated_ast - Unannotated AST in JSON format interface - Vyper interface of a contract external_interface - External interface of a contract, used for outside contract calls opcodes - List of opcodes as a string @@ -255,7 +256,13 @@ def compile_files( output_formats = combined_json_outputs show_version = True - translate_map = {"abi_python": "abi", "json": "abi", "ast": "ast_dict", "ir_json": "ir_dict"} + translate_map = { + "abi_python": "abi", + "json": "abi", + "ast": "ast_dict", + "unannotated_ast": "unannotated_ast_dict", + "ir_json": "ir_dict", + } final_formats = [translate_map.get(i, i) for i in output_formats] if storage_layout_paths: diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 4c7c3afaed..27266577a0 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -36,6 +36,7 @@ VyperException, tag_exceptions, ) +from vyper.semantics.analysis.base import Modifiability from vyper.semantics.types import ( AddressT, BoolT, @@ -69,6 +70,9 @@ class Expr: # TODO: Once other refactors are made reevaluate all inline imports def __init__(self, node, context): + if isinstance(node, vy_ast.VyperNode): + node = node._metadata.get("folded_value", node) + self.expr = node self.context = context @@ -184,7 +188,15 @@ def parse_Name(self): # TODO: use self.expr._expr_info elif self.expr.id in self.context.globals: varinfo = self.context.globals[self.expr.id] - assert varinfo.is_immutable, "not an immutable!" + if varinfo.modifiability == Modifiability.ALWAYS_CONSTANT: + # non-struct constants should have been dispatched via the `Expr` ctor + # using the folded value metadata + assert isinstance(varinfo.typ, StructT) + value_node = varinfo.decl_node.value + value_node = value_node._metadata.get("folded_value", value_node) + return Expr.parse_value_expr(value_node, self.context) + + assert varinfo.modifiability == Modifiability.IMMUTABLE, "not an immutable!" ofst = varinfo.position.offset diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index c87814ba15..bc7930af82 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -13,7 +13,8 @@ OUTPUT_FORMATS = { # requires vyper_module - "ast_dict": output.build_ast_dict, + "unannotated_ast_dict": output.build_ast_dict, + "ast_dict": output.build_annotated_ast_dict, "layout": output.build_layout_output, # requires global_ctx "devdoc": output.build_devdoc, diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index dc2a43720e..b9dd9b957d 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -23,6 +23,14 @@ def build_ast_dict(compiler_data: CompilerData) -> dict: return ast_dict +def build_annotated_ast_dict(compiler_data: CompilerData) -> dict: + annotated_ast_dict = { + "contract_name": str(compiler_data.contract_path), + "ast": ast_to_dict(compiler_data.vyper_module_annotated), + } + return annotated_ast_dict + + def build_devdoc(compiler_data: CompilerData) -> dict: userdoc, devdoc = parse_natspec(compiler_data.vyper_module_folded) return devdoc diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index b9b2df6ae8..7407c4f281 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -53,8 +53,10 @@ class CompilerData: ---------- vyper_module : vy_ast.Module Top-level Vyper AST node + vyper_module_annotated : vy_ast.Module + Annotated but unfolded Vyper AST vyper_module_folded : vy_ast.Module - Folded Vyper AST + Annotated and folded Vyper AST global_ctx : ModuleT Sorted, contextualized representation of the Vyper AST ir_nodes : IRnode @@ -152,16 +154,13 @@ def vyper_module(self): return self._generate_ast @cached_property - def vyper_module_unfolded(self) -> vy_ast.Module: - # This phase is intended to generate an AST for tooling use, and is not - # used in the compilation process. - - return generate_unfolded_ast(self.vyper_module, self.input_bundle) + def vyper_module_annotated(self) -> vy_ast.Module: + return generate_annotated_ast(self.vyper_module, self.input_bundle) @cached_property def _folded_module(self): return generate_folded_ast( - self.vyper_module, self.input_bundle, self.storage_layout_override + self.vyper_module_annotated, self.input_bundle, self.storage_layout_override ) @property @@ -248,9 +247,8 @@ def blueprint_bytecode(self) -> bytes: # destructive -- mutates module in place! -def generate_unfolded_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: +def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: vy_ast.validation.validate_literal_nodes(vyper_module) - vy_ast.folding.replace_builtin_functions(vyper_module) with input_bundle.search_path(Path(vyper_module.resolved_path).parent): # note: validate_semantics does type inference on the AST @@ -275,20 +273,14 @@ def generate_folded_ast( Returns ------- vy_ast.Module - Folded Vyper AST + Annotated Vyper AST StorageLayout Layout of variables in storage """ - - vy_ast.validation.validate_literal_nodes(vyper_module) + symbol_tables = set_data_positions(vyper_module, storage_layout_overrides) vyper_module_folded = copy.deepcopy(vyper_module) - vy_ast.folding.fold(vyper_module_folded) - - with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - validate_semantics(vyper_module_folded, input_bundle) - - symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides) + # vy_ast.folding.fold(vyper_module_folded) return vyper_module_folded, symbol_tables diff --git a/vyper/exceptions.py b/vyper/exceptions.py index f216069eab..f625a7d3fb 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -108,7 +108,7 @@ def __str__(self): if isinstance(node, vy_ast.VyperNode): module_node = node.get_ancestor(vy_ast.Module) - if module_node.get("path") not in (None, ""): + if module_node and module_node.get("path") not in (None, ""): node_msg = f'{node_msg}contract "{module_node.path}:{node.lineno}", ' fn_node = node.get_ancestor(vy_ast.FunctionDef) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 4d1b1cdbab..5f17357a0f 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -97,6 +97,14 @@ def from_abi(cls, abi_dict: Dict) -> "StateMutability": # specifying a state mutability modifier at all. Do the same here. +class Modifiability(enum.IntEnum): + MODIFIABLE = enum.auto() + IMMUTABLE = enum.auto() + NOT_MODIFIABLE = enum.auto() + CONSTANT_IN_CURRENT_TX = enum.auto() + ALWAYS_CONSTANT = enum.auto() + + class DataPosition: _location: DataLocation @@ -182,7 +190,7 @@ class ImportInfo(AnalysisResult): class VarInfo: """ VarInfo are objects that represent the type of a variable, - plus associated metadata like location and constancy attributes + plus associated metadata like location and modifiability attributes Object Attributes ----------------- @@ -192,9 +200,8 @@ class VarInfo: typ: VyperType location: DataLocation = DataLocation.UNSET - is_constant: bool = False + modifiability: Modifiability = Modifiability.MODIFIABLE is_public: bool = False - is_immutable: bool = False is_transient: bool = False is_local_var: bool = False decl_node: Optional[vy_ast.VyperNode] = None @@ -225,11 +232,10 @@ class ExprInfo: typ: VyperType var_info: Optional[VarInfo] = None location: DataLocation = DataLocation.UNSET - is_constant: bool = False - is_immutable: bool = False + modifiability: Modifiability = Modifiability.MODIFIABLE def __post_init__(self): - should_match = ("typ", "location", "is_constant", "is_immutable") + should_match = ("typ", "location", "modifiability") if self.var_info is not None: for attr in should_match: if getattr(self.var_info, attr) != getattr(self, attr): @@ -241,8 +247,7 @@ def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": var_info.typ, var_info=var_info, location=var_info.location, - is_constant=var_info.is_constant, - is_immutable=var_info.is_immutable, + modifiability=var_info.modifiability, ) @classmethod @@ -253,7 +258,7 @@ def copy_with_type(self, typ: VyperType) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ - to_copy = ("location", "is_constant", "is_immutable") + to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} return self.__class__(typ=typ, **fields) @@ -277,9 +282,9 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil if self.location == DataLocation.CALLDATA: raise ImmutableViolation("Cannot write to calldata", node) - if self.is_constant: + if self.modifiability == Modifiability.ALWAYS_CONSTANT: raise ImmutableViolation("Constant value cannot be written to", node) - if self.is_immutable: + if self.modifiability == Modifiability.IMMUTABLE: if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": raise ImmutableViolation("Immutable value cannot be written to", node) # TODO: we probably want to remove this restriction. diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index a3ebf85fa2..417e9e7018 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -18,7 +18,7 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -186,16 +186,18 @@ def __init__( self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["func_type"] - self.expr_visitor = _ExprVisitor(self.func) + self.expr_visitor = ExprVisitor(self.func) def analyze(self): # allow internal function params to be mutable - location, is_immutable = ( - (DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True) + location, modifiability = ( + (DataLocation.MEMORY, Modifiability.MODIFIABLE) + if self.func.is_internal + else (DataLocation.CALLDATA, Modifiability.NOT_MODIFIABLE) ) for arg in self.func.arguments: self.namespace[arg.name] = VarInfo( - arg.typ, location=location, is_immutable=is_immutable + arg.typ, location=location, modifiability=modifiability ) for node in self.fn_node.body: @@ -358,7 +360,8 @@ def visit_For(self, node): else: # iteration over a variable or literal list - if isinstance(node.iter, vy_ast.List) and len(node.iter.elements) == 0: + iter_val = node.iter.get_folded_value_maybe() + if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) type_list = [ @@ -421,7 +424,9 @@ def visit_For(self, node): # type check the for loop body using each possible type for iterator value with self.namespace.enter_scope(): - self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) + self.namespace[iter_name] = VarInfo( + possible_target_type, modifiability=Modifiability.ALWAYS_CONSTANT + ) try: with NodeMetadata.enter_typechecker_speculation(): @@ -523,10 +528,10 @@ def visit_Return(self, node): self.expr_visitor.visit(node.value, self.func.return_type) -class _ExprVisitor(VyperNodeVisitorBase): +class ExprVisitor(VyperNodeVisitorBase): scope_name = "function" - def __init__(self, fn_node: ContractFunctionT): + def __init__(self, fn_node: Optional[ContractFunctionT] = None): self.func = fn_node def visit(self, node, typ): @@ -543,6 +548,12 @@ def visit(self, node, typ): # annotate node._metadata["type"] = typ + # validate and annotate folded value + folded_value = node._metadata.get("folded_value") + if folded_value: + validate_expected_type(folded_value, typ) + folded_value._metadata["type"] = typ + def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) @@ -551,10 +562,10 @@ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: # if self.func.mutability < expr_info.mutability: # raise ... - if self.func.mutability != StateMutability.PAYABLE: + if self.func and self.func.mutability != StateMutability.PAYABLE: _validate_msg_value_access(node) - if self.func.mutability == StateMutability.PURE: + if self.func and self.func.mutability == StateMutability.PURE: _validate_pure_access(node, typ) value_type = get_exact_type_from_node(node.value) @@ -589,7 +600,7 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: if isinstance(call_type, ContractFunctionT): # function calls - if call_type.is_internal: + if self.func and call_type.is_internal: self.func.called_functions.add(call_type) for arg, typ in zip(node.args, call_type.argument_types): self.visit(arg, typ) @@ -615,7 +626,7 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: self.visit(arg, arg_type) else: # builtin functions - arg_types = call_type.infer_arg_types(node) + arg_types = call_type.infer_arg_types(node, typ) # `infer_arg_types` already calls `validate_expected_type` for arg, arg_type in zip(node.args, arg_types): self.visit(arg, arg_type) @@ -680,7 +691,7 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: self.visit(element, typ.value_type) def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: - if self.func.mutability == StateMutability.PURE: + if self.func and self.func.mutability == StateMutability.PURE: _validate_self_reference(node) if not isinstance(typ, TYPE_T): @@ -691,7 +702,7 @@ def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: # don't recurse; can't annotate AST children of type definition return - if isinstance(node.value, vy_ast.List): + if isinstance(node.value, (vy_ast.List, vy_ast.Subscript)): possible_base_types = get_possible_types_from_node(node.value) for possible_type in possible_base_types: @@ -746,7 +757,9 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: """ validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} - start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args + start, end = ( + (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else [i for i in node.args] + ) all_args = (start, end, *kwargs.values()) for arg1 in all_args: @@ -756,21 +769,23 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: if not type_list: raise TypeMismatch("Iterator values are of different types", node) + folded_start, folded_end = [i.get_folded_value_maybe() for i in (start, end)] if "bound" in kwargs: bound = kwargs["bound"] - if not isinstance(bound, vy_ast.Num): + folded_bound = bound.get_folded_value_maybe() + if not isinstance(folded_bound, vy_ast.Num): raise StateAccessViolation("Bound must be a literal", bound) - if bound.value <= 0: + if folded_bound.value <= 0: raise StructureException("Bound must be at least 1", bound) if isinstance(start, vy_ast.Num) and isinstance(end, vy_ast.Num): error = "Please remove the `bound=` kwarg when using range with constants" raise StructureException(error, bound) else: - for arg in (start, end): - if not isinstance(arg, vy_ast.Num): + for original_arg, folded_arg in zip([start, end], [folded_start, folded_end]): + if not isinstance(folded_arg, vy_ast.Num): error = "Value must be a literal integer, unless a bound is specified" - raise StateAccessViolation(error, arg) - if end.value <= start.value: + raise StateAccessViolation(error, original_arg) + if folded_end.value <= folded_start.value: raise StructureException("End must be greater than start", end) return type_list diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index fb536b7ab7..9d828eaa2d 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -20,12 +20,13 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import ImportInfo, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.import_graph import ImportGraph -from vyper.semantics.analysis.local import validate_functions +from vyper.semantics.analysis.local import ExprVisitor, validate_functions +from vyper.semantics.analysis.pre_typecheck import pre_typecheck from vyper.semantics.analysis.utils import ( - check_constant, + check_modifiability, get_exact_type_from_node, validate_expected_type, ) @@ -55,6 +56,7 @@ def validate_semantics_r( namespace = get_namespace() with namespace.enter_scope(), import_graph.enter_path(module_ast): + pre_typecheck(module_ast) analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace, import_graph, is_interface) ret = analyzer.analyze() @@ -260,6 +262,14 @@ def visit_VariableDecl(self, node): else DataLocation.STORAGE ) + modifiability = ( + Modifiability.IMMUTABLE + if node.is_immutable + else Modifiability.ALWAYS_CONSTANT + if node.is_constant + else Modifiability.MODIFIABLE + ) + type_ = type_from_annotation(node.annotation, data_loc) if node.is_transient and not version_check(begin="cancun"): @@ -269,9 +279,8 @@ def visit_VariableDecl(self, node): type_, decl_node=node, location=data_loc, - is_constant=node.is_constant, + modifiability=modifiability, is_public=node.is_public, - is_immutable=node.is_immutable, is_transient=node.is_transient, ) node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace @@ -304,7 +313,10 @@ def _validate_self_namespace(): if node.is_constant: if not node.value: raise VariableDeclarationException("Constant must be declared with a value", node) - if not check_constant(node.value): + + ExprVisitor().visit(node.value, type_) + + if not check_modifiability(node.value, Modifiability.ALWAYS_CONSTANT): raise StateAccessViolation("Value must be a literal", node.value) validate_expected_type(node.value, type_) @@ -483,7 +495,7 @@ def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: resolved_path=str(file.resolved_path), ) vy_ast.validation.validate_literal_nodes(ret) - vy_ast.folding.fold(ret) + # vy_ast.folding.fold(ret) return ret diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py new file mode 100644 index 0000000000..b89c1c6759 --- /dev/null +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -0,0 +1,78 @@ +from vyper import ast as vy_ast +from vyper.exceptions import UnfoldableNode + + +def get_constants(node: vy_ast.Module) -> dict: + constants: dict[str, vy_ast.VyperNode] = {} + module_nodes = node.body.copy() + const_var_decls = [ + n for n in module_nodes if isinstance(n, vy_ast.VariableDecl) and n.is_constant + ] + + while const_var_decls: + derived_nodes = 0 + + for c in const_var_decls: + name = c.get("target.id") + # Handle syntax errors downstream + if c.value is None: + continue + + for n in c.value.get_descendants(include_self=True, reverse=True): + prefold(n, constants) + + try: + val = c.value.get_folded_value_throwing() + + # note that if a constant is redefined, its value will be overwritten, + # but it is okay because the syntax error is handled downstream + constants[name] = val + derived_nodes += 1 + const_var_decls.remove(c) + except UnfoldableNode: + pass + + if not derived_nodes: + break + + return constants + + +def pre_typecheck(node: vy_ast.Module) -> None: + constants = get_constants(node) + + for n in node.get_descendants(reverse=True): + if isinstance(n, vy_ast.VariableDecl): + continue + + prefold(n, constants) + + +def prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]): + if isinstance(node, vy_ast.Name): + var_name = node.id + if var_name in constants: + node._metadata["folded_value"] = constants[var_name] + return + + if isinstance(node, vy_ast.Call): + if isinstance(node.func, vy_ast.Name): + from vyper.builtins.functions import DISPATCH_TABLE + + func_name = node.func.id + + call_type = DISPATCH_TABLE.get(func_name) + if call_type and hasattr(call_type, "fold"): + try: + node._metadata["folded_value"] = call_type.fold(node) + return + except UnfoldableNode: + pass + + if getattr(node, "_is_prefoldable", None): + # call `get_folded_value_throwing` for its side effects and allow all + # exceptions other than `UnfoldableNode` to raise + try: + node.get_folded_value_throwing() + except UnfoldableNode: + pass diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 20ebb0f093..dcf81b4d6e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,7 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -98,12 +98,9 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: # kludge! for validate_modification in local analysis of Assign types = [self.get_expr_info(n) for n in node.elements] location = sorted((i.location for i in types), key=lambda k: k.value)[-1] - is_constant = any((getattr(i, "is_constant", False) for i in types)) - is_immutable = any((getattr(i, "is_immutable", False) for i in types)) + modifiability = sorted((i.modifiability for i in types), key=lambda k: k.value)[-1] - return ExprInfo( - t, location=location, is_constant=is_constant, is_immutable=is_immutable - ) + return ExprInfo(t, location=location, modifiability=modifiability) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): @@ -203,7 +200,7 @@ def _raise_invalid_reference(name, node): if isinstance(s, (VyperType, TYPE_T)): # ex. foo.bar(). bar() is a ContractFunctionT return [s] - if is_self_reference and (s.is_constant or s.is_immutable): + if is_self_reference and s.modifiability >= Modifiability.IMMUTABLE: _raise_invalid_reference(name, node) # general case. s is a VarInfo, e.g. self.foo return [s.typ] @@ -282,6 +279,8 @@ def types_from_Call(self, node): var = self.get_exact_type_from_node(node.func, include_type_exprs=True) return_value = var.fetch_call_return(node) if return_value: + if isinstance(return_value, list): + return return_value return [return_value] raise InvalidType(f"{var} did not return a value", node) @@ -378,7 +377,7 @@ def types_from_Name(self, node): def types_from_Subscript(self, node): # index access, e.g. `foo[1]` - if isinstance(node.value, vy_ast.List): + if isinstance(node.value, (vy_ast.List, vy_ast.Subscript)): types_list = self.get_possible_types_from_node(node.value) ret = [] for t in types_list: @@ -625,54 +624,33 @@ def validate_unique_method_ids(functions: List) -> None: seen.add(method_id) -def check_kwargable(node: vy_ast.VyperNode) -> bool: +def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool: """ - Check if the given node can be used as a default arg + Check if the given node is not more modifiable than the given modifiability. """ - if _check_literal(node): + if node.get_folded_value_maybe(): return True - if isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(check_kwargable(item) for item in node.elements) - if isinstance(node, vy_ast.Call): - args = node.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return all(check_kwargable(v) for v in args[0].values) - call_type = get_exact_type_from_node(node.func) - if getattr(call_type, "_kwargable", False): - return True + if isinstance(node, (vy_ast.BinOp, vy_ast.Compare)): + return all(check_modifiability(i, modifiability) for i in (node.left, node.right)) - value_type = get_expr_info(node) - # is_constant here actually means not_assignable, and is to be renamed - return value_type.is_constant + if isinstance(node, vy_ast.BoolOp): + return all(check_modifiability(i, modifiability) for i in node.values) + if isinstance(node, vy_ast.UnaryOp): + return check_modifiability(node.operand, modifiability) -def _check_literal(node: vy_ast.VyperNode) -> bool: - """ - Check if the given node is a literal value. - """ - if isinstance(node, vy_ast.Constant): - return True - elif isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(_check_literal(item) for item in node.elements) - return False - - -def check_constant(node: vy_ast.VyperNode) -> bool: - """ - Check if the given node is a literal or constant value. - """ - if _check_literal(node): - return True if isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(check_constant(item) for item in node.elements) + return all(check_modifiability(item, modifiability) for item in node.elements) + if isinstance(node, vy_ast.Call): args = node.args if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return all(check_constant(v) for v in args[0].values) + return all(check_modifiability(v, modifiability) for v in args[0].values) call_type = get_exact_type_from_node(node.func) if getattr(call_type, "_kwargable", False): return True - return False + value_type = get_expr_info(node) + return value_type.modifiability >= modifiability diff --git a/vyper/semantics/environment.py b/vyper/semantics/environment.py index ad68f1103e..295f40029e 100644 --- a/vyper/semantics/environment.py +++ b/vyper/semantics/environment.py @@ -1,6 +1,6 @@ from typing import Dict -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.types import AddressT, BytesT, VyperType from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T @@ -52,7 +52,7 @@ def get_constant_vars() -> Dict: """ result = {} for k, v in CONSTANT_ENVIRONMENT_VARS.items(): - result[k] = VarInfo(v, is_constant=True) + result[k] = VarInfo(v, modifiability=Modifiability.CONSTANT_IN_CURRENT_TX) return result diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 6ecfe78be3..429ba807e1 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -340,7 +340,7 @@ def fetch_call_return(self, node): return self.typedef._ctor_call_return(node) raise StructureException("Value is not callable", node) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): if hasattr(self.typedef, "_ctor_arg_types"): return self.typedef._ctor_arg_types(node) raise StructureException("Value is not callable", node) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 34206546fd..c71308e8fc 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -16,9 +16,14 @@ StateAccessViolation, StructureException, ) -from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot +from vyper.semantics.analysis.base import ( + FunctionVisibility, + Modifiability, + StateMutability, + StorageSlot, +) from vyper.semantics.analysis.utils import ( - check_kwargable, + check_modifiability, get_exact_type_from_node, validate_expected_type, ) @@ -39,12 +44,14 @@ class _FunctionArg: @dataclass class PositionalArg(_FunctionArg): + # unfolded ast ast_source: Optional[vy_ast.VyperNode] = None @dataclass class KeywordArg(_FunctionArg): default_value: vy_ast.VyperNode + # unfolded ast ast_source: Optional[vy_ast.VyperNode] = None @@ -128,7 +135,7 @@ def __repr__(self): def __str__(self): ret_sig = "" if not self.return_type else f" -> {self.return_type}" args_sig = ",".join([str(t) for t in self.argument_types]) - return f"def {self.name} {args_sig}{ret_sig}:" + return f"def {self.name}({args_sig}){ret_sig}:" # override parent implementation. function type equality does not # make too much sense. @@ -696,7 +703,7 @@ def _parse_args( positional_args.append(PositionalArg(argname, type_, ast_source=arg)) else: value = funcdef.args.defaults[i - n_positional_args] - if not check_kwargable(value): + if not check_modifiability(value, Modifiability.IMMUTABLE): raise StateAccessViolation("Value must be literal or environment variable", value) validate_expected_type(value, type_) keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index b0d7800011..ccf9604bd4 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -5,7 +5,7 @@ from vyper.abi_types import ABI_Address, ABIType from vyper.ast.validation import validate_call_args from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -324,7 +324,7 @@ def variables(self): @cached_property def immutables(self): - return [t for t in self.variables.values() if t.is_immutable] + return [t for t in self.variables.values() if t.modifiability == Modifiability.IMMUTABLE] @cached_property def immutable_section_bytes(self): diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 0c8e9fddd8..5e1154416a 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -287,7 +287,7 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": ): raise StructureException(err_msg, node.slice) - length_node = node.slice.value.elements[1] + length_node = node.slice.value.elements[1].get_folded_value_maybe() if not isinstance(length_node, vy_ast.Int): raise StructureException(err_msg, length_node) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 8d68a9fa01..137d8d56f5 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -179,7 +179,8 @@ def get_index_value(node: vy_ast.Index) -> int: # TODO: revisit this! from vyper.semantics.analysis.utils import get_possible_types_from_node - if not isinstance(node.get("value"), vy_ast.Int): + value = node.value.get_folded_value_maybe() + if not isinstance(value, vy_ast.Int): if hasattr(node, "value"): # even though the subscript is an invalid type, first check if it's a valid _something_ # this gives a more accurate error in case of e.g. a typo in a constant variable name @@ -191,7 +192,7 @@ def get_index_value(node: vy_ast.Index) -> int: raise InvalidType("Subscript must be a literal integer", node) - if node.value.value <= 0: + if value.value <= 0: raise ArrayIndexException("Subscript must be greater than 0", node) - return node.value.value + return value.value