From 56c4c9dbc09d6310bf132cfde3fdbe1431189a9b Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 1 Jan 2024 06:19:47 +0800 Subject: [PATCH] refactor: reimplement AST folding (#3669) this commit reimplements AST folding. fundamentally, it changes AST folding from a mutating pass to be an annotation pass. this brings several benefits: - typechecking is easier, because folding does not have to reason at all about types. type checking happens on both the folded and unfolded nodes, so intermediate values are type-checked. - correctness in general is easier, because the AST is not mutated. there is also some incidental performance benefit, although that is not necessarily the focus here. - the vyper frontend is now nearly mutation-free. only the getter AST expansion pass remains. note that we cannot push folding past the typechecking stage entirely, because some type checking operations depend on having folded values (e.g., `range()` expressions, or type expressions with integer parameters). the approach taken in this commit is to change constant folding to be annotating, rather than mutating. this way, type-checking can operate on the original AST (and check for the folded values where needed). intermediate values are also type-checked, so expressions like `x: uint128 = 2**128 + 1 - 1` are caught by the typechecker. summary of changes: - `evaluate()` is renamed to `_try_fold()`. a new utility function called `get_folded_value()` caches folded values and is threaded through the codebase. - `pre_typecheck` is added, which extracts `constant` variables and runs `get_folded_value()` on all nodes. - a new `Modifiability` enum replaces the old (confusing) `is_constant` and `is_immutable` attributes on ExprInfo. - `ExprInfo.is_transient` is removed, and handled by adding `TRANSIENT` to the `DataLocation` enum. - the old `check_literal` and `check_kwargable` utility functions are replaced with a more general (and more correct) `check_modifiability` function - several utility functions (ex. `_validate_numeric_bounds()`) related to ad-hoc type-checking (which would happen during constant folding) are removed. - `CompilerData.vyper_module_folded` is renamed to `annotated_vyper_module` - the AST output options are now `ast` and `annotated_ast`. - `None` literals are now banned in AST validation instead of during analysis. --------- Co-authored-by: Charles Cooper --- .../builtins/codegen/test_keccak256.py | 31 ++ .../builtins/codegen/test_sha256.py | 30 ++ .../functional/builtins/codegen/test_unary.py | 7 +- tests/functional/builtins/folding/test_abs.py | 10 +- .../builtins/folding/test_addmod_mulmod.py | 2 +- .../builtins/folding/test_bitwise.py | 11 +- .../builtins/folding/test_epsilon.py | 2 +- .../builtins/folding/test_floor_ceil.py | 2 +- .../folding/test_fold_as_wei_value.py | 4 +- .../builtins/folding/test_keccak_sha.py | 6 +- tests/functional/builtins/folding/test_len.py | 6 +- .../builtins/folding/test_min_max.py | 6 +- .../builtins/folding/test_powmod.py | 2 +- .../test_default_parameters.py | 42 +++ .../test_external_contract_calls.py | 2 +- .../codegen/test_call_graph_stability.py | 2 +- tests/functional/codegen/test_interfaces.py | 2 +- .../codegen/types/numbers/test_constants.py | 4 +- .../codegen/types/numbers/test_decimals.py | 43 ++- .../codegen/types/numbers/test_signed_ints.py | 40 ++- .../types/numbers/test_unsigned_ints.py | 30 +- .../codegen/types/test_dynamic_array.py | 20 +- tests/functional/codegen/types/test_lists.py | 10 +- .../exceptions/test_argument_exception.py | 6 +- tests/functional/syntax/test_abi_decode.py | 4 +- tests/functional/syntax/test_abs.py | 40 +++ tests/functional/syntax/test_addmulmod.py | 22 ++ tests/functional/syntax/test_as_wei_value.py | 72 ++++- tests/functional/syntax/test_ceil.py | 19 ++ tests/functional/syntax/test_dynamic_array.py | 17 +- tests/functional/syntax/test_epsilon.py | 20 ++ tests/functional/syntax/test_floor.py | 19 ++ tests/functional/syntax/test_for_range.py | 56 +++- tests/functional/syntax/test_len.py | 22 +- tests/functional/syntax/test_method_id.py | 50 +++ tests/functional/syntax/test_minmax.py | 43 ++- tests/functional/syntax/test_minmax_value.py | 28 +- tests/functional/syntax/test_powmod.py | 39 +++ tests/functional/syntax/test_raw_call.py | 20 +- tests/functional/syntax/test_ternary.py | 4 +- tests/functional/syntax/test_uint2str.py | 19 ++ tests/functional/syntax/test_unary.py | 21 ++ ..._decimal.py => test_fold_binop_decimal.py} | 8 +- ...te_binop_int.py => test_fold_binop_int.py} | 10 +- ...evaluate_boolop.py => test_fold_boolop.py} | 6 +- ...aluate_compare.py => test_fold_compare.py} | 10 +- ...te_subscript.py => test_fold_subscript.py} | 2 +- ...aluate_unaryop.py => test_fold_unaryop.py} | 6 +- tests/unit/ast/nodes/test_replace_in_tree.py | 70 ----- tests/unit/ast/test_ast_dict.py | 6 +- tests/unit/ast/test_folding.py | 272 ---------------- tests/unit/ast/test_natspec.py | 2 +- vyper/ast/README.md | 21 -- vyper/ast/__init__.py | 2 +- vyper/ast/__init__.pyi | 2 +- vyper/ast/folding.py | 263 ---------------- vyper/ast/natspec.py | 10 +- vyper/ast/nodes.py | 265 ++++++++++------ vyper/ast/nodes.pyi | 13 +- vyper/ast/parse.py | 1 + vyper/ast/validation.py | 11 +- vyper/builtins/_signatures.py | 25 +- vyper/builtins/functions.py | 297 +++++++++--------- vyper/cli/vyper_compile.py | 11 +- vyper/codegen/expr.py | 17 +- vyper/compiler/README.md | 2 - vyper/compiler/__init__.py | 2 + vyper/compiler/output.py | 20 +- vyper/compiler/phases.py | 61 ++-- vyper/semantics/README.md | 29 +- vyper/semantics/analysis/base.py | 87 +++-- vyper/semantics/analysis/local.py | 81 +++-- vyper/semantics/analysis/module.py | 46 +-- vyper/semantics/analysis/pre_typecheck.py | 94 ++++++ vyper/semantics/analysis/utils.py | 73 ++--- vyper/semantics/data_locations.py | 3 +- vyper/semantics/environment.py | 4 +- vyper/semantics/types/base.py | 2 +- vyper/semantics/types/function.py | 13 +- vyper/semantics/types/subscriptable.py | 2 + vyper/semantics/types/utils.py | 10 +- 81 files changed, 1464 insertions(+), 1230 deletions(-) create mode 100644 tests/functional/syntax/test_abs.py create mode 100644 tests/functional/syntax/test_ceil.py create mode 100644 tests/functional/syntax/test_epsilon.py create mode 100644 tests/functional/syntax/test_floor.py create mode 100644 tests/functional/syntax/test_method_id.py create mode 100644 tests/functional/syntax/test_powmod.py create mode 100644 tests/functional/syntax/test_uint2str.py create mode 100644 tests/functional/syntax/test_unary.py rename tests/unit/ast/nodes/{test_evaluate_binop_decimal.py => test_fold_binop_decimal.py} (93%) rename tests/unit/ast/nodes/{test_evaluate_binop_int.py => test_fold_binop_int.py} (93%) rename tests/unit/ast/nodes/{test_evaluate_boolop.py => test_fold_boolop.py} (92%) rename tests/unit/ast/nodes/{test_evaluate_compare.py => test_fold_compare.py} (94%) rename tests/unit/ast/nodes/{test_evaluate_subscript.py => test_fold_subscript.py} (93%) rename tests/unit/ast/nodes/{test_evaluate_unaryop.py => test_fold_unaryop.py} (86%) delete mode 100644 tests/unit/ast/nodes/test_replace_in_tree.py delete mode 100644 tests/unit/ast/test_folding.py delete mode 100644 vyper/ast/folding.py create mode 100644 vyper/semantics/analysis/pre_typecheck.py diff --git a/tests/functional/builtins/codegen/test_keccak256.py b/tests/functional/builtins/codegen/test_keccak256.py index 90fa8b9e09..3b0b9f2018 100644 --- a/tests/functional/builtins/codegen/test_keccak256.py +++ b/tests/functional/builtins/codegen/test_keccak256.py @@ -1,3 +1,6 @@ +from vyper.utils import hex_to_int + + def test_hash_code(get_contract_with_gas_estimation, keccak): hash_code = """ @external @@ -80,3 +83,31 @@ def try32(inp: bytes32) -> bool: assert c.tryy(b"\x35" * 33) is True print("Passed KECCAK256 hash test") + + +def test_hash_constant_bytes32(get_contract_with_gas_estimation, keccak): + hex_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(bytes32) = {hex_val} +BAR: constant(bytes32) = keccak256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert "0x" + c.foo().hex() == keccak(hex_to_int(hex_val).to_bytes(32, "big")).hex() + + +def test_hash_constant_string(get_contract_with_gas_estimation, keccak): + str_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(String[66]) = "{str_val}" +BAR: constant(bytes32) = keccak256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert "0x" + c.foo().hex() == keccak(str_val.encode()).hex() diff --git a/tests/functional/builtins/codegen/test_sha256.py b/tests/functional/builtins/codegen/test_sha256.py index 468e684645..8e1b89bd31 100644 --- a/tests/functional/builtins/codegen/test_sha256.py +++ b/tests/functional/builtins/codegen/test_sha256.py @@ -2,6 +2,8 @@ import pytest +from vyper.utils import hex_to_int + pytestmark = pytest.mark.usefixtures("memory_mocker") @@ -77,3 +79,31 @@ def bar() -> bytes32: c.set(test_val, transact={}) assert c.a() == test_val assert c.bar() == hashlib.sha256(test_val).digest() + + +def test_sha256_constant_bytes32(get_contract_with_gas_estimation): + hex_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(bytes32) = {hex_val} +BAR: constant(bytes32) = sha256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() == hashlib.sha256(hex_to_int(hex_val).to_bytes(32, "big")).digest() + + +def test_sha256_constant_string(get_contract_with_gas_estimation): + str_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(String[66]) = "{str_val}" +BAR: constant(bytes32) = sha256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() == hashlib.sha256(str_val.encode()).digest() diff --git a/tests/functional/builtins/codegen/test_unary.py b/tests/functional/builtins/codegen/test_unary.py index 33f79be233..2be5c0d33f 100644 --- a/tests/functional/builtins/codegen/test_unary.py +++ b/tests/functional/builtins/codegen/test_unary.py @@ -69,16 +69,11 @@ def bar() -> decimal: def test_negation_int128(get_contract): code = """ -a: constant(int128) = -2**127 - -@external -def foo() -> int128: - return -2**127 +a: constant(int128) = min_value(int128) @external def bar() -> int128: return -(a+1) """ c = get_contract(code) - assert c.foo() == -(2**127) assert c.bar() == 2**127 - 1 diff --git a/tests/functional/builtins/folding/test_abs.py b/tests/functional/builtins/folding/test_abs.py index a91a4f1ad3..68131678fa 100644 --- a/tests/functional/builtins/folding/test_abs.py +++ b/tests/functional/builtins/folding/test_abs.py @@ -4,7 +4,7 @@ from vyper import ast as vy_ast from vyper.builtins import functions as vy_fn -from vyper.exceptions import OverflowException +from vyper.exceptions import InvalidType @pytest.mark.fuzzing @@ -21,7 +21,7 @@ def foo(a: int256) -> int256: vyper_ast = vy_ast.parse_to_ast(f"abs({a})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["abs"].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE["abs"]._try_fold(old_node) assert contract.foo(a) == new_node.value == abs(a) @@ -35,7 +35,7 @@ def test_abs_upper_bound_folding(get_contract, a): def foo(a: int256) -> int256: return abs({a}) """ - with pytest.raises(OverflowException): + with pytest.raises(InvalidType): get_contract(source) @@ -55,7 +55,7 @@ def test_abs_lower_bound_folded(get_contract, tx_failed): source = """ @external def foo() -> int256: - return abs(-2**255) + return abs(min_value(int256)) """ - with pytest.raises(OverflowException): + with pytest.raises(InvalidType): get_contract(source) diff --git a/tests/functional/builtins/folding/test_addmod_mulmod.py b/tests/functional/builtins/folding/test_addmod_mulmod.py index 33dcc62984..1d789f1655 100644 --- a/tests/functional/builtins/folding/test_addmod_mulmod.py +++ b/tests/functional/builtins/folding/test_addmod_mulmod.py @@ -24,6 +24,6 @@ def foo(a: uint256, b: uint256, c: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({a}, {b}, {c})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert contract.foo(a, b, c) == new_node.value diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index 63e733644f..53a6d333a0 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -13,6 +13,9 @@ st_sint256 = st.integers(min_value=-(2**255), max_value=2**255 - 1) +# TODO: move this file to tests/unit/ast/nodes/test_fold_bitwise.py + + @pytest.mark.fuzzing @settings(max_examples=50) @pytest.mark.parametrize("op", ["&", "|", "^"]) @@ -28,7 +31,7 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(a, b) == new_node.value @@ -49,7 +52,7 @@ def foo(a: uint256, b: uint256) -> uint256: old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() # force bounds check, no-op because validate_numeric_bounds # already does this, but leave in for hygiene (in case # more types are added). @@ -79,7 +82,7 @@ def foo(a: int256, b: uint256) -> int256: old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() validate_expected_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. @@ -104,6 +107,6 @@ def foo(a: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"~{value}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_epsilon.py b/tests/functional/builtins/folding/test_epsilon.py index 794648cfce..4f5e9434ec 100644 --- a/tests/functional/builtins/folding/test_epsilon.py +++ b/tests/functional/builtins/folding/test_epsilon.py @@ -15,6 +15,6 @@ def foo() -> {typ_name}: vyper_ast = vy_ast.parse_to_ast(f"epsilon({typ_name})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["epsilon"].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE["epsilon"]._try_fold(old_node) assert contract.foo() == new_node.value diff --git a/tests/functional/builtins/folding/test_floor_ceil.py b/tests/functional/builtins/folding/test_floor_ceil.py index 87db23889a..04921e504e 100644 --- a/tests/functional/builtins/folding/test_floor_ceil.py +++ b/tests/functional/builtins/folding/test_floor_ceil.py @@ -30,6 +30,6 @@ def foo(a: decimal) -> int256: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_fold_as_wei_value.py b/tests/functional/builtins/folding/test_fold_as_wei_value.py index 210ab51f0d..4287615bab 100644 --- a/tests/functional/builtins/folding/test_fold_as_wei_value.py +++ b/tests/functional/builtins/folding/test_fold_as_wei_value.py @@ -32,7 +32,7 @@ def foo(a: decimal) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value:.10f}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue().evaluate(old_node) + new_node = vy_fn.AsWeiValue()._try_fold(old_node) assert contract.foo(value) == new_node.value @@ -51,6 +51,6 @@ def foo(a: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue().evaluate(old_node) + new_node = vy_fn.AsWeiValue()._try_fold(old_node) assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_keccak_sha.py b/tests/functional/builtins/folding/test_keccak_sha.py index a2fe460dd1..8da420538f 100644 --- a/tests/functional/builtins/folding/test_keccak_sha.py +++ b/tests/functional/builtins/folding/test_keccak_sha.py @@ -22,7 +22,7 @@ def foo(a: String[100]) -> bytes32: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}('''{value}''')") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -41,7 +41,7 @@ def foo(a: Bytes[100]) -> bytes32: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -62,6 +62,6 @@ def foo(a: Bytes[100]) -> bytes32: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert f"0x{contract.foo(value).hex()}" == new_node.value diff --git a/tests/functional/builtins/folding/test_len.py b/tests/functional/builtins/folding/test_len.py index edf33120dd..967f906555 100644 --- a/tests/functional/builtins/folding/test_len.py +++ b/tests/functional/builtins/folding/test_len.py @@ -17,7 +17,7 @@ def foo(a: String[1024]) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"len('{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len().evaluate(old_node) + new_node = vy_fn.Len()._try_fold(old_node) assert contract.foo(value) == new_node.value @@ -35,7 +35,7 @@ def foo(a: Bytes[1024]) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"len(b'{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len().evaluate(old_node) + new_node = vy_fn.Len()._try_fold(old_node) assert contract.foo(value.encode()) == new_node.value @@ -53,6 +53,6 @@ def foo(a: Bytes[1024]) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"len({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len().evaluate(old_node) + new_node = vy_fn.Len()._try_fold(old_node) assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_min_max.py b/tests/functional/builtins/folding/test_min_max.py index 309f7519c0..36a611fa1b 100644 --- a/tests/functional/builtins/folding/test_min_max.py +++ b/tests/functional/builtins/folding/test_min_max.py @@ -31,7 +31,7 @@ def foo(a: decimal, b: decimal) -> decimal: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert contract.foo(left, right) == new_node.value @@ -50,7 +50,7 @@ def foo(a: int128, b: int128) -> int128: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert contract.foo(left, right) == new_node.value @@ -69,6 +69,6 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert contract.foo(left, right) == new_node.value diff --git a/tests/functional/builtins/folding/test_powmod.py b/tests/functional/builtins/folding/test_powmod.py index 8667ec93fd..a3c2567f58 100644 --- a/tests/functional/builtins/folding/test_powmod.py +++ b/tests/functional/builtins/folding/test_powmod.py @@ -21,6 +21,6 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"pow_mod256({a}, {b})") old_node = vyper_ast.body[0].value - new_node = vy_fn.PowMod256().evaluate(old_node) + new_node = vy_fn.PowMod256()._try_fold(old_node) assert contract.foo(a, b) == new_node.value diff --git a/tests/functional/codegen/calling_convention/test_default_parameters.py b/tests/functional/codegen/calling_convention/test_default_parameters.py index 03f5d9fca2..462748a9c7 100644 --- a/tests/functional/codegen/calling_convention/test_default_parameters.py +++ b/tests/functional/codegen/calling_convention/test_default_parameters.py @@ -305,6 +305,48 @@ def foo(a: address = empty(address)): def foo(a: int112 = min_value(int112)): self.A = a """, + """ +struct X: + x: int128 + y: address +BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345}) +@external +def out_literals(a: int128 = BAR.x + 1) -> X: + return BAR + """, + """ +struct X: + x: int128 + y: address +struct Y: + x: X + y: uint256 +BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345}) +FOO: constant(Y) = Y({x: BAR, y: 256}) +@external +def out_literals(a: int128 = FOO.x.x + 1) -> Y: + return FOO + """, + """ +struct Bar: + a: bool + +BAR: constant(Bar) = Bar({a: True}) + +@external +def foo(x: bool = True and not BAR.a): + pass + """, + """ +struct Bar: + a: uint256 + +BAR: constant(Bar) = Bar({ a: 123 }) + +@external +def foo(x: bool = BAR.a + 1 > 456): + pass + """, ] diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index 0360396f03..0af4f9f937 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -388,7 +388,7 @@ def test_int128_too_long(get_contract, tx_failed): contract_1 = """ @external def foo() -> int256: - return (2**255)-1 + return max_value(int256) """ c = get_contract(contract_1) diff --git a/tests/functional/codegen/test_call_graph_stability.py b/tests/functional/codegen/test_call_graph_stability.py index 2d8ad59791..ca0e6c8c9e 100644 --- a/tests/functional/codegen/test_call_graph_stability.py +++ b/tests/functional/codegen/test_call_graph_stability.py @@ -54,7 +54,7 @@ def foo(): t = CompilerData(code) # check the .called_functions data structure on foo() directly - foo = t.vyper_module_folded.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0] + foo = t.annotated_vyper_module.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0] foo_t = foo._metadata["func_type"] assert [f.name for f in foo_t.called_functions] == func_names diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 65d2df9038..7d363fadc0 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -435,7 +435,7 @@ def ok() -> {typ}: @external def should_fail() -> int256: - return -2**255 # OOB for all int/uint types with less than 256 bits + return min_value(int256) """ code = f""" diff --git a/tests/functional/codegen/types/numbers/test_constants.py b/tests/functional/codegen/types/numbers/test_constants.py index 8244bc5487..af871983ab 100644 --- a/tests/functional/codegen/types/numbers/test_constants.py +++ b/tests/functional/codegen/types/numbers/test_constants.py @@ -4,7 +4,7 @@ import pytest from vyper.compiler import compile_code -from vyper.exceptions import InvalidType +from vyper.exceptions import TypeMismatch from vyper.utils import MemoryPositions @@ -158,7 +158,7 @@ def test_custom_constants_fail(get_contract, assert_compile_failed, storage_type def foo() -> {return_type}: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) def test_constant_address(get_contract): diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index 25dc1f1a1e..fcf71f12f0 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -3,7 +3,13 @@ import pytest -from vyper.exceptions import DecimalOverrideException, InvalidOperation, TypeMismatch +from vyper import compile_code +from vyper.exceptions import ( + DecimalOverrideException, + InvalidOperation, + OverflowException, + TypeMismatch, +) from vyper.utils import DECIMAL_EPSILON, SizeLimits @@ -24,23 +30,25 @@ def test_decimal_override(): @pytest.mark.parametrize("op", ["**", "&", "|", "^"]) -def test_invalid_ops(get_contract, assert_compile_failed, op): +def test_invalid_ops(op): code = f""" @external def foo(x: decimal, y: decimal) -> decimal: return x {op} y """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) @pytest.mark.parametrize("op", ["not"]) -def test_invalid_unary_ops(get_contract, assert_compile_failed, op): +def test_invalid_unary_ops(op): code = f""" @external def foo(x: decimal) -> decimal: return {op} x """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) def quantize(x: Decimal) -> Decimal: @@ -263,11 +271,32 @@ def bar(num: decimal) -> decimal: assert c.bar(Decimal("1e37")) == Decimal("-9e37") # Math lines up -def test_exponents(assert_compile_failed, get_contract): +def test_exponents(): code = """ @external def foo() -> decimal: return 2.2 ** 2.0 """ - assert_compile_failed(lambda: get_contract(code), TypeMismatch) + with pytest.raises(TypeMismatch): + compile_code(code) + + +def test_decimal_nested_intermediate_overflow(): + code = """ +@external +def foo(): + a: decimal = 18707220957835557353007165858768422651595.9365500927 + 1e-10 - 1e-10 + """ + with pytest.raises(OverflowException): + compile_code(code) + + +def test_replace_decimal_nested_intermediate_underflow(dummy_input_bundle): + code = """ +@external +def foo(): + a: decimal = -18707220957835557353007165858768422651595.9365500928 - 1e-10 + 1e-10 + """ + with pytest.raises(OverflowException): + compile_code(code) diff --git a/tests/functional/codegen/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py index 52de5b649f..a10eaee408 100644 --- a/tests/functional/codegen/types/numbers/test_signed_ints.py +++ b/tests/functional/codegen/types/numbers/test_signed_ints.py @@ -4,6 +4,7 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidOperation, InvalidType, OverflowException, ZeroDivisionException from vyper.semantics.types import IntegerT from vyper.utils import evm_div, evm_mod @@ -206,17 +207,16 @@ def _num_min() -> {typ}: @pytest.mark.parametrize("typ", types) -def test_overflow_out_of_range(get_contract, assert_compile_failed, typ): +def test_overflow_out_of_range(get_contract, typ): code = f""" @external def num_sub() -> {typ}: return 1-2**{typ.bits} """ - if typ.bits == 256: - assert_compile_failed(lambda: get_contract(code), OverflowException) - else: - assert_compile_failed(lambda: get_contract(code), InvalidType) + exc = OverflowException if typ.bits == 256 else InvalidType + with pytest.raises(exc): + compile_code(code) ARITHMETIC_OPS = { @@ -231,7 +231,7 @@ def num_sub() -> {typ}: @pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) @pytest.mark.parametrize("typ", types) @pytest.mark.fuzzing -def test_arithmetic_thorough(get_contract, tx_failed, assert_compile_failed, op, typ): +def test_arithmetic_thorough(get_contract, tx_failed, op, typ): # both variables code_1 = f""" @external @@ -318,10 +318,12 @@ def foo() -> {typ}: elif div_by_zero: with tx_failed(): c.foo(x, y) - assert_compile_failed(lambda code=code_2: get_contract(code), ZeroDivisionException) + with pytest.raises(ZeroDivisionException): + compile_code(code_2) with tx_failed(): get_contract(code_3).foo(y) - assert_compile_failed(lambda code=code_4: get_contract(code), ZeroDivisionException) + with pytest.raises(ZeroDivisionException): + compile_code(code_4) else: with tx_failed(): c.foo(x, y) @@ -329,9 +331,8 @@ def foo() -> {typ}: get_contract(code_2).foo(x) with tx_failed(): get_contract(code_3).foo(y) - assert_compile_failed( - lambda code=code_4: get_contract(code), (InvalidType, OverflowException) - ) + with pytest.raises((InvalidType, OverflowException)): + compile_code(code_4) COMPARISON_OPS = { @@ -359,7 +360,7 @@ def foo(x: {typ}, y: {typ}) -> bool: fn = COMPARISON_OPS[op] c = get_contract(code_1) - # note: constant folding is tested in tests/ast/folding + # note: constant folding is tested in tests/unit/ast/nodes special_cases = [ lo, lo + 1, @@ -413,10 +414,21 @@ def foo(a: {typ}) -> {typ}: @pytest.mark.parametrize("typ", types) @pytest.mark.parametrize("op", ["not"]) -def test_invalid_unary_ops(get_contract, assert_compile_failed, typ, op): +def test_invalid_unary_ops(typ, op): code = f""" @external def foo(a: {typ}) -> {typ}: return {op} a """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) + + +def test_binop_nested_intermediate_underflow(): + code = """ +@external +def foo(): + a: int256 = -2**255 * 2 - 10 + 100 + """ + with pytest.raises(InvalidType): + compile_code(code) diff --git a/tests/functional/codegen/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py index 8982065b5d..f10e861689 100644 --- a/tests/functional/codegen/types/numbers/test_unsigned_ints.py +++ b/tests/functional/codegen/types/numbers/test_unsigned_ints.py @@ -4,9 +4,10 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidOperation, InvalidType, OverflowException, ZeroDivisionException from vyper.semantics.types import IntegerT -from vyper.utils import evm_div, evm_mod +from vyper.utils import SizeLimits, evm_div, evm_mod types = sorted(IntegerT.unsigneds()) @@ -85,7 +86,7 @@ def foo(x: {typ}) -> {typ}: @pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) @pytest.mark.parametrize("typ", types) @pytest.mark.fuzzing -def test_arithmetic_thorough(get_contract, tx_failed, assert_compile_failed, op, typ): +def test_arithmetic_thorough(get_contract, tx_failed, op, typ): # both variables code_1 = f""" @external @@ -192,7 +193,7 @@ def foo(x: {typ}, y: {typ}) -> bool: lo, hi = typ.ast_bounds - # note: constant folding is tested in tests/ast/folding + # note: folding is tested in tests/unit/ast/nodes special_cases = [0, 1, 2, 3, hi // 2 - 1, hi // 2, hi // 2 + 1, hi - 2, hi - 1, hi] xs = special_cases.copy() @@ -204,7 +205,7 @@ def foo(x: {typ}, y: {typ}) -> bool: @pytest.mark.parametrize("typ", types) -def test_uint_literal(get_contract, assert_compile_failed, typ): +def test_uint_literal(get_contract, typ): lo, hi = typ.ast_bounds good_cases = [0, 1, 2, 3, hi // 2 - 1, hi // 2, hi // 2 + 1, hi - 1, hi] @@ -221,7 +222,13 @@ def test() -> {typ}: assert c.test() == val for val in bad_cases: - assert_compile_failed(lambda v=val: get_contract(code_template.format(typ=typ, val=v))) + exc = ( + InvalidType + if SizeLimits.MIN_INT256 <= val <= SizeLimits.MAX_UINT256 + else OverflowException + ) + with pytest.raises(exc): + compile_code(code_template.format(typ=typ, val=val)) @pytest.mark.parametrize("typ", types) @@ -232,4 +239,15 @@ def test_invalid_unary_ops(get_contract, assert_compile_failed, typ, op): def foo(a: {typ}) -> {typ}: return {op} a """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) + + +def test_binop_nested_intermediate_overflow(): + code = """ +@external +def foo(): + a: uint256 = 2**255 * 2 / 10 + """ + with pytest.raises(OverflowException): + compile_code(code) diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 4ef6874ae9..70a68e3206 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -2,6 +2,7 @@ import pytest +from vyper.compiler import compile_code from vyper.exceptions import ( ArgumentException, ArrayIndexException, @@ -315,6 +316,21 @@ def test_array(x: int128, y: int128, z: int128, w: int128) -> int128: def test_array_negative_accessor(get_contract_with_gas_estimation, assert_compile_failed): + array_constant_negative_accessor = """ +FOO: constant(int128) = -1 +@external +def test_array(x: int128, y: int128, z: int128, w: int128) -> int128: + a: int128[4] = [0, 0, 0, 0] + a[0] = x + a[1] = y + a[2] = z + a[3] = w + return a[-4] * 1000 + a[-3] * 100 + a[-2] * 10 + a[FOO] + """ + + with pytest.raises(ArrayIndexException): + compile_code(array_constant_negative_accessor) + array_negative_accessor = """ @external def test_array(x: int128, y: int128, z: int128, w: int128) -> int128: @@ -1728,7 +1744,7 @@ def test_constant_list_fail(get_contract, assert_compile_failed, storage_type, r def foo() -> DynArray[{return_type}, 3]: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -1740,7 +1756,7 @@ def test_constant_list_fail_2(get_contract, assert_compile_failed, storage_type, def foo() -> {return_type}: return MY_CONSTANT[0] """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index 657c4ba0b8..b5b9538c20 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -2,7 +2,7 @@ import pytest -from vyper.exceptions import ArrayIndexException, InvalidType, OverflowException, TypeMismatch +from vyper.exceptions import ArrayIndexException, OverflowException, TypeMismatch def test_list_tester_code(get_contract_with_gas_estimation): @@ -705,7 +705,7 @@ def test_constant_list_fail(get_contract, assert_compile_failed, storage_type, r def foo() -> {return_type}[3]: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -717,7 +717,7 @@ def test_constant_list_fail_2(get_contract, assert_compile_failed, storage_type, def foo() -> {return_type}: return MY_CONSTANT[0] """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -824,7 +824,7 @@ def test_constant_nested_list_fail(get_contract, assert_compile_failed, storage_ def foo() -> {return_type}[2][3]: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -838,4 +838,4 @@ def test_constant_nested_list_fail_2( def foo() -> {return_type}: return MY_CONSTANT[0][0] """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) diff --git a/tests/functional/syntax/exceptions/test_argument_exception.py b/tests/functional/syntax/exceptions/test_argument_exception.py index fc06395015..0b7ec21bdb 100644 --- a/tests/functional/syntax/exceptions/test_argument_exception.py +++ b/tests/functional/syntax/exceptions/test_argument_exception.py @@ -1,13 +1,13 @@ import pytest -from vyper import compiler +from vyper import compile_code from vyper.exceptions import ArgumentException fail_list = [ """ @external def foo(): - x = as_wei_value(5, "vader") + x: uint256 = as_wei_value(5, "vader") """, """ @external @@ -95,4 +95,4 @@ def foo(): @pytest.mark.parametrize("bad_code", fail_list) def test_function_declaration_exception(bad_code): with pytest.raises(ArgumentException): - compiler.compile_code(bad_code) + compile_code(bad_code) diff --git a/tests/functional/syntax/test_abi_decode.py b/tests/functional/syntax/test_abi_decode.py index f05ff429cd..a6665bb84c 100644 --- a/tests/functional/syntax/test_abi_decode.py +++ b/tests/functional/syntax/test_abi_decode.py @@ -26,7 +26,7 @@ def bar(j: String[32]) -> bool: @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_abi_encode_fail(bad_code, exc): +def test_abi_decode_fail(bad_code, exc): with pytest.raises(exc): compiler.compile_code(bad_code) @@ -41,5 +41,5 @@ def foo(x: Bytes[32]) -> uint256: @pytest.mark.parametrize("good_code", valid_list) -def test_abi_encode_success(good_code): +def test_abi_decode_success(good_code): assert compiler.compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_abs.py b/tests/functional/syntax/test_abs.py new file mode 100644 index 0000000000..0841ff05d6 --- /dev/null +++ b/tests/functional/syntax/test_abs.py @@ -0,0 +1,40 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +fail_list = [ + ( + """ +@external +def foo(): + y: int256 = abs( + -57896044618658097711785492504343953926634992332820282019728792003956564819968 + ) + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_abs_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(int256) = -3 +BAR: constant(int256) = abs(FOO) + +@external +def foo(): + a: int256 = BAR + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_abs_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_addmulmod.py b/tests/functional/syntax/test_addmulmod.py index ddff4d3e01..17c7b3ab8c 100644 --- a/tests/functional/syntax/test_addmulmod.py +++ b/tests/functional/syntax/test_addmulmod.py @@ -1,5 +1,6 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidType fail_list = [ @@ -25,3 +26,24 @@ def foo() -> uint256: @pytest.mark.parametrize("code,exc", fail_list) def test_add_mod_fail(assert_compile_failed, get_contract, code, exc): assert_compile_failed(lambda: get_contract(code), exc) + + +valid_list = [ + """ +FOO: constant(uint256) = 3 +BAR: constant(uint256) = 5 +BAZ: constant(uint256) = 19 +BAX: constant(uint256) = uint256_addmod(FOO, BAR, BAZ) + """, + """ +FOO: constant(uint256) = 3 +BAR: constant(uint256) = 5 +BAZ: constant(uint256) = 19 +BAX: constant(uint256) = uint256_mulmod(FOO, BAR, BAZ) + """, +] + + +@pytest.mark.parametrize("code", valid_list) +def test_addmulmod_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_as_wei_value.py b/tests/functional/syntax/test_as_wei_value.py index a5232a5c9a..056d0348e9 100644 --- a/tests/functional/syntax/test_as_wei_value.py +++ b/tests/functional/syntax/test_as_wei_value.py @@ -1,13 +1,31 @@ import pytest -from vyper.exceptions import ArgumentException, InvalidType, StructureException +from vyper import compile_code +from vyper.exceptions import ( + ArgumentException, + InvalidLiteral, + InvalidType, + OverflowException, + StructureException, + UndeclaredDefinition, +) + +# CMC 2023-12-31 these tests could probably go in builtins/folding/ fail_list = [ ( """ @external def foo(): - x: int128 = as_wei_value(5, szabo) + x: uint256 = as_wei_value(5, szabo) + """, + UndeclaredDefinition, + ), + ( + """ +@external +def foo(): + x: uint256 = as_wei_value(5, "szaboo") """, ArgumentException, ), @@ -28,12 +46,50 @@ def foo(): """, InvalidType, ), + ( + """ +@external +def foo() -> uint256: + return as_wei_value( + 115792089237316195423570985008687907853269984665640564039457584007913129639937, + 'milliether' + ) + """, + OverflowException, + ), + ( + """ +@external +def foo(): + x: uint256 = as_wei_value(-1, "szabo") + """, + InvalidLiteral, + ), + ( + """ +FOO: constant(uint256) = as_wei_value(5, szabo) + """, + UndeclaredDefinition, + ), + ( + """ +FOO: constant(uint256) = as_wei_value(5, "szaboo") + """, + ArgumentException, + ), + ( + """ +FOO: constant(uint256) = as_wei_value(-1, "szabo") + """, + InvalidLiteral, + ), ] @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_as_wei_fail(get_contract_with_gas_estimation, bad_code, exc, assert_compile_failed): - assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), exc) +def test_as_wei_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) valid_list = [ @@ -59,6 +115,14 @@ def foo() -> uint256: x: address = 0x1234567890123456789012345678901234567890 return x.balance """, + """ +y: constant(String[5]) = "szabo" +x: constant(uint256) = as_wei_value(5, y) + +@external +def foo(): + a: uint256 = x + """, ] diff --git a/tests/functional/syntax/test_ceil.py b/tests/functional/syntax/test_ceil.py new file mode 100644 index 0000000000..41f4175d01 --- /dev/null +++ b/tests/functional/syntax/test_ceil.py @@ -0,0 +1,19 @@ +import pytest + +from vyper import compile_code + +valid_list = [ + """ +BAR: constant(decimal) = 2.5 +FOO: constant(int256) = ceil(BAR) + +@external +def foo(): + a: int256 = FOO + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_ceil_good(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_dynamic_array.py b/tests/functional/syntax/test_dynamic_array.py index 99a01a17c8..f566a80625 100644 --- a/tests/functional/syntax/test_dynamic_array.py +++ b/tests/functional/syntax/test_dynamic_array.py @@ -1,6 +1,6 @@ import pytest -from vyper import compiler +from vyper import compile_code from vyper.exceptions import StructureException fail_list = [ @@ -24,12 +24,21 @@ def foo(): """, StructureException, ), + ( + """ +@external +def foo(): + a: DynArray[uint256, FOO] = [1, 2, 3] + """, + StructureException, + ), ] @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_block_fail(assert_compile_failed, get_contract, bad_code, exc): - assert_compile_failed(lambda: get_contract(bad_code), exc) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) valid_list = [ @@ -48,4 +57,4 @@ def test_block_fail(assert_compile_failed, get_contract, bad_code, exc): @pytest.mark.parametrize("good_code", valid_list) def test_dynarray_pass(good_code): - assert compiler.compile_code(good_code) is not None + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_epsilon.py b/tests/functional/syntax/test_epsilon.py new file mode 100644 index 0000000000..0e80d2b4bf --- /dev/null +++ b/tests/functional/syntax/test_epsilon.py @@ -0,0 +1,20 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +# CMC 2023-12-31 this could probably go in builtins/folding/ +fail_list = [ + ( + """ +FOO: constant(address) = epsilon(address) + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) diff --git a/tests/functional/syntax/test_floor.py b/tests/functional/syntax/test_floor.py new file mode 100644 index 0000000000..5c30aecbe1 --- /dev/null +++ b/tests/functional/syntax/test_floor.py @@ -0,0 +1,19 @@ +import pytest + +from vyper import compile_code + +valid_list = [ + """ +BAR: constant(decimal) = 2.5 +FOO: constant(int256) = floor(BAR) + +@external +def foo(): + a: int256 = FOO + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_floor_good(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index 7c7f9c476d..a9c3ad5cab 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -3,7 +3,12 @@ import pytest from vyper import compiler -from vyper.exceptions import ArgumentException, StateAccessViolation, StructureException +from vyper.exceptions import ( + ArgumentException, + StateAccessViolation, + StructureException, + TypeMismatch, +) fail_list = [ ( @@ -20,6 +25,17 @@ def foo(): ( """ @external +def bar(): + for i in range(1,2,bound=0): + pass + """, + StructureException, + "Bound must be at least 1", + "0", + ), + ( + """ +@external def foo(): x: uint256 = 100 for _ in range(10, bound=x): @@ -181,6 +197,44 @@ def foo(x: int128): "Value must be a literal integer, unless a bound is specified", "x", ), + ( + """ +@external +def bar(x: uint256): + for i in range(3, x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +FOO: constant(int128) = 3 +BAR: constant(uint256) = 7 + +@external +def foo(): + for i in range(FOO, BAR): + pass + """, + TypeMismatch, + "Iterator values are of different types", + "range(FOO, BAR)", + ), + ( + """ +FOO: constant(int128) = -1 + +@external +def foo(): + for i in range(10, bound=FOO): + pass + """, + StructureException, + "Bound must be at least 1", + "-1", + ), ] for_code_regex = re.compile(r"for .+ in (.*):") diff --git a/tests/functional/syntax/test_len.py b/tests/functional/syntax/test_len.py index bbde7e4897..b8cc61df1d 100644 --- a/tests/functional/syntax/test_len.py +++ b/tests/functional/syntax/test_len.py @@ -1,7 +1,6 @@ import pytest -from pytest import raises -from vyper import compiler +from vyper import compile_code from vyper.exceptions import TypeMismatch fail_list = [ @@ -21,11 +20,11 @@ def foo(inp: int128) -> uint256: @pytest.mark.parametrize("bad_code", fail_list) def test_block_fail(bad_code): if isinstance(bad_code, tuple): - with raises(bad_code[1]): - compiler.compile_code(bad_code[0]) + with pytest.raises(bad_code[1]): + compile_code(bad_code[0]) else: - with raises(TypeMismatch): - compiler.compile_code(bad_code) + with pytest.raises(TypeMismatch): + compile_code(bad_code) valid_list = [ @@ -39,9 +38,18 @@ def foo(inp: Bytes[10]) -> uint256: def foo(inp: String[10]) -> uint256: return len(inp) """, + """ +BAR: constant(String[5]) = "vyper" +FOO: constant(uint256) = len(BAR) + +@external +def foo() -> uint256: + a: uint256 = FOO + return a + """, ] @pytest.mark.parametrize("good_code", valid_list) def test_list_success(good_code): - assert compiler.compile_code(good_code) is not None + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_method_id.py b/tests/functional/syntax/test_method_id.py new file mode 100644 index 0000000000..849c1b0d55 --- /dev/null +++ b/tests/functional/syntax/test_method_id.py @@ -0,0 +1,50 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidLiteral, InvalidType + +fail_list = [ + ( + """ +@external +def foo(): + a: Bytes[4] = method_id("bar ()") + """, + InvalidLiteral, + ), + ( + """ +FOO: constant(Bytes[4]) = method_id(1) + """, + InvalidType, + ), + ( + """ +FOO: constant(Bytes[4]) = method_id("bar ()") + """, + InvalidLiteral, + ), +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_method_id_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(String[5]) = "foo()" +BAR: constant(Bytes[4]) = method_id(FOO) + +@external +def foo(a: Bytes[4] = BAR): + pass + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_method_id_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_minmax.py b/tests/functional/syntax/test_minmax.py index 2ad3d363f1..78ee74635c 100644 --- a/tests/functional/syntax/test_minmax.py +++ b/tests/functional/syntax/test_minmax.py @@ -1,6 +1,7 @@ import pytest -from vyper.exceptions import InvalidType, TypeMismatch +from vyper import compile_code +from vyper.exceptions import InvalidType, OverflowException, TypeMismatch fail_list = [ ( @@ -19,9 +20,45 @@ def foo(): """, TypeMismatch, ), + ( + """ +@external +def foo(): + a: decimal = min(1.0, 18707220957835557353007165858768422651595.9365500928) + """, + OverflowException, + ), ] @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_block_fail(assert_compile_failed, get_contract_with_gas_estimation, bad_code, exc): - assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), exc) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(uint256) = 123 +BAR: constant(uint256) = 456 +BAZ: constant(uint256) = min(FOO, BAR) + +@external +def foo(): + a: uint256 = BAZ + """, + """ +FOO: constant(uint256) = 123 +BAR: constant(uint256) = 456 +BAZ: constant(uint256) = max(FOO, BAR) + +@external +def foo(): + a: uint256 = BAZ + """, +] + + +@pytest.mark.parametrize("good_code", valid_list) +def test_block_success(good_code): + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_minmax_value.py b/tests/functional/syntax/test_minmax_value.py index e154cad23f..8cc3370b42 100644 --- a/tests/functional/syntax/test_minmax_value.py +++ b/tests/functional/syntax/test_minmax_value.py @@ -1,21 +1,39 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidType fail_list = [ - """ + ( + """ @external def foo(): a: address = min_value(address) """, - """ + InvalidType, + ), + ( + """ @external def foo(): a: address = max_value(address) """, + InvalidType, + ), + ( + """ +FOO: constant(address) = min_value(address) + +@external +def foo(): + a: address = FOO + """, + InvalidType, + ), ] -@pytest.mark.parametrize("bad_code", fail_list) -def test_block_fail(assert_compile_failed, get_contract_with_gas_estimation, bad_code): - assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), InvalidType) +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) diff --git a/tests/functional/syntax/test_powmod.py b/tests/functional/syntax/test_powmod.py new file mode 100644 index 0000000000..12ea23152c --- /dev/null +++ b/tests/functional/syntax/test_powmod.py @@ -0,0 +1,39 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +fail_list = [ + ( + """ +@external +def foo(): + a: uint256 = pow_mod256(-1, -1) + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_powmod_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(uint256) = 3 +BAR: constant(uint256) = 5 +BAZ: constant(uint256) = pow_mod256(FOO, BAR) + +@external +def foo(): + a: uint256 = BAZ + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_powmod_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_raw_call.py b/tests/functional/syntax/test_raw_call.py index b1286e7a8e..c0b38d1d1e 100644 --- a/tests/functional/syntax/test_raw_call.py +++ b/tests/functional/syntax/test_raw_call.py @@ -1,6 +1,6 @@ import pytest -from vyper import compiler +from vyper import compile_code from vyper.exceptions import ArgumentException, InvalidType, SyntaxException, TypeMismatch fail_list = [ @@ -39,7 +39,7 @@ def foo(): @pytest.mark.parametrize("bad_code,exc", fail_list) def test_raw_call_fail(bad_code, exc): with pytest.raises(exc): - compiler.compile_code(bad_code) + compile_code(bad_code) valid_list = [ @@ -90,9 +90,23 @@ def foo(): value=self.balance - self.balances[0] ) """, + # test constants + """ +OUTSIZE: constant(uint256) = 4 +REVERT_ON_FAILURE: constant(bool) = True +@external +def foo(): + x: Bytes[9] = raw_call( + 0x1234567890123456789012345678901234567890, + b"cow", + max_outsize=OUTSIZE, + gas=595757, + revert_on_failure=REVERT_ON_FAILURE + ) + """, ] @pytest.mark.parametrize("good_code", valid_list) def test_raw_call_success(good_code): - assert compiler.compile_code(good_code) is not None + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_ternary.py b/tests/functional/syntax/test_ternary.py index 325be3e43b..6a2bb9c072 100644 --- a/tests/functional/syntax/test_ternary.py +++ b/tests/functional/syntax/test_ternary.py @@ -1,6 +1,6 @@ import pytest -from vyper.compiler import compile_code +from vyper import compile_code from vyper.exceptions import InvalidType, TypeMismatch good_list = [ @@ -82,7 +82,7 @@ def foo() -> uint256: def foo() -> uint256: return 1 if TEST else 2 """, - InvalidType, + TypeMismatch, ), ( # bad test type: variable """ diff --git a/tests/functional/syntax/test_uint2str.py b/tests/functional/syntax/test_uint2str.py new file mode 100644 index 0000000000..9e6dde30cc --- /dev/null +++ b/tests/functional/syntax/test_uint2str.py @@ -0,0 +1,19 @@ +import pytest + +from vyper import compile_code + +valid_list = [ + """ +FOO: constant(uint256) = 3 +BAR: constant(String[78]) = uint2str(FOO) + +@external +def foo(): + a: String[78] = BAR + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_addmulmod_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_unary.py b/tests/functional/syntax/test_unary.py new file mode 100644 index 0000000000..5942ee15db --- /dev/null +++ b/tests/functional/syntax/test_unary.py @@ -0,0 +1,21 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +fail_list = [ + ( + """ +@external +def foo() -> int128: + return -2**127 + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("code,exc", fail_list) +def test_unary_fail(code, exc): + with pytest.raises(exc): + compile_code(code) diff --git a/tests/unit/ast/nodes/test_evaluate_binop_decimal.py b/tests/unit/ast/nodes/test_fold_binop_decimal.py similarity index 93% rename from tests/unit/ast/nodes/test_evaluate_binop_decimal.py rename to tests/unit/ast/nodes/test_fold_binop_decimal.py index 44b82e321d..e426a11de9 100644 --- a/tests/unit/ast/nodes/test_evaluate_binop_decimal.py +++ b/tests/unit/ast/nodes/test_fold_binop_decimal.py @@ -31,7 +31,7 @@ def foo(a: decimal, b: decimal) -> decimal: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: is_valid = False @@ -49,7 +49,7 @@ def test_binop_pow(): old_node = vyper_ast.body[0].value with pytest.raises(TypeMismatch): - old_node.evaluate() + old_node.get_folded_value() @pytest.mark.fuzzing @@ -74,8 +74,8 @@ def foo({input_value}) -> decimal: literal_op = literal_op.rsplit(maxsplit=1)[0] vyper_ast = vy_ast.parse_to_ast(literal_op) try: - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.get_folded_value() + expected = new_node.value is_valid = -(2**127) <= expected < 2**127 except (OverflowException, ZeroDivisionException): # for overflow or division/modulus by 0, expect the contract call to revert diff --git a/tests/unit/ast/nodes/test_evaluate_binop_int.py b/tests/unit/ast/nodes/test_fold_binop_int.py similarity index 93% rename from tests/unit/ast/nodes/test_evaluate_binop_int.py rename to tests/unit/ast/nodes/test_fold_binop_int.py index 405d557f7d..904b36c167 100644 --- a/tests/unit/ast/nodes/test_evaluate_binop_int.py +++ b/tests/unit/ast/nodes/test_fold_binop_int.py @@ -27,7 +27,7 @@ def foo(a: int128, b: int128) -> int128: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: is_valid = False @@ -57,7 +57,7 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() is_valid = new_node.value >= 0 except ZeroDivisionException: is_valid = False @@ -85,7 +85,7 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{left} ** {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -115,8 +115,8 @@ def foo({input_value}) -> int128: vyper_ast = vy_ast.parse_to_ast(literal_op) try: - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.get_folded_value() + expected = new_node.value is_valid = True except ZeroDivisionException: is_valid = False diff --git a/tests/unit/ast/nodes/test_evaluate_boolop.py b/tests/unit/ast/nodes/test_fold_boolop.py similarity index 92% rename from tests/unit/ast/nodes/test_evaluate_boolop.py rename to tests/unit/ast/nodes/test_fold_boolop.py index 8b70537c39..3c42da0d26 100644 --- a/tests/unit/ast/nodes/test_evaluate_boolop.py +++ b/tests/unit/ast/nodes/test_fold_boolop.py @@ -26,7 +26,7 @@ def foo({input_value}) -> bool: vyper_ast = vy_ast.parse_to_ast(literal_op) old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(*values) == new_node.value @@ -53,7 +53,7 @@ def foo({input_value}) -> bool: literal_op = literal_op.rsplit(maxsplit=1)[0] vyper_ast = vy_ast.parse_to_ast(literal_op) - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.get_folded_value() + expected = new_node.value assert contract.foo(*values) == expected diff --git a/tests/unit/ast/nodes/test_evaluate_compare.py b/tests/unit/ast/nodes/test_fold_compare.py similarity index 94% rename from tests/unit/ast/nodes/test_evaluate_compare.py rename to tests/unit/ast/nodes/test_fold_compare.py index 07f8e70de6..2b7c0f09d7 100644 --- a/tests/unit/ast/nodes/test_evaluate_compare.py +++ b/tests/unit/ast/nodes/test_fold_compare.py @@ -21,7 +21,7 @@ def foo(a: int128, b: int128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -41,7 +41,7 @@ def foo(a: uint128, b: uint128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -65,7 +65,7 @@ def bar(a: int128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} in {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() # check runtime == fully folded assert contract.foo(left, right) == new_node.value @@ -94,7 +94,7 @@ def bar(a: int128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} not in {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() # check runtime == fully folded assert contract.foo(left, right) == new_node.value @@ -109,4 +109,4 @@ def test_compare_type_mismatch(op): vyper_ast = vy_ast.parse_to_ast(f"1 {op} 1.0") old_node = vyper_ast.body[0].value with pytest.raises(UnfoldableNode): - old_node.evaluate() + old_node.get_folded_value() diff --git a/tests/unit/ast/nodes/test_evaluate_subscript.py b/tests/unit/ast/nodes/test_fold_subscript.py similarity index 93% rename from tests/unit/ast/nodes/test_evaluate_subscript.py rename to tests/unit/ast/nodes/test_fold_subscript.py index ca50a076a5..1884abf73b 100644 --- a/tests/unit/ast/nodes/test_evaluate_subscript.py +++ b/tests/unit/ast/nodes/test_fold_subscript.py @@ -21,6 +21,6 @@ def foo(array: int128[10], idx: uint256) -> int128: vyper_ast = vy_ast.parse_to_ast(f"{array}[{idx}]") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(array, idx) == new_node.value diff --git a/tests/unit/ast/nodes/test_evaluate_unaryop.py b/tests/unit/ast/nodes/test_fold_unaryop.py similarity index 86% rename from tests/unit/ast/nodes/test_evaluate_unaryop.py rename to tests/unit/ast/nodes/test_fold_unaryop.py index 63d7a0b7ff..ff48adfe71 100644 --- a/tests/unit/ast/nodes/test_evaluate_unaryop.py +++ b/tests/unit/ast/nodes/test_fold_unaryop.py @@ -14,7 +14,7 @@ def foo(a: bool) -> bool: vyper_ast = vy_ast.parse_to_ast(f"not {bool_cond}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(bool_cond) == new_node.value @@ -31,7 +31,7 @@ def foo(a: bool) -> bool: literal_op = f"{'not ' * count}{bool_cond}" vyper_ast = vy_ast.parse_to_ast(literal_op) - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.get_folded_value() + expected = new_node.value assert contract.foo(bool_cond) == expected diff --git a/tests/unit/ast/nodes/test_replace_in_tree.py b/tests/unit/ast/nodes/test_replace_in_tree.py deleted file mode 100644 index 682e7ce7de..0000000000 --- a/tests/unit/ast/nodes/test_replace_in_tree.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest - -from vyper import ast as vy_ast -from vyper.exceptions import CompilerPanic - - -def test_assumptions(): - # ASTs generated separately from the same source should compare equal - test_tree = vy_ast.parse_to_ast("foo = 42") - expected_tree = vy_ast.parse_to_ast("foo = 42") - assert vy_ast.compare_nodes(test_tree, expected_tree) - - # ASTs generated separately with different source should compare not-equal - test_tree = vy_ast.parse_to_ast("foo = 42") - expected_tree = vy_ast.parse_to_ast("bar = 666") - assert not vy_ast.compare_nodes(test_tree, expected_tree) - - -def test_simple_replacement(): - test_tree = vy_ast.parse_to_ast("foo = 42") - expected_tree = vy_ast.parse_to_ast("bar = 42") - - old_node = test_tree.body[0].target - new_node = vy_ast.parse_to_ast("bar").body[0].value - - test_tree.replace_in_tree(old_node, new_node) - - assert vy_ast.compare_nodes(test_tree, expected_tree) - - -def test_list_replacement_similar_nodes(): - test_tree = vy_ast.parse_to_ast("foo = [1, 1, 1, 1, 1]") - expected_tree = vy_ast.parse_to_ast("foo = [1, 1, 31337, 1, 1]") - - old_node = test_tree.body[0].value.elements[2] - new_node = vy_ast.parse_to_ast("31337").body[0].value - - test_tree.replace_in_tree(old_node, new_node) - - assert vy_ast.compare_nodes(test_tree, expected_tree) - - -def test_parents_children(): - test_tree = vy_ast.parse_to_ast("foo = 42") - - old_node = test_tree.body[0].target - parent = old_node.get_ancestor() - - new_node = vy_ast.parse_to_ast("bar").body[0].value - test_tree.replace_in_tree(old_node, new_node) - - assert old_node.get_ancestor() == new_node.get_ancestor() - - assert old_node not in parent.get_children() - assert new_node in parent.get_children() - - assert old_node not in test_tree.get_descendants() - assert new_node in test_tree.get_descendants() - - -def test_cannot_replace_twice(): - test_tree = vy_ast.parse_to_ast("foo = 42") - old_node = test_tree.body[0].target - - new_node = vy_ast.parse_to_ast("42").body[0].value - - test_tree.replace_in_tree(old_node, new_node) - - with pytest.raises(CompilerPanic): - test_tree.replace_in_tree(old_node, new_node) diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index dc49f72561..20390f3d5e 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -41,8 +41,8 @@ def test_basic_ast(): code = """ a: int128 """ - dict_out = compiler.compile_code(code, output_formats=["ast_dict"], source_id=0) - assert dict_out["ast_dict"]["ast"]["body"][0] == { + dict_out = compiler.compile_code(code, output_formats=["annotated_ast_dict"], source_id=0) + assert dict_out["annotated_ast_dict"]["ast"]["body"][0] == { "annotation": { "ast_type": "Name", "col_offset": 3, @@ -69,12 +69,14 @@ def test_basic_ast(): "lineno": 2, "node_id": 2, "src": "1:1:0", + "type": "int128", }, "value": None, "is_constant": False, "is_immutable": False, "is_public": False, "is_transient": False, + "type": "int128", } diff --git a/tests/unit/ast/test_folding.py b/tests/unit/ast/test_folding.py deleted file mode 100644 index 62a7140e97..0000000000 --- a/tests/unit/ast/test_folding.py +++ /dev/null @@ -1,272 +0,0 @@ -import pytest - -from vyper import ast as vy_ast -from vyper.ast import folding -from vyper.exceptions import OverflowException - - -def test_integration(): - test_ast = vy_ast.parse_to_ast("[1+2, 6+7][8-8]") - expected_ast = vy_ast.parse_to_ast("3") - - folding.fold(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_binop_simple(): - test_ast = vy_ast.parse_to_ast("1 + 2") - expected_ast = vy_ast.parse_to_ast("3") - - folding.replace_literal_ops(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_binop_nested(): - test_ast = vy_ast.parse_to_ast("((6 + (2**4)) * 4) / 2") - expected_ast = vy_ast.parse_to_ast("44") - - folding.replace_literal_ops(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_binop_nested_intermediate_overflow(): - test_ast = vy_ast.parse_to_ast("2**255 * 2 / 10") - with pytest.raises(OverflowException): - folding.fold(test_ast) - - -def test_replace_binop_nested_intermediate_underflow(): - test_ast = vy_ast.parse_to_ast("-2**255 * 2 - 10 + 100") - with pytest.raises(OverflowException): - folding.fold(test_ast) - - -def test_replace_decimal_nested_intermediate_overflow(): - test_ast = vy_ast.parse_to_ast( - "18707220957835557353007165858768422651595.9365500927 + 1e-10 - 1e-10" - ) - with pytest.raises(OverflowException): - folding.fold(test_ast) - - -def test_replace_decimal_nested_intermediate_underflow(): - test_ast = vy_ast.parse_to_ast( - "-18707220957835557353007165858768422651595.9365500928 - 1e-10 + 1e-10" - ) - with pytest.raises(OverflowException): - folding.fold(test_ast) - - -def test_replace_literal_ops(): - test_ast = vy_ast.parse_to_ast("[not True, True and False, True or False]") - expected_ast = vy_ast.parse_to_ast("[False, False, True]") - - folding.replace_literal_ops(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_subscripts_simple(): - test_ast = vy_ast.parse_to_ast("[foo, bar, baz][1]") - expected_ast = vy_ast.parse_to_ast("bar") - - folding.replace_subscripts(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_subscripts_nested(): - test_ast = vy_ast.parse_to_ast("[[0, 1], [2, 3], [4, 5]][2][1]") - expected_ast = vy_ast.parse_to_ast("5") - - folding.replace_subscripts(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -constants_modified = [ - "bar = FOO", - "bar: int128[FOO]", - "[1, 2, FOO]", - "def bar(a: int128 = FOO): pass", - "log bar(FOO)", - "FOO + 1", - "a: int128[FOO / 2]", - "a[FOO - 1] = 44", -] - - -@pytest.mark.parametrize("source", constants_modified) -def test_replace_constant(source): - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True) - - assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -constants_unmodified = [ - "FOO = 42", - "self.FOO = 42", - "bar = FOO()", - "FOO()", - "bar = FOO()", - "bar = self.FOO", - "log FOO(bar)", - "[1, 2, FOO()]", - "FOO[42] = 2", -] - - -@pytest.mark.parametrize("source", constants_unmodified) -def test_replace_constant_no(source): - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True) - - assert vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -userdefined_modified = [ - "FOO", - "foo = FOO", - "foo: int128[FOO] = 42", - "foo = [FOO]", - "foo += FOO", - "def foo(bar: int128 = FOO): pass", - "def foo(): bar = FOO", - "def foo(): return FOO", -] - - -@pytest.mark.parametrize("source", userdefined_modified) -def test_replace_userdefined_constant(source): - source = f"FOO: constant(int128) = 42\n{source}" - - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_user_defined_constants(folded_ast) - - assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -userdefined_unmodified = [ - "FOO: constant(int128) = 42", - "FOO = 42", - "FOO += 42", - "FOO()", - "def foo(FOO: int128 = 42): pass", - "def foo(): FOO = 42", - "def FOO(): pass", -] - - -@pytest.mark.parametrize("source", userdefined_unmodified) -def test_replace_userdefined_constant_no(source): - source = f"FOO: constant(int128) = 42\n{source}" - - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_user_defined_constants(folded_ast) - - assert vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -dummy_address = "0x000000000000000000000000000000000000dEaD" -userdefined_attributes = [("b: uint256 = ADDR.balance", f"b: uint256 = {dummy_address}.balance")] - - -@pytest.mark.parametrize("source", userdefined_attributes) -def test_replace_userdefined_attribute(source): - preamble = f"ADDR: constant(address) = {dummy_address}" - l_source = f"{preamble}\n{source[0]}" - r_source = f"{preamble}\n{source[1]}" - - l_ast = vy_ast.parse_to_ast(l_source) - folding.replace_user_defined_constants(l_ast) - - r_ast = vy_ast.parse_to_ast(r_source) - - assert vy_ast.compare_nodes(l_ast, r_ast) - - -userdefined_struct = [("b: Foo = FOO", "b: Foo = Foo({a: 123, b: 456})")] - - -@pytest.mark.parametrize("source", userdefined_struct) -def test_replace_userdefined_struct(source): - preamble = """ -struct Foo: - a: uint256 - b: uint256 - -FOO: constant(Foo) = Foo({a: 123, b: 456}) - """ - l_source = f"{preamble}\n{source[0]}" - r_source = f"{preamble}\n{source[1]}" - - l_ast = vy_ast.parse_to_ast(l_source) - folding.replace_user_defined_constants(l_ast) - - r_ast = vy_ast.parse_to_ast(r_source) - - assert vy_ast.compare_nodes(l_ast, r_ast) - - -userdefined_nested_struct = [ - ("b: Foo = FOO", "b: Foo = Foo({f1: Bar({b1: 123, b2: 456}), f2: 789})") -] - - -@pytest.mark.parametrize("source", userdefined_nested_struct) -def test_replace_userdefined_nested_struct(source): - preamble = """ -struct Bar: - b1: uint256 - b2: uint256 - -struct Foo: - f1: Bar - f2: uint256 - -FOO: constant(Foo) = Foo({f1: Bar({b1: 123, b2: 456}), f2: 789}) - """ - l_source = f"{preamble}\n{source[0]}" - r_source = f"{preamble}\n{source[1]}" - - l_ast = vy_ast.parse_to_ast(l_source) - folding.replace_user_defined_constants(l_ast) - - r_ast = vy_ast.parse_to_ast(r_source) - - assert vy_ast.compare_nodes(l_ast, r_ast) - - -builtin_folding_functions = [("ceil(4.2)", "5"), ("floor(4.2)", "4")] - -builtin_folding_sources = [ - "{}", - "foo = {}", - "foo = [{0}, {0}]", - "def foo(): {}", - "def foo(): return {}", - "def foo(bar: {}): pass", -] - - -@pytest.mark.parametrize("source", builtin_folding_sources) -@pytest.mark.parametrize("original,result", builtin_folding_functions) -def test_replace_builtins(source, original, result): - original_ast = vy_ast.parse_to_ast(source.format(original)) - target_ast = vy_ast.parse_to_ast(source.format(result)) - - folding.replace_builtin_functions(original_ast) - - assert vy_ast.compare_nodes(original_ast, target_ast) diff --git a/tests/unit/ast/test_natspec.py b/tests/unit/ast/test_natspec.py index c2133468aa..22167f8694 100644 --- a/tests/unit/ast/test_natspec.py +++ b/tests/unit/ast/test_natspec.py @@ -60,7 +60,7 @@ def doesEat(food: String[30], qty: uint256) -> bool: def parse_natspec(code): - vyper_ast = CompilerData(code).vyper_module_folded + vyper_ast = CompilerData(code).annotated_vyper_module return vy_ast.parse_natspec(vyper_ast) diff --git a/vyper/ast/README.md b/vyper/ast/README.md index 320c69da0c..7400091993 100644 --- a/vyper/ast/README.md +++ b/vyper/ast/README.md @@ -12,8 +12,6 @@ and parsing NatSpec docstrings. * [`annotation.py`](annotation.py): Contains the `AnnotatingVisitor` class, used to annotate and modify the Python AST prior to converting it to a Vyper AST. -* [`folding.py`](folding.py): Functions for evaluating and replacing literal -nodes within the Vyper AST. * [`natspec.py`](natspec.py): Functions for parsing NatSpec docstrings within the source. * [`nodes.py`](nodes.py): Contains the Vyper node classes, and the `get_node` @@ -70,25 +68,6 @@ or parents that match a desired pattern. To learn more about these methods, read their docstrings in the `VyperNode` class in [`nodes.py`](nodes.py). -### Modifying the AST - -[`folding.py`](folding.py) contains the `fold` function, a high-level method called -to evaluating and replacing literal nodes within the AST. Some examples of literal -folding include: - -* arithmetic operations (`3+2` becomes `5`) -* references to literal arrays (`["foo", "bar"][1]` becomes `"bar"`) -* builtin functions applied to literals (`min(1,2)` becomes `1`) - -The process of literal folding includes: - -1. Foldable node classes are evaluated via their `evaluate` method, which attempts -to create a new `Constant` from the content of the given node. -2. Replacement nodes are generated using the `from_node` class method within the new -node class. -3. The modification of the tree is handled by `Module.replace_in_tree`, which locates -the existing node and replaces it with a new one. - ## Design ### `__slots__` diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index 4b46801153..bc08626b59 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -17,4 +17,4 @@ # required to avoid circular dependency -from . import expansion, folding # noqa: E402 +from . import expansion # noqa: E402 diff --git a/vyper/ast/__init__.pyi b/vyper/ast/__init__.pyi index eac8ffdef5..5581e82fe2 100644 --- a/vyper/ast/__init__.pyi +++ b/vyper/ast/__init__.pyi @@ -1,7 +1,7 @@ import ast as python_ast from typing import Any, Optional, Union -from . import expansion, folding, nodes, validation +from . import expansion, nodes, validation from .natspec import parse_natspec as parse_natspec from .nodes import * from .parse import parse_to_ast as parse_to_ast diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py deleted file mode 100644 index 087708a356..0000000000 --- a/vyper/ast/folding.py +++ /dev/null @@ -1,263 +0,0 @@ -from typing import Optional, Union - -from vyper.ast import nodes as vy_ast -from vyper.builtins.functions import DISPATCH_TABLE -from vyper.exceptions import UnfoldableNode, UnknownType -from vyper.semantics.types.base import VyperType -from vyper.semantics.types.utils import type_from_annotation - - -def fold(vyper_module: vy_ast.Module) -> None: - """ - Perform literal folding operations on a Vyper AST. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - """ - changed_nodes = 1 - while changed_nodes: - changed_nodes = 0 - changed_nodes += replace_user_defined_constants(vyper_module) - changed_nodes += replace_literal_ops(vyper_module) - changed_nodes += replace_subscripts(vyper_module) - changed_nodes += replace_builtin_functions(vyper_module) - - -def replace_literal_ops(vyper_module: vy_ast.Module) -> int: - """ - Find and evaluate operation and comparison nodes within the Vyper AST, - replacing them with Constant nodes where possible. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - node_types = (vy_ast.BoolOp, vy_ast.BinOp, vy_ast.UnaryOp, vy_ast.Compare) - for node in vyper_module.get_descendants(node_types, reverse=True): - try: - new_node = node.evaluate() - except UnfoldableNode: - continue - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes - - -def replace_subscripts(vyper_module: vy_ast.Module) -> int: - """ - Find and evaluate Subscript nodes within the Vyper AST, replacing them with - Constant nodes where possible. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_descendants(vy_ast.Subscript, reverse=True): - try: - new_node = node.evaluate() - except UnfoldableNode: - continue - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes - - -def replace_builtin_functions(vyper_module: vy_ast.Module) -> int: - """ - Find and evaluate builtin function calls within the Vyper AST, replacing - them with Constant nodes where possible. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_descendants(vy_ast.Call, reverse=True): - if not isinstance(node.func, vy_ast.Name): - continue - - name = node.func.id - func = DISPATCH_TABLE.get(name) - if func is None or not hasattr(func, "evaluate"): - continue - try: - new_node = func.evaluate(node) # type: ignore - except UnfoldableNode: - continue - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes - - -def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: - """ - Find user-defined constant assignments, and replace references - to the constants with their literal values. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_children(vy_ast.VariableDecl): - if not isinstance(node.target, vy_ast.Name): - # left-hand-side of assignment is not a variable - continue - if not node.is_constant: - # annotation is not wrapped in `constant(...)` - continue - - # Extract type definition from propagated annotation - type_ = None - try: - type_ = type_from_annotation(node.annotation) - except UnknownType: - # handle user-defined types e.g. structs - it's OK to not - # propagate the type annotation here because user-defined - # types can be unambiguously inferred at typechecking time - pass - - changed_nodes += replace_constant( - vyper_module, node.target.id, node.value, False, type_=type_ - ) - - return changed_nodes - - -# TODO constant folding on log events - - -def _replace(old_node, new_node, type_=None): - if isinstance(new_node, vy_ast.Constant): - new_node = new_node.from_node(old_node, value=new_node.value) - if type_: - new_node._metadata["type"] = type_ - return new_node - elif isinstance(new_node, vy_ast.List): - base_type = type_.value_type if type_ else None - list_values = [_replace(old_node, i, type_=base_type) for i in new_node.elements] - new_node = new_node.from_node(old_node, elements=list_values) - if type_: - new_node._metadata["type"] = type_ - return new_node - elif isinstance(new_node, vy_ast.Call): - # Replace `Name` node with `Call` node - keyword = keywords = None - if hasattr(new_node, "keyword"): - keyword = new_node.keyword - if hasattr(new_node, "keywords"): - keywords = new_node.keywords - new_node = new_node.from_node( - old_node, func=new_node.func, args=new_node.args, keyword=keyword, keywords=keywords - ) - return new_node - else: - raise UnfoldableNode - - -def replace_constant( - vyper_module: vy_ast.Module, - id_: str, - replacement_node: Union[vy_ast.Constant, vy_ast.List, vy_ast.Call], - raise_on_error: bool, - type_: Optional[VyperType] = None, -) -> int: - """ - Replace references to a variable name with a literal value. - - Arguments - --------- - vyper_module : Module - Module-level ast node to perform replacement in. - id_ : str - String representing the `.id` attribute of the node(s) to be replaced. - replacement_node : Constant | List | Call - Vyper ast node representing the literal value to be substituted in. - `Call` nodes are for struct constants. - raise_on_error: bool - Boolean indicating if `UnfoldableNode` exception should be raised or ignored. - type_ : VyperType, optional - Type definition to be propagated to type checker. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_descendants(vy_ast.Name, {"id": id_}, reverse=True): - parent = node.get_ancestor() - - if isinstance(parent, vy_ast.Call) and node == parent.func: - # do not replace calls because splicing a constant into a callable site is - # never valid and it worsens the error message - continue - - # do not replace dictionary keys - if isinstance(parent, vy_ast.Dict) and node in parent.keys: - continue - - if not node.get_ancestor(vy_ast.Index): - # do not replace left-hand side of assignments - assign = node.get_ancestor( - (vy_ast.Assign, vy_ast.AnnAssign, vy_ast.AugAssign, vy_ast.VariableDecl) - ) - - if assign and node in assign.target.get_descendants(include_self=True): - continue - - # do not replace enum members - if node.get_ancestor(vy_ast.FlagDef): - continue - - try: - # note: _replace creates a copy of the replacement_node - new_node = _replace(node, replacement_node, type_=type_) - except UnfoldableNode: - if raise_on_error: - raise - continue - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes diff --git a/vyper/ast/natspec.py b/vyper/ast/natspec.py index 41905b178a..41a6703b6e 100644 --- a/vyper/ast/natspec.py +++ b/vyper/ast/natspec.py @@ -11,13 +11,13 @@ USERDOCS_FIELDS = ("notice",) -def parse_natspec(vyper_module_folded: vy_ast.Module) -> Tuple[dict, dict]: +def parse_natspec(annotated_vyper_module: vy_ast.Module) -> Tuple[dict, dict]: """ Parses NatSpec documentation from a contract. Arguments --------- - vyper_module_folded : Module + annotated_vyper_module: Module Module-level vyper ast node. interface_codes: Dict, optional Dict containing relevant data for any import statements related to @@ -33,15 +33,15 @@ def parse_natspec(vyper_module_folded: vy_ast.Module) -> Tuple[dict, dict]: from vyper.semantics.types.function import FunctionVisibility userdoc, devdoc = {}, {} - source: str = vyper_module_folded.full_source_code + source: str = annotated_vyper_module.full_source_code - docstring = vyper_module_folded.get("doc_string.value") + docstring = annotated_vyper_module.get("doc_string.value") if docstring: devdoc.update(_parse_docstring(source, docstring, ("param", "return"))) if "notice" in devdoc: userdoc["notice"] = devdoc.pop("notice") - for node in [i for i in vyper_module_folded.body if i.get("doc_string.value")]: + for node in [i for i in annotated_vyper_module.body if i.get("doc_string.value")]: docstring = node.doc_string.value func_type = node._metadata["func_type"] if func_type.visibility != FunctionVisibility.EXTERNAL: diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index dba9f2a22d..efab5117d4 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -11,7 +11,6 @@ from vyper.compiler.settings import VYPER_ERROR_CONTEXT_LINES, VYPER_ERROR_LINE_NUMBERS from vyper.exceptions import ( ArgumentException, - CompilerPanic, InvalidLiteral, InvalidOperation, OverflowException, @@ -19,6 +18,7 @@ SyntaxException, TypeMismatch, UnfoldableNode, + VariableDeclarationException, VyperException, ZeroDivisionException, ) @@ -210,23 +210,6 @@ def _raise_syntax_exc(error_msg: str, ast_struct: dict) -> None: ) -def _validate_numeric_bounds( - node: Union["BinOp", "UnaryOp"], value: Union[decimal.Decimal, int] -) -> None: - if isinstance(value, decimal.Decimal): - # this will change if/when we add more decimal types - lower, upper = SizeLimits.MIN_AST_DECIMAL, SizeLimits.MAX_AST_DECIMAL - elif isinstance(value, int): - lower, upper = SizeLimits.MIN_INT256, SizeLimits.MAX_UINT256 - else: - raise CompilerPanic(f"Unexpected return type from {node._op}: {type(value)}") - if not lower <= value <= upper: - raise OverflowException( - f"Result of {node.op.description} ({value}) is outside bounds of all numeric types", - node, - ) - - class VyperNode: """ Base class for all vyper AST nodes. @@ -246,7 +229,7 @@ class VyperNode: Field names that, if present, must be set to None or a `SyntaxException` is raised. This attribute is used to exclude syntax that is valid in Python but not in Vyper. - _terminus : bool, optional + _is_terminus : bool, optional If `True`, indicates that execution halts upon reaching this node. _translated_fields : Dict, optional Field names that are reassigned if encountered. Used to normalize fields @@ -390,22 +373,67 @@ def description(self): """ return getattr(self, "_description", type(self).__name__) - def evaluate(self) -> "VyperNode": + @property + def is_literal_value(self): + """ + Check if the node is a literal value. + """ + return False + + @property + def has_folded_value(self): + """ + Property method to check if the node has a folded value. + """ + return "folded_value" in self._metadata + + def get_folded_value(self) -> "VyperNode": """ - Attempt to evaluate the content of a node and generate a new node from it. + Attempt to get the folded value, bubbling up UnfoldableNode if the node + is not foldable. + + + The returned value is cached on `_metadata["folded_value"]`. - If a node cannot be evaluated it should raise `UnfoldableNode`. This base - method acts as a catch-all to raise on any inherited classes that do not - implement the method. + For constant/literal nodes, the node should be directly returned + without caching to the metadata. """ - raise UnfoldableNode(f"{type(self)} cannot be evaluated") + if self.is_literal_value: + return self + + if "folded_value" not in self._metadata: + res = self._try_fold() # possibly throws UnfoldableNode + self._set_folded_value(res) + + return self._metadata["folded_value"] + + def _set_folded_value(self, node: "VyperNode") -> None: + # sanity check this is only called once + assert "folded_value" not in self._metadata + + # set the folded node's parent so that get_ancestor works + # this is mainly important for error messages. + node._parent = self._parent + + self._metadata["folded_value"] = node + + def _try_fold(self) -> "VyperNode": + """ + Attempt to constant-fold the content of a node, returning the result of + constant-folding if possible. + + If a node cannot be folded, it should raise `UnfoldableNode`. This + base implementation acts as a catch-all to raise on any inherited + classes that do not implement the method. + """ + raise UnfoldableNode(f"{type(self)} cannot be folded") def validate(self) -> None: """ Validate the content of a node. - Called by `ast.validation.validate_literal_nodes` to verify values within - literal nodes. + Called by `ast.validation.validate_literal_nodes` to verify values + within literal nodes. Returns `None` if the node is valid, raises `InvalidLiteral` or another more expressive exception if the value cannot be valid within a Vyper @@ -609,48 +637,6 @@ class Module(TopLevel): # metadata __slots__ = ("path", "resolved_path", "source_id") - def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: - """ - Perform an in-place substitution of a node within the tree. - - Parameters - ---------- - old_node : VyperNode - Node object to be replaced. - new_node : VyperNode - Node object to replace new_node. - - Returns - ------- - None - """ - parent = old_node._parent - - if old_node not in parent._children: - raise CompilerPanic("Node to be replaced does not exist within parent children") - - is_replaced = False - for key in parent.get_fields(): - obj = getattr(parent, key, None) - if obj == old_node: - if is_replaced: - raise CompilerPanic("Node to be replaced exists as multiple members in parent") - setattr(parent, key, new_node) - is_replaced = True - elif isinstance(obj, list) and obj.count(old_node): - if is_replaced or obj.count(old_node) > 1: - raise CompilerPanic("Node to be replaced exists as multiple members in parent") - obj[obj.index(old_node)] = new_node - is_replaced = True - if not is_replaced: - raise CompilerPanic("Node to be replaced does not exist within parent members") - - parent._children.remove(old_node) - - new_node._parent = parent - new_node._depth = old_node._depth - parent._children.add(new_node) - def add_to_body(self, node: VyperNode) -> None: """ Add a new node to the body of this node. @@ -769,6 +755,10 @@ class Constant(ExprNode): # inherited class for all simple constant node types __slots__ = ("value",) + @property + def is_literal_value(self): + return True + class Num(Constant): # inherited class for all numeric constant node types @@ -862,7 +852,14 @@ def n_bytes(self): """ The number of bytes this hex value represents """ - return self.n_nibbles // 2 + return len(self.bytes_value) + + @property + def bytes_value(self): + """ + This value as bytes + """ + return bytes.fromhex(self.value.removeprefix("0x")) class Str(Constant): @@ -905,19 +902,39 @@ class List(ExprNode): __slots__ = ("elements",) _translated_fields = {"elts": "elements"} + @property + def is_literal_value(self): + return all(e.is_literal_value for e in self.elements) + + def _try_fold(self) -> ExprNode: + elements = [e.get_folded_value() for e in self.elements] + return type(self).from_node(self, elements=elements) + class Tuple(ExprNode): __slots__ = ("elements",) _translated_fields = {"elts": "elements"} + @property + def is_literal_value(self): + return all(e.is_literal_value for e in self.elements) + def validate(self): if not self.elements: raise InvalidLiteral("Cannot have an empty tuple", self) + def _try_fold(self) -> ExprNode: + elements = [e.get_folded_value() for e in self.elements] + return type(self).from_node(self, elements=elements) + class NameConstant(Constant): __slots__ = () + def validate(self): + if self.value is None: + raise InvalidLiteral("`None` is not a valid vyper value!", self) + class Ellipsis(Constant): __slots__ = () @@ -926,6 +943,14 @@ class Ellipsis(Constant): class Dict(ExprNode): __slots__ = ("keys", "values") + @property + def is_literal_value(self): + return all(v.is_literal_value for v in self.values) + + def _try_fold(self) -> ExprNode: + values = [v.get_folded_value() for v in self.values] + return type(self).from_node(self, values=values) + class Name(ExprNode): __slots__ = ("id",) @@ -934,7 +959,7 @@ class Name(ExprNode): class UnaryOp(ExprNode): __slots__ = ("op", "operand") - def evaluate(self) -> ExprNode: + def _try_fold(self) -> ExprNode: """ Attempt to evaluate the unary operation. @@ -943,16 +968,17 @@ def evaluate(self) -> ExprNode: Int | Decimal Node representing the result of the evaluation. """ - if isinstance(self.op, Not) and not isinstance(self.operand, NameConstant): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if isinstance(self.op, USub) and not isinstance(self.operand, (Int, Decimal)): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if isinstance(self.op, Invert) and not isinstance(self.operand, Int): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") + operand = self.operand.get_folded_value() - value = self.op._op(self.operand.value) - _validate_numeric_bounds(self, value) - return type(self.operand).from_node(self, value=value) + if isinstance(self.op, Not) and not isinstance(operand, NameConstant): + raise UnfoldableNode("not a boolean!", self.operand) + if isinstance(self.op, USub) and not isinstance(operand, Num): + raise UnfoldableNode("not a number!", self.operand) + if isinstance(self.op, Invert) and not isinstance(operand, Int): + raise UnfoldableNode("not an int!", self.operand) + + value = self.op._op(operand.value) + return type(operand).from_node(self, value=value) class Operator(VyperNode): @@ -982,7 +1008,7 @@ def _op(self, value): class BinOp(ExprNode): __slots__ = ("left", "op", "right") - def evaluate(self) -> ExprNode: + def _try_fold(self) -> ExprNode: """ Attempt to evaluate the arithmetic operation. @@ -991,20 +1017,19 @@ def evaluate(self) -> ExprNode: Int | Decimal Node representing the result of the evaluation. """ - left, right = self.left, self.right + left, right = [i.get_folded_value() for i in (self.left, self.right)] if type(left) is not type(right): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if not isinstance(left, (Int, Decimal)): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") + raise UnfoldableNode("invalid operation", self) + if not isinstance(left, Num): + raise UnfoldableNode("not a number!", self.left) # this validation is performed to prevent the compiler from hanging # on very large shifts and improve the error message for negative # values. if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256): - raise InvalidLiteral("Shift bits must be between 0 and 256", right) + raise InvalidLiteral("Shift bits must be between 0 and 256", self.right) value = self.op._op(left.value, right.value) - _validate_numeric_bounds(self, value) return type(left).from_node(self, value=value) @@ -1132,7 +1157,7 @@ class RShift(Operator): class BoolOp(ExprNode): __slots__ = ("op", "values") - def evaluate(self) -> ExprNode: + def _try_fold(self) -> ExprNode: """ Attempt to evaluate the boolean operation. @@ -1141,13 +1166,12 @@ def evaluate(self) -> ExprNode: NameConstant Node representing the result of the evaluation. """ - if next((i for i in self.values if not isinstance(i, NameConstant)), None): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") + values = [v.get_folded_value() for v in self.values] - values = [i.value for i in self.values] - if None in values: + if any(not isinstance(v, NameConstant) for v in values): raise UnfoldableNode("Node contains invalid field(s) for evaluation") + values = [v.value for v in values] value = self.op._op(values) return NameConstant.from_node(self, value=value) @@ -1188,7 +1212,7 @@ def __init__(self, *args, **kwargs): kwargs["right"] = kwargs.pop("comparators")[0] super().__init__(*args, **kwargs) - def evaluate(self) -> ExprNode: + def _try_fold(self) -> ExprNode: """ Attempt to evaluate the comparison. @@ -1197,7 +1221,7 @@ def evaluate(self) -> ExprNode: NameConstant Node representing the result of the evaluation. """ - left, right = self.left, self.right + left, right = [i.get_folded_value() for i in (self.left, self.right)] if not isinstance(left, Constant): raise UnfoldableNode("Node contains invalid field(s) for evaluation") @@ -1278,6 +1302,21 @@ def _op(self, left, right): class Call(ExprNode): __slots__ = ("func", "args", "keywords") + # try checking if this is a builtin, which is foldable + def _try_fold(self): + if not isinstance(self.func, Name): + raise UnfoldableNode("not a builtin", self) + + # cursed import cycle! + from vyper.builtins.functions import DISPATCH_TABLE + + func_name = self.func.id + if func_name not in DISPATCH_TABLE: + raise UnfoldableNode("not a builtin", self) + + builtin_t = DISPATCH_TABLE[func_name] + return builtin_t._try_fold(self) + class keyword(VyperNode): __slots__ = ("arg", "value") @@ -1290,7 +1329,7 @@ class Attribute(ExprNode): class Subscript(ExprNode): __slots__ = ("slice", "value") - def evaluate(self) -> ExprNode: + def _try_fold(self) -> ExprNode: """ Attempt to evaluate the subscript. @@ -1302,14 +1341,22 @@ def evaluate(self) -> ExprNode: ExprNode Node representing the result of the evaluation. """ - if not isinstance(self.value, List): + slice_ = self.slice.value.get_folded_value() + value = self.value.get_folded_value() + + if not isinstance(value, List): raise UnfoldableNode("Subscript object is not a literal list") - elements = self.value.elements + + elements = value.elements if len(set([type(i) for i in elements])) > 1: raise UnfoldableNode("List contains multiple node types") - idx = self.slice.get("value.value") - if not isinstance(idx, int) or idx < 0 or idx >= len(elements): - raise UnfoldableNode("Invalid index value") + + if not isinstance(slice_, Int): + raise UnfoldableNode("invalid index type", slice_) + + idx = slice_.value + if idx < 0 or idx >= len(elements): + raise UnfoldableNode("invalid index value") return elements[idx] @@ -1410,6 +1457,24 @@ def _check_args(annotation, call_name): if isinstance(self.annotation, Call): _raise_syntax_exc("Invalid scope for variable declaration", self.annotation) + def _pretty_location(self) -> str: + if self.is_constant: + return "Constant" + if self.is_transient: + return "Transient" + if self.is_immutable: + return "Immutable" + return "Storage" + + def validate(self): + if self.is_constant and self.value is None: + raise VariableDeclarationException("Constant must be declared with a value", self) + + if not self.is_constant and self.value is not None: + raise VariableDeclarationException( + f"{self._pretty_location} variables cannot have an initial value", self.value + ) + class AugAssign(Stmt): __slots__ = ("op", "target", "value") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 47856b6021..8bc4a4eb57 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -24,9 +24,15 @@ class VyperNode: def __eq__(self, other: Any) -> Any: ... @property def description(self): ... + @property + def is_literal_value(self): ... + @property + def has_folded_value(self): ... @classmethod def get_fields(cls: Any) -> set: ... - def evaluate(self) -> VyperNode: ... + def get_folded_value(self) -> VyperNode: ... + def _try_fold(self) -> VyperNode: ... + def _set_folded_value(self, node: VyperNode) -> None: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... def to_dict(self) -> dict: ... @@ -35,14 +41,14 @@ class VyperNode: node_type: Union[Type[VyperNode], Sequence[Type[VyperNode]], None] = ..., filters: Optional[dict] = ..., reverse: bool = ..., - ) -> Sequence: ... + ) -> list: ... def get_descendants( self, node_type: Union[Type[VyperNode], Sequence[Type[VyperNode]], None] = ..., filters: Optional[dict] = ..., include_self: bool = ..., reverse: bool = ..., - ) -> Sequence: ... + ) -> list: ... def get_ancestor( self, node_type: Union[Type[VyperNode], Sequence[Type[VyperNode]], None] = ... ) -> VyperNode: ... @@ -61,7 +67,6 @@ class TopLevel(VyperNode): class Module(TopLevel): path: str = ... resolved_path: str = ... - def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: ... def add_to_body(self, node: VyperNode) -> None: ... def remove_from_body(self, node: VyperNode) -> None: ... def namespace(self) -> Any: ... # context manager diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index a2f2542179..38a9d31695 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -81,6 +81,7 @@ def parse_to_ast_with_settings( # Convert to Vyper AST. module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint + return settings, module diff --git a/vyper/ast/validation.py b/vyper/ast/validation.py index 36a6a0484c..387f7734b9 100644 --- a/vyper/ast/validation.py +++ b/vyper/ast/validation.py @@ -1,11 +1,11 @@ # validation utils for ast -# TODO this really belongs in vyper/semantics/validation/utils from typing import Optional, Union from vyper.ast import nodes as vy_ast from vyper.exceptions import ArgumentException, CompilerPanic, StructureException +# TODO this really belongs in vyper/semantics/validation/utils def validate_call_args( node: vy_ast.Call, arg_count: Union[int, tuple], kwargs: Optional[list] = None ) -> None: @@ -101,14 +101,13 @@ def validate_literal_nodes(vyper_module: vy_ast.Module) -> None: """ Individually validate Vyper AST nodes. - Calls the `validate` method of each node to verify that literal nodes - do not contain invalid values. + Recursively calls the `validate` method of each node to verify that + literal nodes do not contain invalid values. Arguments --------- vyper_module : vy_ast.Module Top level Vyper AST node. """ - for node in vyper_module.get_descendants(): - if hasattr(node, "validate"): - node.validate() + for node in vyper_module.get_descendants(include_self=True): + node.validate() diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index a5949dfd85..aac008ad1e 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -1,12 +1,17 @@ import functools from typing import Any, Optional -from vyper.ast import nodes as vy_ast +from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args from vyper.codegen.expr import Expr from vyper.codegen.ir_node import IRnode -from vyper.exceptions import CompilerPanic, TypeMismatch -from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type +from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode +from vyper.semantics.analysis.base import Modifiability +from vyper.semantics.analysis.utils import ( + check_modifiability, + get_exact_type_from_node, + validate_expected_type, +) from vyper.semantics.types import TYPE_T, KwargSettings, VyperType from vyper.semantics.types.utils import type_from_annotation @@ -29,7 +34,7 @@ def process_arg(arg, expected_arg_type, context): def process_kwarg(kwarg_node, kwarg_settings, expected_kwarg_type, context): if kwarg_settings.require_literal: - return kwarg_node.value + return kwarg_node.get_folded_value().value return process_arg(kwarg_node, expected_kwarg_type, context) @@ -78,6 +83,7 @@ class BuiltinFunctionT(VyperType): _has_varargs = False _inputs: list[tuple[str, Any]] = [] _kwargs: dict[str, KwargSettings] = {} + _modifiability: Modifiability = Modifiability.MODIFIABLE _return_type: Optional[VyperType] = None # helper function to deal with TYPE_DEFINITIONs @@ -106,8 +112,10 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: for kwarg in node.keywords: kwarg_settings = self._kwargs[kwarg.arg] - if kwarg_settings.require_literal and not isinstance(kwarg.value, vy_ast.Constant): - raise TypeMismatch("Value for kwarg must be a literal", kwarg.value) + if kwarg_settings.require_literal and not check_modifiability( + kwarg.value, Modifiability.CONSTANT + ): + raise TypeMismatch("Value must be literal", kwarg.value) self._validate_single(kwarg.value, kwarg_settings.typ) # typecheck varargs. we don't have type info from the signature, @@ -125,7 +133,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: return self._return_type - def infer_arg_types(self, node: vy_ast.Call) -> list[VyperType]: + def infer_arg_types(self, node: vy_ast.Call, expected_return_typ=None) -> list[VyperType]: self._validate_arg_types(node) ret = [expected for (_, expected) in self._inputs] @@ -142,3 +150,6 @@ def infer_kwarg_types(self, node: vy_ast.Call) -> dict[str, VyperType]: def __repr__(self): return f"(builtin) {self._id}" + + def _try_fold(self, node): + raise UnfoldableNode(f"not foldable: {self}", node) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index d50a31767d..c896fc7ef6 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1,7 +1,6 @@ import hashlib import math import operator -from decimal import Decimal from vyper import ast as vy_ast from vyper.abi_types import ABI_Tuple @@ -44,14 +43,13 @@ CompilerPanic, InvalidLiteral, InvalidType, - OverflowException, StateAccessViolation, StructureException, TypeMismatch, UnfoldableNode, ZeroDivisionException, ) -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.analysis.utils import ( get_common_types, get_exact_type_from_node, @@ -88,7 +86,6 @@ EIP_170_LIMIT, SHA3_PER_WORD, MemoryPositions, - SizeLimits, bytes_to_int, ceil32, fourbytes_to_int, @@ -108,9 +105,7 @@ class FoldedFunctionT(BuiltinFunctionT): # Base class for nodes which should always be folded - # Since foldable builtin functions are not folded before semantics validation, - # this flag is used for `check_kwargable` in semantics validation. - _kwargable = True + _modifiability = Modifiability.CONSTANT class TypenameFoldedFunctionT(FoldedFunctionT): @@ -126,7 +121,7 @@ def fetch_call_return(self, node): type_ = self.infer_arg_types(node)[0].typedef return type_ - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): validate_call_args(node, 1) input_typedef = TYPE_T(type_from_annotation(node.args[0])) return [input_typedef] @@ -138,12 +133,13 @@ class Floor(BuiltinFunctionT): # TODO: maybe use int136? _return_type = INT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Decimal): + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Decimal): raise UnfoldableNode - value = math.floor(node.args[0].value) + value = math.floor(value.value) return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -168,12 +164,13 @@ class Ceil(BuiltinFunctionT): # TODO: maybe use int136? _return_type = INT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Decimal): + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Decimal): raise UnfoldableNode - value = math.ceil(node.args[0].value) + value = math.ceil(value.value) return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -202,7 +199,7 @@ def fetch_call_return(self, node): return target_typedef.typedef # TODO: push this down into convert.py for more consistency - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): validate_call_args(node, 2) target_type = type_from_annotation(node.args[1]) @@ -337,7 +334,7 @@ def fetch_call_return(self, node): return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `b` b_type = get_possible_types_from_node(node.args[0]).pop() @@ -461,20 +458,19 @@ class Len(BuiltinFunctionT): _inputs = [("b", (StringT.any(), BytesT.any(), DArrayT.any()))] _return_type = UINT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - arg = node.args[0] + arg = node.args[0].get_folded_value() if isinstance(arg, (vy_ast.Str, vy_ast.Bytes)): length = len(arg.value) elif isinstance(arg, vy_ast.Hex): - # 2 characters represent 1 byte and we subtract 1 to ignore the leading `0x` - length = len(arg.value) // 2 - 1 + length = len(arg.bytes_value) else: raise UnfoldableNode return vy_ast.Int.from_node(node, value=length) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type typ = get_possible_types_from_node(node.args[0]).pop() @@ -504,7 +500,7 @@ def fetch_call_return(self, node): return_type.set_length(length) return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): if len(node.args) < 2: raise ArgumentException("Invalid argument count: expected at least 2", node) @@ -598,22 +594,22 @@ class Keccak256(BuiltinFunctionT): _inputs = [("value", (BytesT.any(), BYTES32_T, StringT.any()))] _return_type = BYTES32_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if isinstance(node.args[0], vy_ast.Bytes): - value = node.args[0].value - elif isinstance(node.args[0], vy_ast.Str): - value = node.args[0].value.encode() - elif isinstance(node.args[0], vy_ast.Hex): - length = len(node.args[0].value) // 2 - 1 - value = int(node.args[0].value, 16).to_bytes(length, "big") + value = node.args[0].get_folded_value() + if isinstance(value, vy_ast.Bytes): + value = value.value + elif isinstance(value, vy_ast.Str): + value = value.value.encode() + elif isinstance(value, vy_ast.Hex): + value = value.bytes_value else: raise UnfoldableNode hash_ = f"0x{keccak256(value).hex()}" return vy_ast.Hex.from_node(node, value=hash_) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `value` value_type = get_possible_types_from_node(node.args[0]).pop() @@ -645,22 +641,22 @@ class Sha256(BuiltinFunctionT): _inputs = [("value", (BYTES32_T, BytesT.any(), StringT.any()))] _return_type = BYTES32_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if isinstance(node.args[0], vy_ast.Bytes): - value = node.args[0].value - elif isinstance(node.args[0], vy_ast.Str): - value = node.args[0].value.encode() - elif isinstance(node.args[0], vy_ast.Hex): - length = len(node.args[0].value) // 2 - 1 - value = int(node.args[0].value, 16).to_bytes(length, "big") + value = node.args[0].get_folded_value() + if isinstance(value, vy_ast.Bytes): + value = value.value + elif isinstance(value, vy_ast.Str): + value = value.value.encode() + elif isinstance(value, vy_ast.Hex): + value = value.bytes_value else: raise UnfoldableNode hash_ = f"0x{hashlib.sha256(value).hexdigest()}" return vy_ast.Hex.from_node(node, value=hash_) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `value` value_type = get_possible_types_from_node(node.args[0]).pop() @@ -714,18 +710,20 @@ def build_IR(self, expr, args, kwargs, context): class MethodID(FoldedFunctionT): _id = "method_id" + _inputs = [("value", StringT.any())] + _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BytesT(4))} - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1, ["output_type"]) - args = node.args - if not isinstance(args[0], vy_ast.Str): - raise InvalidType("method id must be given as a literal string", args[0]) - if " " in args[0].value: - raise InvalidLiteral("Invalid function signature - no spaces allowed.") + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Str): + raise InvalidType("method id must be given as a literal string", node.args[0]) + if " " in value.value: + raise InvalidLiteral("Invalid function signature - no spaces allowed.", node.args[0]) - return_type = self.infer_kwarg_types(node) - value = method_id_int(args[0].value) + return_type = self.infer_kwarg_types(node)["output_type"].typedef + value = method_id_int(value.value) if return_type.compare_type(BYTES4_T): return vy_ast.Hex.from_node(node, value=hex(value)) @@ -735,21 +733,22 @@ def evaluate(self, node): def fetch_call_return(self, node): validate_call_args(node, 1, ["output_type"]) - type_ = self.infer_kwarg_types(node) + type_ = self.infer_kwarg_types(node)["output_type"].typedef return type_ + def infer_arg_types(self, node, expected_return_typ=None): + return [self._inputs[0][1]] + def infer_kwarg_types(self, node): if node.keywords: - return_type = type_from_annotation(node.keywords[0].value) - if return_type.compare_type(BYTES4_T): - return BYTES4_T - elif isinstance(return_type, BytesT) and return_type.length == 4: - return BytesT(4) - else: + output_type = type_from_annotation(node.keywords[0].value) + if output_type not in (BytesT(4), BYTES4_T): raise ArgumentException("output_type must be Bytes[4] or bytes4", node.keywords[0]) + else: + # default to `Bytes[4]` + output_type = BytesT(4) - # If `output_type` is not given, default to `Bytes[4]` - return BytesT(4) + return {"output_type": TYPE_T(output_type)} class ECRecover(BuiltinFunctionT): @@ -762,7 +761,7 @@ class ECRecover(BuiltinFunctionT): ] _return_type = AddressT() - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) v_t, r_t, s_t = [get_possible_types_from_node(arg).pop() for arg in node.args[1:]] return [BYTES32_T, v_t, r_t, s_t] @@ -859,7 +858,7 @@ def fetch_call_return(self, node): return_type = self.infer_kwarg_types(node)["output_type"].typedef return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) input_type = get_possible_types_from_node(node.args[0]).pop() return [input_type, UINT256_T] @@ -974,42 +973,37 @@ class AsWeiValue(BuiltinFunctionT): } def get_denomination(self, node): - if not isinstance(node.args[1], vy_ast.Str): + value = node.args[1].get_folded_value() + if not isinstance(value, vy_ast.Str): raise ArgumentException( "Wei denomination must be given as a literal string", node.args[1] ) try: - denom = next(v for k, v in self.wei_denoms.items() if node.args[1].value in k) + denom = next(v for k, v in self.wei_denoms.items() if value.value in k) except StopIteration: - raise ArgumentException( - f"Unknown denomination: {node.args[1].value}", node.args[1] - ) from None + raise ArgumentException(f"Unknown denomination: {value.value}", node.args[1]) from None return denom - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 2) denom = self.get_denomination(node) - if not isinstance(node.args[0], (vy_ast.Decimal, vy_ast.Int)): + value = node.args[0].get_folded_value() + if not isinstance(value, (vy_ast.Decimal, vy_ast.Int)): raise UnfoldableNode - value = node.args[0].value + value = value.value if value < 0: raise InvalidLiteral("Negative wei value not allowed", node.args[0]) - if isinstance(value, int) and value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", node.args[0]) - if isinstance(value, Decimal) and value > SizeLimits.MAX_AST_DECIMAL: - raise InvalidLiteral("Value out of range for decimal", node.args[0]) - return vy_ast.Int.from_node(node, value=int(value * denom)) def fetch_call_return(self, node): self.infer_arg_types(node) return self._return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type instead of abstract type value_type = get_possible_types_from_node(node.args[0]).pop() @@ -1074,8 +1068,14 @@ def fetch_call_return(self, node): kwargz = {i.arg: i.value for i in node.keywords} outsize = kwargz.get("max_outsize") + if outsize is not None: + outsize = outsize.get_folded_value() + revert_on_failure = kwargz.get("revert_on_failure") - revert_on_failure = revert_on_failure.value if revert_on_failure is not None else True + if revert_on_failure is not None: + revert_on_failure = revert_on_failure.get_folded_value().value + else: + revert_on_failure = True if outsize is None or outsize.value == 0: if revert_on_failure: @@ -1093,7 +1093,7 @@ def fetch_call_return(self, node): return return_type return TupleT([BoolT(), return_type]) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `data` data_type = get_possible_types_from_node(node.args[1]).pop() @@ -1268,7 +1268,7 @@ class RawRevert(BuiltinFunctionT): def fetch_call_return(self, node): return None - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) data_type = get_possible_types_from_node(node.args[0]).pop() return [data_type] @@ -1288,7 +1288,7 @@ class RawLog(BuiltinFunctionT): def fetch_call_return(self, node): self.infer_arg_types(node) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) if not isinstance(node.args[0], vy_ast.List) or len(node.args[0].elements) > 4: @@ -1338,19 +1338,18 @@ class BitwiseAnd(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def _try_fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_and()` is deprecated! Please use the & operator instead.") self.__class__._warned = True validate_call_args(node, 2) - for arg in node.args: - if not isinstance(arg, vy_ast.Int): + values = [i.get_folded_value() for i in node.args] + for val in values: + if not isinstance(val, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = node.args[0].value & node.args[1].value + value = values[0].value & values[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1364,19 +1363,18 @@ class BitwiseOr(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def _try_fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_or()` is deprecated! Please use the | operator instead.") self.__class__._warned = True validate_call_args(node, 2) - for arg in node.args: - if not isinstance(arg, vy_ast.Int): + values = [i.get_folded_value() for i in node.args] + for val in values: + if not isinstance(val, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = node.args[0].value | node.args[1].value + value = values[0].value | values[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1390,19 +1388,18 @@ class BitwiseXor(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def _try_fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_xor()` is deprecated! Please use the ^ operator instead.") self.__class__._warned = True validate_call_args(node, 2) - for arg in node.args: - if not isinstance(arg, vy_ast.Int): + values = [i.get_folded_value() for i in node.args] + for val in values: + if not isinstance(val, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = node.args[0].value ^ node.args[1].value + value = values[0].value ^ values[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1416,18 +1413,17 @@ class BitwiseNot(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def _try_fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_not()` is deprecated! Please use the ~ operator instead.") self.__class__._warned = True validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Int): + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Int): raise UnfoldableNode - value = node.args[0].value - if value < 0 or value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", node.args[0]) + value = value.value value = (2**256 - 1) - value return vy_ast.Int.from_node(node, value=value) @@ -1443,17 +1439,16 @@ class Shift(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def _try_fold(self, node): if not self.__class__._warned: vyper_warn("`shift()` is deprecated! Please use the << or >> operator instead.") self.__class__._warned = True validate_call_args(node, 2) - if [i for i in node.args if not isinstance(i, vy_ast.Int)]: + args = [i.get_folded_value() for i in node.args] + if any(not isinstance(i, vy_ast.Int) for i in args): raise UnfoldableNode - value, shift = [i.value for i in node.args] - if value < 0 or value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", node.args[0]) + value, shift = [i.value for i in args] if shift < -256 or shift > 256: # this validation is performed to prevent the compiler from hanging # rather than for correctness because the post-folded constant would @@ -1470,7 +1465,7 @@ def fetch_call_return(self, node): # return type is the type of the first argument return self.infer_arg_types(node)[0] - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type instead of SignedIntegerAbstractType arg_ty = get_possible_types_from_node(node.args[0])[0] @@ -1495,17 +1490,16 @@ class _AddMulMod(BuiltinFunctionT): _inputs = [("a", UINT256_T), ("b", UINT256_T), ("c", UINT256_T)] _return_type = UINT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 3) - if isinstance(node.args[2], vy_ast.Int) and node.args[2].value == 0: + args = [i.get_folded_value() for i in node.args] + if isinstance(args[2], vy_ast.Int) and args[2].value == 0: raise ZeroDivisionException("Modulo by 0", node.args[2]) - for arg in node.args: + for arg in args: if not isinstance(arg, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = self._eval_fn(node.args[0].value, node.args[1].value) % node.args[2].value + value = self._eval_fn(args[0].value, args[1].value) % args[2].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1537,15 +1531,13 @@ class PowMod256(BuiltinFunctionT): _inputs = [("a", UINT256_T), ("b", UINT256_T)] _return_type = UINT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 2) - if next((i for i in node.args if not isinstance(i, vy_ast.Int)), None): - raise UnfoldableNode - - left, right = node.args - if left.value < 0 or right.value < 0: + values = [i.get_folded_value() for i in node.args] + if any(not isinstance(i, vy_ast.Int) for i in values): raise UnfoldableNode + left, right = values value = pow(left.value, right.value, 2**256) return vy_ast.Int.from_node(node, value=value) @@ -1560,18 +1552,13 @@ class Abs(BuiltinFunctionT): _inputs = [("value", INT256_T)] _return_type = INT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Int): + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Int): raise UnfoldableNode - value = node.args[0].value - if not SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_INT256: - raise OverflowException("Literal is outside of allowable range for int256") - value = abs(value) - if not SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_INT256: - raise OverflowException("Absolute literal value is outside allowable range for int256") - + value = abs(value.value) return vy_ast.Int.from_node(node, value=value) def build_IR(self, expr, context): @@ -1946,7 +1933,7 @@ def fetch_call_return(self, node): return_type = self.infer_arg_types(node).pop() return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) types_list = get_common_types(*node.args, filter_fn=lambda x: isinstance(x, IntegerT)) @@ -2004,34 +1991,26 @@ class UnsafeDiv(_UnsafeMath): class _MinMax(BuiltinFunctionT): _inputs = [("a", (DecimalT(), IntegerT.any())), ("b", (DecimalT(), IntegerT.any()))] - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 2) - if not isinstance(node.args[0], type(node.args[1])): + + left = node.args[0].get_folded_value() + right = node.args[1].get_folded_value() + if not isinstance(left, type(right)): raise UnfoldableNode - if not isinstance(node.args[0], (vy_ast.Decimal, vy_ast.Int)): + if not isinstance(left, (vy_ast.Decimal, vy_ast.Int)): raise UnfoldableNode - left, right = (i.value for i in node.args) - if isinstance(left, Decimal) and ( - min(left, right) < SizeLimits.MIN_AST_DECIMAL - or max(left, right) > SizeLimits.MAX_AST_DECIMAL - ): - raise InvalidType("Decimal value is outside of allowable range", node) - types_list = get_common_types( - *node.args, filter_fn=lambda x: isinstance(x, (IntegerT, DecimalT)) + *(left, right), filter_fn=lambda x: isinstance(x, (IntegerT, DecimalT)) ) if not types_list: raise TypeMismatch("Cannot perform action between dislike numeric types", node) - value = self._eval_fn(left, right) - return type(node.args[0]).from_node(node, value=value) + value = self._eval_fn(left.value, right.value) + return type(left).from_node(node, value=value) def fetch_call_return(self, node): - return_type = self.infer_arg_types(node).pop() - return return_type - - def infer_arg_types(self, node): self._validate_arg_types(node) types_list = get_common_types( @@ -2040,8 +2019,13 @@ def infer_arg_types(self, node): if not types_list: raise TypeMismatch("Cannot perform action between dislike numeric types", node) - type_ = types_list.pop() - return [type_, type_] + return types_list + + def infer_arg_types(self, node, expected_return_typ=None): + types_list = self.fetch_call_return(node) + # type mismatch should have been caught in `fetch_call_return` + assert expected_return_typ in types_list + return [expected_return_typ, expected_return_typ] @process_inputs def build_IR(self, expr, args, kwargs, context): @@ -2085,18 +2069,19 @@ def fetch_call_return(self, node): len_needed = math.ceil(bits * math.log(2) / math.log(10)) return StringT(len_needed) - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Int): + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Int): raise UnfoldableNode - value = node.args[0].value + value = value.value if value < 0: raise InvalidType("Only unsigned ints allowed", node) value = str(value) return vy_ast.Str.from_node(node, value=value) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) input_type = get_possible_types_from_node(node.args[0]).pop() return [input_type] @@ -2493,7 +2478,7 @@ def fetch_call_return(self, node): _, output_type = self.infer_arg_types(node) return output_type.typedef - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) validate_call_args(node, 2, ["unwrap_tuple"]) @@ -2572,7 +2557,7 @@ def build_IR(self, expr, args, kwargs, context): class _MinMaxValue(TypenameFoldedFunctionT): - def evaluate(self, node): + def _try_fold(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) @@ -2590,6 +2575,10 @@ def evaluate(self, node): ret._metadata["type"] = input_type return ret + def infer_arg_types(self, node, expected_return_typ=None): + input_typedef = TYPE_T(type_from_annotation(node.args[0])) + return [input_typedef] + class MinValue(_MinMaxValue): _id = "min_value" @@ -2608,7 +2597,7 @@ def _eval(self, type_): class Epsilon(TypenameFoldedFunctionT): _id = "epsilon" - def evaluate(self, node): + def _try_fold(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 3063a289ab..d6ba9e180a 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -33,7 +33,8 @@ devdoc - Natspec developer documentation combined_json - All of the above format options combined as single JSON output layout - Storage layout of a Vyper contract -ast - AST in JSON format +ast - AST (not yet annotated) in JSON format +annotated_ast - Annotated AST in JSON format interface - Vyper interface of a contract external_interface - External interface of a contract, used for outside contract calls opcodes - List of opcodes as a string @@ -255,7 +256,13 @@ def compile_files( output_formats = combined_json_outputs show_version = True - translate_map = {"abi_python": "abi", "json": "abi", "ast": "ast_dict", "ir_json": "ir_dict"} + translate_map = { + "abi_python": "abi", + "json": "abi", + "ast": "ast_dict", + "annotated_ast": "annotated_ast_dict", + "ir_json": "ir_dict", + } final_formats = [translate_map.get(i, i) for i in output_formats] if storage_layout_paths: diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 4c7c3afaed..577660b883 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -69,9 +69,6 @@ class Expr: # TODO: Once other refactors are made reevaluate all inline imports def __init__(self, node, context): - self.expr = node - self.context = context - if isinstance(node, IRnode): # this is a kludge for parse_AugAssign to pass in IRnodes # directly. @@ -79,6 +76,13 @@ def __init__(self, node, context): self.ir_node = node return + assert isinstance(node, vy_ast.VyperNode) + if node.has_folded_value: + node = node.get_folded_value() + + self.expr = node + self.context = context + fn_name = f"parse_{type(node).__name__}" with tag_exceptions(node, fallback_exception_type=CodegenPanic, note=fn_name): fn = getattr(self, fn_name) @@ -184,6 +188,13 @@ def parse_Name(self): # TODO: use self.expr._expr_info elif self.expr.id in self.context.globals: varinfo = self.context.globals[self.expr.id] + + if varinfo.is_constant: + # non-struct constants should have already gotten propagated + # during constant folding + assert isinstance(varinfo.typ, StructT) + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + assert varinfo.is_immutable, "not an immutable!" ofst = varinfo.position.offset diff --git a/vyper/compiler/README.md b/vyper/compiler/README.md index eb70750a2b..abb8c6ee91 100644 --- a/vyper/compiler/README.md +++ b/vyper/compiler/README.md @@ -25,8 +25,6 @@ The compilation process includes the following broad phases: 1. In [`vyper.ast`](../ast), the source code is parsed and converted to an abstract syntax tree. -1. In [`vyper.ast.folding`](../ast/folding.py), literal Vyper AST nodes are -evaluated and replaced with the resulting values. 1. The [`GlobalContext`](../codegen/global_context.py) object is generated from the Vyper AST, analyzing and organizing the nodes prior to IR generation. 1. In [`vyper.codegen.module`](../codegen/module.py), the contextualized nodes are diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index c87814ba15..0f7d7a8014 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -14,6 +14,8 @@ OUTPUT_FORMATS = { # requires vyper_module "ast_dict": output.build_ast_dict, + # requires annotated_vyper_module + "annotated_ast_dict": output.build_annotated_ast_dict, "layout": output.build_layout_output, # requires global_ctx "devdoc": output.build_devdoc, diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index dc2a43720e..8ccf6abee1 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -23,18 +23,26 @@ def build_ast_dict(compiler_data: CompilerData) -> dict: return ast_dict +def build_annotated_ast_dict(compiler_data: CompilerData) -> dict: + annotated_ast_dict = { + "contract_name": str(compiler_data.contract_path), + "ast": ast_to_dict(compiler_data.annotated_vyper_module), + } + return annotated_ast_dict + + def build_devdoc(compiler_data: CompilerData) -> dict: - userdoc, devdoc = parse_natspec(compiler_data.vyper_module_folded) + userdoc, devdoc = parse_natspec(compiler_data.annotated_vyper_module) return devdoc def build_userdoc(compiler_data: CompilerData) -> dict: - userdoc, devdoc = parse_natspec(compiler_data.vyper_module_folded) + userdoc, devdoc = parse_natspec(compiler_data.annotated_vyper_module) return userdoc def build_external_interface_output(compiler_data: CompilerData) -> str: - interface = compiler_data.vyper_module_folded._metadata["type"].interface + interface = compiler_data.annotated_vyper_module._metadata["type"].interface stem = PurePath(compiler_data.contract_path).stem # capitalize words separated by '_' # ex: test_interface.vy -> TestInterface @@ -53,7 +61,7 @@ def build_external_interface_output(compiler_data: CompilerData) -> str: def build_interface_output(compiler_data: CompilerData) -> str: - interface = compiler_data.vyper_module_folded._metadata["type"].interface + interface = compiler_data.annotated_vyper_module._metadata["type"].interface out = "" if interface.events: @@ -158,7 +166,7 @@ def _to_dict(func_t): def build_method_identifiers_output(compiler_data: CompilerData) -> dict: - module_t = compiler_data.vyper_module_folded._metadata["type"] + module_t = compiler_data.annotated_vyper_module._metadata["type"] functions = module_t.function_defs return { @@ -167,7 +175,7 @@ def build_method_identifiers_output(compiler_data: CompilerData) -> dict: def build_abi_output(compiler_data: CompilerData) -> list: - module_t = compiler_data.vyper_module_folded._metadata["type"] + module_t = compiler_data.annotated_vyper_module._metadata["type"] _ = compiler_data.ir_runtime # ensure _ir_info is generated abi = module_t.interface.to_toplevel_abi_dict() diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index b9b2df6ae8..8cbcfb1da9 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -2,7 +2,7 @@ import warnings from functools import cached_property from pathlib import Path, PurePath -from typing import Optional, Tuple +from typing import Optional from vyper import ast as vy_ast from vyper.codegen import module @@ -53,8 +53,8 @@ class CompilerData: ---------- vyper_module : vy_ast.Module Top-level Vyper AST node - vyper_module_folded : vy_ast.Module - Folded Vyper AST + annotated_vyper_module: vy_ast.Module + Annotated+analysed Vyper AST global_ctx : ModuleT Sorted, contextualized representation of the Vyper AST ir_nodes : IRnode @@ -152,31 +152,24 @@ def vyper_module(self): return self._generate_ast @cached_property - def vyper_module_unfolded(self) -> vy_ast.Module: - # This phase is intended to generate an AST for tooling use, and is not - # used in the compilation process. - - return generate_unfolded_ast(self.vyper_module, self.input_bundle) - - @cached_property - def _folded_module(self): - return generate_folded_ast( + def _annotated_module(self): + return generate_annotated_ast( self.vyper_module, self.input_bundle, self.storage_layout_override ) @property - def vyper_module_folded(self) -> vy_ast.Module: - module, storage_layout = self._folded_module + def annotated_vyper_module(self) -> vy_ast.Module: + module, storage_layout = self._annotated_module return module @property def storage_layout(self) -> StorageLayout: - module, storage_layout = self._folded_module + module, storage_layout = self._annotated_module return storage_layout @property def global_ctx(self) -> ModuleT: - return self.vyper_module_folded._metadata["type"] + return self.annotated_vyper_module._metadata["type"] @cached_property def _ir_output(self): @@ -205,7 +198,7 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: # ensure codegen is run: _ = self._ir_output - fs = self.vyper_module_folded.get_children(vy_ast.FunctionDef) + fs = self.annotated_vyper_module.get_children(vy_ast.FunctionDef) return {f.name: f._metadata["func_type"] for f in fs} @cached_property @@ -247,25 +240,13 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -# destructive -- mutates module in place! -def generate_unfolded_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: - vy_ast.validation.validate_literal_nodes(vyper_module) - vy_ast.folding.replace_builtin_functions(vyper_module) - - with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - # note: validate_semantics does type inference on the AST - validate_semantics(vyper_module, input_bundle) - - return vyper_module - - -def generate_folded_ast( +def generate_annotated_ast( vyper_module: vy_ast.Module, input_bundle: InputBundle, storage_layout_overrides: StorageLayout = None, -) -> Tuple[vy_ast.Module, StorageLayout]: +) -> tuple[vy_ast.Module, StorageLayout]: """ - Perform constant folding operations on the Vyper AST. + Validates and annotates the Vyper AST. Arguments --------- @@ -275,22 +256,18 @@ def generate_folded_ast( Returns ------- vy_ast.Module - Folded Vyper AST + Annotated Vyper AST StorageLayout Layout of variables in storage """ - - vy_ast.validation.validate_literal_nodes(vyper_module) - - vyper_module_folded = copy.deepcopy(vyper_module) - vy_ast.folding.fold(vyper_module_folded) - + vyper_module = copy.deepcopy(vyper_module) with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - validate_semantics(vyper_module_folded, input_bundle) + # note: validate_semantics does type inference on the AST + validate_semantics(vyper_module, input_bundle) - symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides) + symbol_tables = set_data_positions(vyper_module, storage_layout_overrides) - return vyper_module_folded, symbol_tables + return vyper_module, symbol_tables def generate_ir_nodes( diff --git a/vyper/semantics/README.md b/vyper/semantics/README.md index 1d81a0979b..36519bba29 100644 --- a/vyper/semantics/README.md +++ b/vyper/semantics/README.md @@ -25,6 +25,7 @@ Vyper abstract syntax tree (AST). * [`data_positions`](analysis/data_positions.py): Functions for tracking storage variables and allocating storage slots * [`levenhtein_utils.py`](analysis/levenshtein_utils.py): Helper for better error messages * [`local.py`](analysis/local.py): Validates the local namespace of each function within a contract + * [`pre_typecheck.py`](analysis/pre_typecheck.py): Evaluate foldable nodes and populate their metadata with the replacement nodes. * [`module.py`](analysis/module.py): Validates the module namespace of a contract. * [`utils.py`](analysis/utils.py): Functions for comparing and validating types * [`data_locations.py`](data_locations.py): `DataLocation` object for type location information @@ -35,13 +36,23 @@ Vyper abstract syntax tree (AST). The [`analysis`](analysis) subpackage contains the top-level `validate_semantics` function. This function is used to verify and type-check a contract. The process -consists of three steps: +consists of four steps: -1. Preparing the builtin namespace -2. Validating the module-level scope -3. Annotating and validating local scopes +1. Populating the metadata of foldable nodes with their replacement nodes +2. Preparing the builtin namespace +3. Validating the module-level scope +4. Annotating and validating local scopes -### 1. Preparing the builtin namespace +### 1. Populating the metadata of foldable nodes with their replacement nodes + +[`analysis/pre_typecheck.py`](analysis/pre_typecheck.py) populates the metadata of foldable nodes with their replacement nodes. + +This process includes: +1. Foldable node classes and builtin functions are evaluated via their `fold` method, which attempts to create a new `Constant` from the content of the given node. +2. Replacement nodes are generated using the `from_node` class method within the new +node class. + +### 2. Preparing the builtin namespace The [`Namespace`](namespace.py) object represents the namespace for a contract. Builtins are added upon initialization of the object. This includes: @@ -51,9 +62,9 @@ Builtins are added upon initialization of the object. This includes: * Adding builtin functions from the [`functions`](../builtins/functions.py) package * Adding / resetting `self` and `log` -### 2. Validating the Module Scope +### 3. Validating the Module Scope -[`validation/module.py`](validation/module.py) validates the module-level scope +[`analysis/module.py`](analysis/module.py) validates the module-level scope of a contract. This includes: * Generating user-defined types (e.g. structs and interfaces) @@ -61,9 +72,9 @@ of a contract. This includes: and functions * Validating import statements and function signatures -### 3. Annotating and validating the Local Scopes +### 4. Annotating and validating the Local Scopes -[`validation/local.py`](validation/local.py) validates the local scope within each +[`analysis/local.py`](analysis/local.py) validates the local scope within each function in a contract. `FunctionNodeVisitor` is used to iterate over the statement nodes in each function body, annotate them and apply appropriate checks. diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 4d1b1cdbab..bb6d9ad9f7 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -97,6 +97,27 @@ def from_abi(cls, abi_dict: Dict) -> "StateMutability": # specifying a state mutability modifier at all. Do the same here. +# classify the constancy of an expression +# CMC 2023-12-31 note that we now have three ways of classifying mutability in +# the codebase: StateMutability (for functions), Modifiability (for expressions +# and variables) and Constancy (in codegen). context.Constancy can/should +# probably be refactored away though as those kinds of checks should be done +# during analysis. +class Modifiability(enum.IntEnum): + # is writeable/can result in arbitrary state or memory changes + MODIFIABLE = enum.auto() + + # could potentially add more fine-grained here as needed, like + # CONSTANT_AFTER_DEPLOY, TX_CONSTANT, BLOCK_CONSTANT, etc. + + # things that are constant within the current message call, including + # block.*, msg.*, tx.* and immutables + RUNTIME_CONSTANT = enum.auto() + + # compile-time / always constant + CONSTANT = enum.auto() + + class DataPosition: _location: DataLocation @@ -182,21 +203,18 @@ class ImportInfo(AnalysisResult): class VarInfo: """ VarInfo are objects that represent the type of a variable, - plus associated metadata like location and constancy attributes + plus associated metadata like location and modifiability attributes Object Attributes ----------------- - is_constant : bool, optional - If `True`, this is a variable defined with the `constant()` modifier + location: DataLocation of this variable + modifiability: Modifiability of this variable """ typ: VyperType location: DataLocation = DataLocation.UNSET - is_constant: bool = False + modifiability: Modifiability = Modifiability.MODIFIABLE is_public: bool = False - is_immutable: bool = False - is_transient: bool = False - is_local_var: bool = False decl_node: Optional[vy_ast.VyperNode] = None def __hash__(self): @@ -211,10 +229,28 @@ def set_position(self, position: DataPosition) -> None: if self.location != position._location: if self.location == DataLocation.UNSET: self.location = position._location + elif self.is_transient and position._location == DataLocation.STORAGE: + # CMC 2023-12-31 - use same allocator for storage and transient + # for now, this should be refactored soon. + pass else: raise CompilerPanic("Incompatible locations") self.position = position + @property + def is_transient(self): + return self.location == DataLocation.TRANSIENT + + @property + def is_immutable(self): + return self.location == DataLocation.CODE + + @property + def is_constant(self): + res = self.location == DataLocation.UNSET + assert res == (self.modifiability == Modifiability.CONSTANT) + return res + @dataclass class ExprInfo: @@ -225,11 +261,10 @@ class ExprInfo: typ: VyperType var_info: Optional[VarInfo] = None location: DataLocation = DataLocation.UNSET - is_constant: bool = False - is_immutable: bool = False + modifiability: Modifiability = Modifiability.MODIFIABLE def __post_init__(self): - should_match = ("typ", "location", "is_constant", "is_immutable") + should_match = ("typ", "location", "modifiability") if self.var_info is not None: for attr in should_match: if getattr(self.var_info, attr) != getattr(self, attr): @@ -241,8 +276,7 @@ def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": var_info.typ, var_info=var_info, location=var_info.location, - is_constant=var_info.is_constant, - is_immutable=var_info.is_immutable, + modifiability=var_info.modifiability, ) @classmethod @@ -253,7 +287,7 @@ def copy_with_type(self, typ: VyperType) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ - to_copy = ("location", "is_constant", "is_immutable") + to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} return self.__class__(typ=typ, **fields) @@ -277,17 +311,24 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil if self.location == DataLocation.CALLDATA: raise ImmutableViolation("Cannot write to calldata", node) - if self.is_constant: + + if self.modifiability == Modifiability.RUNTIME_CONSTANT: + if self.location == DataLocation.CODE: + if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": + raise ImmutableViolation("Immutable value cannot be written to", node) + + # special handling for immutable variables in the ctor + # TODO: we probably want to remove this restriction. + if self.var_info._modification_count: # type: ignore + raise ImmutableViolation( + "Immutable value cannot be modified after assignment", node + ) + self.var_info._modification_count += 1 # type: ignore + else: + raise ImmutableViolation("Environment variable cannot be written to", node) + + if self.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to", node) - if self.is_immutable: - if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": - raise ImmutableViolation("Immutable value cannot be written to", node) - # TODO: we probably want to remove this restriction. - if self.var_info._modification_count: # type: ignore - raise ImmutableViolation( - "Immutable value cannot be modified after assignment", node - ) - self.var_info._modification_count += 1 # type: ignore if isinstance(node, vy_ast.AugAssign): self.typ.validate_numeric_op(node) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index a3ebf85fa2..91fb2c21f0 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -18,7 +18,7 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -186,16 +186,18 @@ def __init__( self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["func_type"] - self.expr_visitor = _ExprVisitor(self.func) + self.expr_visitor = ExprVisitor(self.func) def analyze(self): # allow internal function params to be mutable - location, is_immutable = ( - (DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True) - ) + if self.func.is_internal: + location, modifiability = (DataLocation.MEMORY, Modifiability.MODIFIABLE) + else: + location, modifiability = (DataLocation.CALLDATA, Modifiability.RUNTIME_CONSTANT) + for arg in self.func.arguments: self.namespace[arg.name] = VarInfo( - arg.typ, location=location, is_immutable=is_immutable + arg.typ, location=location, modifiability=modifiability ) for node in self.fn_node.body: @@ -358,7 +360,8 @@ def visit_For(self, node): else: # iteration over a variable or literal list - if isinstance(node.iter, vy_ast.List) and len(node.iter.elements) == 0: + iter_val = node.iter.get_folded_value() if node.iter.has_folded_value else node.iter + if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) type_list = [ @@ -421,32 +424,35 @@ def visit_For(self, node): # type check the for loop body using each possible type for iterator value with self.namespace.enter_scope(): - self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) + self.namespace[iter_name] = VarInfo( + possible_target_type, modifiability=Modifiability.RUNTIME_CONSTANT + ) try: with NodeMetadata.enter_typechecker_speculation(): for stmt in node.body: self.visit(stmt) + + self.expr_visitor.visit(node.target, possible_target_type) + + if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): + iter_type = get_exact_type_from_node(node.iter) + # note CMC 2023-10-23: slightly redundant with how type_list is computed + validate_expected_type(node.target, iter_type.value_type) + self.expr_visitor.visit(node.iter, iter_type) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) + if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + for a in node.iter.args: + self.expr_visitor.visit(a, possible_target_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, possible_target_type) + except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) else: - self.expr_visitor.visit(node.target, possible_target_type) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - iter_type = get_exact_type_from_node(node.iter) - # note CMC 2023-10-23: slightly redundant with how type_list is computed - validate_expected_type(node.target, iter_type.value_type) - self.expr_visitor.visit(node.iter, iter_type) - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, possible_target_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, possible_target_type) - # success -- do not enter error handling section return @@ -523,10 +529,10 @@ def visit_Return(self, node): self.expr_visitor.visit(node.value, self.func.return_type) -class _ExprVisitor(VyperNodeVisitorBase): +class ExprVisitor(VyperNodeVisitorBase): scope_name = "function" - def __init__(self, fn_node: ContractFunctionT): + def __init__(self, fn_node: Optional[ContractFunctionT] = None): self.func = fn_node def visit(self, node, typ): @@ -543,6 +549,12 @@ def visit(self, node, typ): # annotate node._metadata["type"] = typ + # validate and annotate folded value + if node.has_folded_value: + folded_node = node.get_folded_value() + validate_expected_type(folded_node, typ) + folded_node._metadata["type"] = typ + def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) @@ -551,10 +563,10 @@ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: # if self.func.mutability < expr_info.mutability: # raise ... - if self.func.mutability != StateMutability.PAYABLE: + if self.func and self.func.mutability != StateMutability.PAYABLE: _validate_msg_value_access(node) - if self.func.mutability == StateMutability.PURE: + if self.func and self.func.mutability == StateMutability.PURE: _validate_pure_access(node, typ) value_type = get_exact_type_from_node(node.value) @@ -589,7 +601,7 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: if isinstance(call_type, ContractFunctionT): # function calls - if call_type.is_internal: + if self.func and call_type.is_internal: self.func.called_functions.add(call_type) for arg, typ in zip(node.args, call_type.argument_types): self.visit(arg, typ) @@ -615,7 +627,7 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: self.visit(arg, arg_type) else: # builtin functions - arg_types = call_type.infer_arg_types(node) + arg_types = call_type.infer_arg_types(node, expected_return_typ=typ) # `infer_arg_types` already calls `validate_expected_type` for arg, arg_type in zip(node.args, arg_types): self.visit(arg, arg_type) @@ -680,7 +692,7 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: self.visit(element, typ.value_type) def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: - if self.func.mutability == StateMutability.PURE: + if self.func and self.func.mutability == StateMutability.PURE: _validate_self_reference(node) if not isinstance(typ, TYPE_T): @@ -691,7 +703,7 @@ def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: # don't recurse; can't annotate AST children of type definition return - if isinstance(node.value, vy_ast.List): + if isinstance(node.value, (vy_ast.List, vy_ast.Subscript)): possible_base_types = get_possible_types_from_node(node.value) for possible_type in possible_base_types: @@ -747,6 +759,7 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args + start, end = [i.get_folded_value() if i.has_folded_value else i for i in (start, end)] all_args = (start, end, *kwargs.values()) for arg1 in all_args: @@ -758,6 +771,8 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: if "bound" in kwargs: bound = kwargs["bound"] + if bound.has_folded_value: + bound = bound.get_folded_value() if not isinstance(bound, vy_ast.Num): raise StateAccessViolation("Bound must be a literal", bound) if bound.value <= 0: diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index fb536b7ab7..8e435f870f 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -4,6 +4,7 @@ import vyper.builtins.interfaces from vyper import ast as vy_ast +from vyper.ast.validation import validate_literal_nodes from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, InputBundle from vyper.evm.opcodes import version_check from vyper.exceptions import ( @@ -20,12 +21,13 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import ImportInfo, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.import_graph import ImportGraph -from vyper.semantics.analysis.local import validate_functions +from vyper.semantics.analysis.local import ExprVisitor, validate_functions +from vyper.semantics.analysis.pre_typecheck import pre_typecheck from vyper.semantics.analysis.utils import ( - check_constant, + check_modifiability, get_exact_type_from_node, validate_expected_type, ) @@ -51,6 +53,10 @@ def validate_semantics_r( Analyze a Vyper module AST node, add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info """ + validate_literal_nodes(module_ast) + + pre_typecheck(module_ast) + # validate semantics and annotate AST with type/semantics information namespace = get_namespace() @@ -254,12 +260,19 @@ def visit_VariableDecl(self, node): if node.is_immutable else DataLocation.UNSET if node.is_constant - # XXX: needed if we want separate transient allocator - # else DataLocation.TRANSIENT - # if node.is_transient + else DataLocation.TRANSIENT + if node.is_transient else DataLocation.STORAGE ) + modifiability = ( + Modifiability.RUNTIME_CONSTANT + if node.is_immutable + else Modifiability.CONSTANT + if node.is_constant + else Modifiability.MODIFIABLE + ) + type_ = type_from_annotation(node.annotation, data_loc) if node.is_transient and not version_check(begin="cancun"): @@ -269,10 +282,8 @@ def visit_VariableDecl(self, node): type_, decl_node=node, location=data_loc, - is_constant=node.is_constant, + modifiability=modifiability, is_public=node.is_public, - is_immutable=node.is_immutable, - is_transient=node.is_transient, ) node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace node._metadata["type"] = type_ @@ -302,9 +313,11 @@ def _validate_self_namespace(): self.namespace[name] = var_info if node.is_constant: - if not node.value: - raise VariableDeclarationException("Constant must be declared with a value", node) - if not check_constant(node.value): + assert node.value is not None # checked in VariableDecl.validate() + + ExprVisitor().visit(node.value, type_) + + if not check_modifiability(node.value, Modifiability.CONSTANT): raise StateAccessViolation("Value must be a literal", node.value) validate_expected_type(node.value, type_) @@ -312,11 +325,7 @@ def _validate_self_namespace(): return _finalize() - if node.value: - var_type = "Immutable" if node.is_immutable else "Storage" - raise VariableDeclarationException( - f"{var_type} variables cannot have an initial value", node.value - ) + assert node.value is None # checked in VariableDecl.validate() if node.is_immutable: _validate_self_namespace() @@ -482,9 +491,6 @@ def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: module_path=str(file.path), resolved_path=str(file.resolved_path), ) - vy_ast.validation.validate_literal_nodes(ret) - vy_ast.folding.fold(ret) - return ret diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py new file mode 100644 index 0000000000..a1302ce9c9 --- /dev/null +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -0,0 +1,94 @@ +from vyper import ast as vy_ast +from vyper.exceptions import UnfoldableNode + + +# try to fold a node, swallowing exceptions. this function is very similar to +# `VyperNode.get_folded_value()` but additionally checks in the constants +# table if the node is a `Name` node. +# +# CMC 2023-12-30 a potential refactor would be to move this function into +# `Name._try_fold` (which would require modifying the signature of _try_fold to +# take an optional constants table as parameter). this would remove the +# need to use this function in conjunction with `get_descendants` since +# `VyperNode._try_fold()` already recurses. it would also remove the need +# for `VyperNode._set_folded_value()`. +def _fold_with_constants(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]): + if node.has_folded_value: + return + + if isinstance(node, vy_ast.Name): + # check if it's in constants table + var_name = node.id + + if var_name not in constants: + return + + res = constants[var_name] + node._set_folded_value(res) + return + + try: + # call get_folded_value for its side effects + node.get_folded_value() + except UnfoldableNode: + pass + + +def _get_constants(node: vy_ast.Module) -> dict: + constants: dict[str, vy_ast.VyperNode] = {} + const_var_decls = node.get_children(vy_ast.VariableDecl, {"is_constant": True}) + + while True: + n_processed = 0 + + for c in const_var_decls.copy(): + assert c.value is not None # guaranteed by VariableDecl.validate() + + for n in c.get_descendants(reverse=True): + _fold_with_constants(n, constants) + + try: + val = c.value.get_folded_value() + except UnfoldableNode: + # not foldable, maybe it depends on other constants + # so try again later + continue + + # note that if a constant is redefined, its value will be + # overwritten, but it is okay because the error is handled + # downstream + name = c.target.id + constants[name] = val + + n_processed += 1 + const_var_decls.remove(c) + + if n_processed == 0: + # this condition means that there are some constant vardecls + # whose values are not foldable. this can happen for struct + # and interface constants for instance. these are valid constant + # declarations, but we just can't fold them at this stage. + break + + return constants + + +# perform constant folding on a module AST +def pre_typecheck(node: vy_ast.Module) -> None: + """ + Perform pre-typechecking steps on a Module AST node. + At this point, this is limited to performing constant folding. + """ + constants = _get_constants(node) + + # note: use reverse to get descendants in leaf-first order + for n in node.get_descendants(reverse=True): + # try folding every single node. note this should be done before + # type checking because the typechecker requires literals or + # foldable nodes in type signatures and some other places (e.g. + # certain builtin kwargs). + # + # note we could limit to only folding nodes which are required + # during type checking, but it's easier to just fold everything + # and be done with it! + _fold_with_constants(n, constants) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 20ebb0f093..ba1b02b8d6 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,7 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -98,12 +98,9 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: # kludge! for validate_modification in local analysis of Assign types = [self.get_expr_info(n) for n in node.elements] location = sorted((i.location for i in types), key=lambda k: k.value)[-1] - is_constant = any((getattr(i, "is_constant", False) for i in types)) - is_immutable = any((getattr(i, "is_immutable", False) for i in types)) + modifiability = sorted((i.modifiability for i in types), key=lambda k: k.value)[-1] - return ExprInfo( - t, location=location, is_constant=is_constant, is_immutable=is_immutable - ) + return ExprInfo(t, location=location, modifiability=modifiability) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): @@ -137,8 +134,7 @@ def get_exact_type_from_node(self, node, include_type_exprs=False): def get_possible_types_from_node(self, node, include_type_exprs=False): """ Find all possible types for a given node. - If the node's metadata contains type information propagated from constant folding, - then that type is returned. + If the node's metadata contains type information, then that type is returned. Arguments --------- @@ -203,10 +199,12 @@ def _raise_invalid_reference(name, node): if isinstance(s, (VyperType, TYPE_T)): # ex. foo.bar(). bar() is a ContractFunctionT return [s] + + # general case. s is a VarInfo, e.g. self.foo if is_self_reference and (s.is_constant or s.is_immutable): _raise_invalid_reference(name, node) - # general case. s is a VarInfo, e.g. self.foo return [s.typ] + except UnknownAttribute as e: if not is_self_reference: raise e from None @@ -282,6 +280,8 @@ def types_from_Call(self, node): var = self.get_exact_type_from_node(node.func, include_type_exprs=True) return_value = var.fetch_call_return(node) if return_value: + if isinstance(return_value, list): + return return_value return [return_value] raise InvalidType(f"{var} did not return a value", node) @@ -378,7 +378,7 @@ def types_from_Name(self, node): def types_from_Subscript(self, node): # index access, e.g. `foo[1]` - if isinstance(node.value, vy_ast.List): + if isinstance(node.value, (vy_ast.List, vy_ast.Subscript)): types_list = self.get_possible_types_from_node(node.value) ret = [] for t in types_list: @@ -625,54 +625,33 @@ def validate_unique_method_ids(functions: List) -> None: seen.add(method_id) -def check_kwargable(node: vy_ast.VyperNode) -> bool: +def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool: """ - Check if the given node can be used as a default arg + Check if the given node is not more modifiable than the given modifiability. """ - if _check_literal(node): + if node.is_literal_value or node.has_folded_value: return True - if isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(check_kwargable(item) for item in node.elements) - if isinstance(node, vy_ast.Call): - args = node.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return all(check_kwargable(v) for v in args[0].values) - call_type = get_exact_type_from_node(node.func) - if getattr(call_type, "_kwargable", False): - return True + if isinstance(node, (vy_ast.BinOp, vy_ast.Compare)): + return all(check_modifiability(i, modifiability) for i in (node.left, node.right)) - value_type = get_expr_info(node) - # is_constant here actually means not_assignable, and is to be renamed - return value_type.is_constant + if isinstance(node, vy_ast.BoolOp): + return all(check_modifiability(i, modifiability) for i in node.values) + if isinstance(node, vy_ast.UnaryOp): + return check_modifiability(node.operand, modifiability) -def _check_literal(node: vy_ast.VyperNode) -> bool: - """ - Check if the given node is a literal value. - """ - if isinstance(node, vy_ast.Constant): - return True - elif isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(_check_literal(item) for item in node.elements) - return False - - -def check_constant(node: vy_ast.VyperNode) -> bool: - """ - Check if the given node is a literal or constant value. - """ - if _check_literal(node): - return True if isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(check_constant(item) for item in node.elements) + return all(check_modifiability(item, modifiability) for item in node.elements) + if isinstance(node, vy_ast.Call): args = node.args if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return all(check_constant(v) for v in args[0].values) + return all(check_modifiability(v, modifiability) for v in args[0].values) call_type = get_exact_type_from_node(node.func) - if getattr(call_type, "_kwargable", False): - return True + call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE) + return call_type_modifiability >= modifiability - return False + value_type = get_expr_info(node) + return value_type.modifiability >= modifiability diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index 2f259b1766..cecea35a60 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -7,5 +7,4 @@ class DataLocation(enum.Enum): STORAGE = 2 CALLDATA = 3 CODE = 4 - # XXX: needed for separate transient storage allocator - # TRANSIENT = 5 + TRANSIENT = 5 diff --git a/vyper/semantics/environment.py b/vyper/semantics/environment.py index ad68f1103e..38bac0a63d 100644 --- a/vyper/semantics/environment.py +++ b/vyper/semantics/environment.py @@ -1,6 +1,6 @@ from typing import Dict -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.types import AddressT, BytesT, VyperType from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T @@ -52,7 +52,7 @@ def get_constant_vars() -> Dict: """ result = {} for k, v in CONSTANT_ENVIRONMENT_VARS.items(): - result[k] = VarInfo(v, is_constant=True) + result[k] = VarInfo(v, modifiability=Modifiability.RUNTIME_CONSTANT) return result diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 6ecfe78be3..429ba807e1 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -340,7 +340,7 @@ def fetch_call_return(self, node): return self.typedef._ctor_call_return(node) raise StructureException("Value is not callable", node) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): if hasattr(self.typedef, "_ctor_arg_types"): return self.typedef._ctor_arg_types(node) raise StructureException("Value is not callable", node) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 34206546fd..7c77560e49 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -16,9 +16,14 @@ StateAccessViolation, StructureException, ) -from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot +from vyper.semantics.analysis.base import ( + FunctionVisibility, + Modifiability, + StateMutability, + StorageSlot, +) from vyper.semantics.analysis.utils import ( - check_kwargable, + check_modifiability, get_exact_type_from_node, validate_expected_type, ) @@ -128,7 +133,7 @@ def __repr__(self): def __str__(self): ret_sig = "" if not self.return_type else f" -> {self.return_type}" args_sig = ",".join([str(t) for t in self.argument_types]) - return f"def {self.name} {args_sig}{ret_sig}:" + return f"def {self.name}({args_sig}){ret_sig}:" # override parent implementation. function type equality does not # make too much sense. @@ -696,7 +701,7 @@ def _parse_args( positional_args.append(PositionalArg(argname, type_, ast_source=arg)) else: value = funcdef.args.defaults[i - n_positional_args] - if not check_kwargable(value): + if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT): raise StateAccessViolation("Value must be literal or environment variable", value) validate_expected_type(value, type_) keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 0c8e9fddd8..55ffc23b2f 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -288,6 +288,8 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": raise StructureException(err_msg, node.slice) length_node = node.slice.value.elements[1] + if length_node.has_folded_value: + length_node = length_node.get_folded_value() if not isinstance(length_node, vy_ast.Int): raise StructureException(err_msg, length_node) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 8d68a9fa01..eb96375404 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -179,7 +179,11 @@ def get_index_value(node: vy_ast.Index) -> int: # TODO: revisit this! from vyper.semantics.analysis.utils import get_possible_types_from_node - if not isinstance(node.get("value"), vy_ast.Int): + value = node.get("value") + if value.has_folded_value: + value = value.get_folded_value() + + if not isinstance(value, vy_ast.Int): if hasattr(node, "value"): # even though the subscript is an invalid type, first check if it's a valid _something_ # this gives a more accurate error in case of e.g. a typo in a constant variable name @@ -191,7 +195,7 @@ def get_index_value(node: vy_ast.Index) -> int: raise InvalidType("Subscript must be a literal integer", node) - if node.value.value <= 0: + if value.value <= 0: raise ArrayIndexException("Subscript must be greater than 0", node) - return node.value.value + return value.value