diff --git a/setup.py b/setup.py index 40efb436c5..431c50b74b 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ "pytest-instafail>=0.4,<1.0", "pytest-xdist>=2.5,<3.0", "pytest-split>=0.7.0,<1.0", - "pytest-rerunfailures>=10.2,<11", "eth-tester[py-evm]>=0.9.0b1,<0.10", "py-evm>=0.7.0a1,<0.8", "web3==6.0.0", diff --git a/tests/conftest.py b/tests/conftest.py index 22f8544beb..925a025a4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from vyper import compiler from vyper.ast.grammar import parse_vyper_source from vyper.codegen.ir_node import IRnode -from vyper.compiler.input_bundle import FilesystemInputBundle +from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings, _set_debug_mode from vyper.ir import compile_ir, optimizer @@ -103,6 +103,12 @@ def fn(sources_dict): return fn +# for tests which just need an input bundle, doesn't matter what it is +@pytest.fixture +def dummy_input_bundle(): + return InputBundle([]) + + # TODO: remove me, this is just string.encode("utf-8").ljust() # only used in test_logging.py. @pytest.fixture @@ -255,9 +261,11 @@ def ir_compiler(ir, *args, **kwargs): ir = IRnode.from_list(ir) if optimize != OptimizationLevel.NONE: ir = optimizer.optimize(ir) + bytecode, _ = compile_ir.assembly_to_evm( compile_ir.compile_to_assembly(ir, optimize=optimize) ) + abi = kwargs.get("abi") or [] c = w3.eth.contract(abi=abi, bytecode=bytecode) deploy_transaction = c.constructor() diff --git a/tests/functional/codegen/test_call_graph_stability.py b/tests/functional/codegen/test_call_graph_stability.py index 4c85c330f3..2d8ad59791 100644 --- a/tests/functional/codegen/test_call_graph_stability.py +++ b/tests/functional/codegen/test_call_graph_stability.py @@ -55,7 +55,7 @@ def foo(): # check the .called_functions data structure on foo() directly foo = t.vyper_module_folded.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0] - foo_t = foo._metadata["type"] + foo_t = foo._metadata["func_type"] assert [f.name for f in foo_t.called_functions] == func_names # now for sanity, ensure the order that the function definitions appear diff --git a/tests/functional/builtins/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py similarity index 84% rename from tests/functional/builtins/codegen/test_interfaces.py rename to tests/functional/codegen/test_interfaces.py index 8cb0124f29..3544f4a965 100644 --- a/tests/functional/builtins/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -6,9 +6,9 @@ from vyper.compiler import compile_code from vyper.exceptions import ( ArgumentException, + DuplicateImport, InterfaceViolation, NamespaceCollision, - StructureException, ) @@ -31,7 +31,7 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256): out = compile_code(code, output_formats=["interface"]) out = out["interface"] - code_pass = "\n".join(code.split("\n")[:-2] + [" pass"]) # replace with a pass statement. + code_pass = "\n".join(code.split("\n")[:-2] + [" ..."]) # replace with a pass statement. assert code_pass.strip() == out.strip() @@ -60,7 +60,7 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256): view def test(_owner: address): nonpayable """ - out = compile_code(code, contract_name="One.vy", output_formats=["external_interface"])[ + out = compile_code(code, contract_path="One.vy", output_formats=["external_interface"])[ "external_interface" ] @@ -85,14 +85,14 @@ def test_external_interface_parsing(make_input_bundle, assert_compile_failed): interface_code = """ @external def foo() -> uint256: - pass + ... @external def bar() -> uint256: - pass + ... """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) code = """ import a as FooBarInterface @@ -121,9 +121,8 @@ def foo() -> uint256: """ - assert_compile_failed( - lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation - ) + with pytest.raises(InterfaceViolation): + compile_code(not_implemented_code, input_bundle=input_bundle) def test_missing_event(make_input_bundle, assert_compile_failed): @@ -132,7 +131,7 @@ def test_missing_event(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -156,7 +155,7 @@ def test_malformed_event(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -183,7 +182,7 @@ def test_malformed_events_indexed(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -211,7 +210,7 @@ def test_malformed_events_indexed2(make_input_bundle, assert_compile_failed): a: indexed(uint256) """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -234,13 +233,13 @@ def bar() -> uint256: VALID_IMPORT_CODE = [ # import statement, import path without suffix - ("import a as Foo", "a.vy"), - ("import b.a as Foo", "b/a.vy"), - ("import Foo as Foo", "Foo.vy"), - ("from a import Foo", "a/Foo.vy"), - ("from b.a import Foo", "b/a/Foo.vy"), - ("from .a import Foo", "./a/Foo.vy"), - ("from ..a import Foo", "../a/Foo.vy"), + ("import a as Foo", "a.vyi"), + ("import b.a as Foo", "b/a.vyi"), + ("import Foo as Foo", "Foo.vyi"), + ("from a import Foo", "a/Foo.vyi"), + ("from b.a import Foo", "b/a/Foo.vyi"), + ("from .a import Foo", "./a/Foo.vyi"), + ("from ..a import Foo", "../a/Foo.vyi"), ] @@ -252,11 +251,12 @@ def test_extract_file_interface_imports(code, filename, make_input_bundle): BAD_IMPORT_CODE = [ - ("import a", StructureException), # must alias absolute imports - ("import a as A\nimport a as A", NamespaceCollision), + ("import a as A\nimport a as A", DuplicateImport), + ("import a as A\nimport a as a", DuplicateImport), + ("from . import a\nimport a as a", DuplicateImport), + ("import a as a\nfrom . import a", DuplicateImport), ("from b import a\nfrom . import a", NamespaceCollision), - ("from . import a\nimport a as a", NamespaceCollision), - ("import a as a\nfrom . import a", NamespaceCollision), + ("import a\nimport c as a", NamespaceCollision), ] @@ -264,34 +264,50 @@ def test_extract_file_interface_imports(code, filename, make_input_bundle): def test_extract_file_interface_imports_raises( code, exception_type, assert_compile_failed, make_input_bundle ): - input_bundle = make_input_bundle({"a.vy": "", "b/a.vy": ""}) # dummy - assert_compile_failed(lambda: compile_code(code, input_bundle=input_bundle), exception_type) + input_bundle = make_input_bundle({"a.vyi": "", "b/a.vyi": "", "c.vyi": ""}) + with pytest.raises(exception_type): + compile_code(code, input_bundle=input_bundle) def test_external_call_to_interface(w3, get_contract, make_input_bundle): + token_interface = """ +@view +@external +def balanceOf(addr: address) -> uint256: + ... + +@external +def transfer(to: address, amount: uint256): + ... + """ + token_code = """ +import itoken as IToken + +implements: IToken + balanceOf: public(HashMap[address, uint256]) @external -def transfer(to: address, _value: uint256): - self.balanceOf[to] += _value +def transfer(to: address, amount: uint256): + self.balanceOf[to] += amount """ - input_bundle = make_input_bundle({"one.vy": token_code}) + input_bundle = make_input_bundle({"token.vy": token_code, "itoken.vyi": token_interface}) code = """ -import one as TokenCode +import itoken as IToken interface EPI: def test() -> uint256: view -token_address: TokenCode +token_address: IToken @external def __init__(_token_address: address): - self.token_address = TokenCode(_token_address) + self.token_address = IToken(_token_address) @external @@ -299,14 +315,15 @@ def test(): self.token_address.transfer(msg.sender, 1000) """ - erc20 = get_contract(token_code) - test_c = get_contract(code, *[erc20.address], input_bundle=input_bundle) + token = get_contract(token_code, input_bundle=input_bundle) + + test_c = get_contract(code, *[token.address], input_bundle=input_bundle) sender = w3.eth.accounts[0] - assert erc20.balanceOf(sender) == 0 + assert token.balanceOf(sender) == 0 test_c.test(transact={}) - assert erc20.balanceOf(sender) == 1000 + assert token.balanceOf(sender) == 1000 @pytest.mark.parametrize( @@ -320,26 +337,36 @@ def test(): ], ) def test_external_call_to_interface_kwarg(get_contract, kwarg, typ, expected, make_input_bundle): - code_a = f""" + interface_code = f""" +@external +@view +def foo(_max: {typ} = {kwarg}) -> {typ}: + ... + """ + code1 = f""" +import one as IContract + +implements: IContract + @external @view def foo(_max: {typ} = {kwarg}) -> {typ}: return _max """ - input_bundle = make_input_bundle({"one.vy": code_a}) + input_bundle = make_input_bundle({"one.vyi": interface_code}) - code_b = f""" -import one as ContractA + code2 = f""" +import one as IContract @external @view def bar(a_address: address) -> {typ}: - return ContractA(a_address).foo() + return IContract(a_address).foo() """ - contract_a = get_contract(code_a) - contract_b = get_contract(code_b, *[contract_a.address], input_bundle=input_bundle) + contract_a = get_contract(code1, input_bundle=input_bundle) + contract_b = get_contract(code2, *[contract_a.address], input_bundle=input_bundle) assert contract_b.bar(contract_a.address) == expected @@ -349,8 +376,8 @@ def test_external_call_to_builtin_interface(w3, get_contract): balanceOf: public(HashMap[address, uint256]) @external -def transfer(to: address, _value: uint256) -> bool: - self.balanceOf[to] += _value +def transfer(to: address, amount: uint256) -> bool: + self.balanceOf[to] += amount return True """ @@ -510,14 +537,14 @@ def returns_Bytes3() -> Bytes[3]: """ should_not_compile = """ -import BadJSONInterface as BadJSONInterface +import BadJSONInterface @external def foo(x: BadJSONInterface) -> Bytes[2]: return slice(x.returns_Bytes3(), 0, 2) """ code = """ -import BadJSONInterface as BadJSONInterface +import BadJSONInterface foo: BadJSONInterface @@ -578,10 +605,10 @@ def balanceOf(owner: address) -> uint256: @external @view def balanceOf(owner: address) -> uint256: - pass + ... """ - input_bundle = make_input_bundle({"balanceof.vy": interface_code}) + input_bundle = make_input_bundle({"balanceof.vyi": interface_code}) c = get_contract(code, input_bundle=input_bundle) @@ -592,7 +619,7 @@ def test_simple_implements(make_input_bundle): interface_code = """ @external def foo() -> uint256: - pass + ... """ code = """ @@ -605,7 +632,7 @@ def foo() -> uint256: return 1 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) assert compile_code(code, input_bundle=input_bundle) is not None diff --git a/tests/functional/codegen/test_selector_table_stability.py b/tests/functional/codegen/test_selector_table_stability.py index 3302ff5009..27f82416d6 100644 --- a/tests/functional/codegen/test_selector_table_stability.py +++ b/tests/functional/codegen/test_selector_table_stability.py @@ -14,7 +14,7 @@ def test_dense_jumptable_stability(): # test that the selector table data is stable across different runs # (tox should provide different PYTHONHASHSEEDs). - expected_asm = """{ DATA _sym_BUCKET_HEADERS b'\\x0bB' _sym_bucket_0 b'\\n' b'+\\x8d' _sym_bucket_1 b'\\x0c' b'\\x00\\x85' _sym_bucket_2 b'\\x08' } { DATA _sym_bucket_1 b'\\xd8\\xee\\xa1\\xe8' _sym_external_foo6___3639517672 b'\\x05' b'\\xd2\\x9e\\xe0\\xf9' _sym_external_foo0___3533627641 b'\\x05' b'\\x05\\xf1\\xe0_' _sym_external_foo2___99737695 b'\\x05' b'\\x91\\t\\xb4{' _sym_external_foo23___2433332347 b'\\x05' b'np3\\x7f' _sym_external_foo11___1852846975 b'\\x05' b'&\\xf5\\x96\\xf9' _sym_external_foo13___653629177 b'\\x05' b'\\x04ga\\xeb' _sym_external_foo14___73884139 b'\\x05' b'\\x89\\x06\\xad\\xc6' _sym_external_foo17___2298916294 b'\\x05' b'\\xe4%\\xac\\xd1' _sym_external_foo4___3827674321 b'\\x05' b'yj\\x01\\xac' _sym_external_foo7___2036990380 b'\\x05' b'\\xf1\\xe6K\\xe5' _sym_external_foo29___4058401765 b'\\x05' b'\\xd2\\x89X\\xb8' _sym_external_foo3___3532216504 b'\\x05' } { DATA _sym_bucket_2 b'\\x06p\\xffj' _sym_external_foo25___108068714 b'\\x05' b'\\x964\\x99I' _sym_external_foo24___2520029513 b'\\x05' b's\\x81\\xe7\\xc1' _sym_external_foo10___1937893313 b'\\x05' b'\\x85\\xad\\xc11' _sym_external_foo28___2242756913 b'\\x05' b'\\xfa"\\xb1\\xed' _sym_external_foo5___4196577773 b'\\x05' b'A\\xe7[\\x05' _sym_external_foo22___1105681157 b'\\x05' b'\\xd3\\x89U\\xe8' _sym_external_foo1___3548993000 b'\\x05' b'hL\\xf8\\xf3' _sym_external_foo20___1749874931 b'\\x05' } { DATA _sym_bucket_0 b'\\xee\\xd9\\x1d\\xe3' _sym_external_foo9___4007206371 b'\\x05' b'a\\xbc\\x1ch' _sym_external_foo16___1639717992 b'\\x05' b'\\xd3*\\xa7\\x0c' _sym_external_foo21___3542787852 b'\\x05' b'\\x18iG\\xd9' _sym_external_foo19___409552857 b'\\x05' b'\\n\\xf1\\xf9\\x7f' _sym_external_foo18___183630207 b'\\x05' b')\\xda\\xd7`' _sym_external_foo27___702207840 b'\\x05' b'2\\xf6\\xaa\\xda' _sym_external_foo12___855026394 b'\\x05' b'\\xbe\\xb5\\x05\\xf5' _sym_external_foo15___3199534581 b'\\x05' b'\\xfc\\xa7_\\xe6' _sym_external_foo8___4238827494 b'\\x05' b'\\x1b\\x12C8' _sym_external_foo26___454181688 b'\\x05' } }""" # noqa: E501 + expected_asm = """{ DATA _sym_BUCKET_HEADERS b\'\\x0bB\' _sym_bucket_0 b\'\\n\' b\'+\\x8d\' _sym_bucket_1 b\'\\x0c\' b\'\\x00\\x85\' _sym_bucket_2 b\'\\x08\' } { DATA _sym_bucket_1 b\'\\xd8\\xee\\xa1\\xe8\' _sym_external 6 foo6()3639517672 b\'\\x05\' b\'\\xd2\\x9e\\xe0\\xf9\' _sym_external 0 foo0()3533627641 b\'\\x05\' b\'\\x05\\xf1\\xe0_\' _sym_external 2 foo2()99737695 b\'\\x05\' b\'\\x91\\t\\xb4{\' _sym_external 23 foo23()2433332347 b\'\\x05\' b\'np3\\x7f\' _sym_external 11 foo11()1852846975 b\'\\x05\' b\'&\\xf5\\x96\\xf9\' _sym_external 13 foo13()653629177 b\'\\x05\' b\'\\x04ga\\xeb\' _sym_external 14 foo14()73884139 b\'\\x05\' b\'\\x89\\x06\\xad\\xc6\' _sym_external 17 foo17()2298916294 b\'\\x05\' b\'\\xe4%\\xac\\xd1\' _sym_external 4 foo4()3827674321 b\'\\x05\' b\'yj\\x01\\xac\' _sym_external 7 foo7()2036990380 b\'\\x05\' b\'\\xf1\\xe6K\\xe5\' _sym_external 29 foo29()4058401765 b\'\\x05\' b\'\\xd2\\x89X\\xb8\' _sym_external 3 foo3()3532216504 b\'\\x05\' } { DATA _sym_bucket_2 b\'\\x06p\\xffj\' _sym_external 25 foo25()108068714 b\'\\x05\' b\'\\x964\\x99I\' _sym_external 24 foo24()2520029513 b\'\\x05\' b\'s\\x81\\xe7\\xc1\' _sym_external 10 foo10()1937893313 b\'\\x05\' b\'\\x85\\xad\\xc11\' _sym_external 28 foo28()2242756913 b\'\\x05\' b\'\\xfa"\\xb1\\xed\' _sym_external 5 foo5()4196577773 b\'\\x05\' b\'A\\xe7[\\x05\' _sym_external 22 foo22()1105681157 b\'\\x05\' b\'\\xd3\\x89U\\xe8\' _sym_external 1 foo1()3548993000 b\'\\x05\' b\'hL\\xf8\\xf3\' _sym_external 20 foo20()1749874931 b\'\\x05\' } { DATA _sym_bucket_0 b\'\\xee\\xd9\\x1d\\xe3\' _sym_external 9 foo9()4007206371 b\'\\x05\' b\'a\\xbc\\x1ch\' _sym_external 16 foo16()1639717992 b\'\\x05\' b\'\\xd3*\\xa7\\x0c\' _sym_external 21 foo21()3542787852 b\'\\x05\' b\'\\x18iG\\xd9\' _sym_external 19 foo19()409552857 b\'\\x05\' b\'\\n\\xf1\\xf9\\x7f\' _sym_external 18 foo18()183630207 b\'\\x05\' b\')\\xda\\xd7`\' _sym_external 27 foo27()702207840 b\'\\x05\' b\'2\\xf6\\xaa\\xda\' _sym_external 12 foo12()855026394 b\'\\x05\' b\'\\xbe\\xb5\\x05\\xf5\' _sym_external 15 foo15()3199534581 b\'\\x05\' b\'\\xfc\\xa7_\\xe6\' _sym_external 8 foo8()4238827494 b\'\\x05\' b\'\\x1b\\x12C8\' _sym_external 26 foo26()454181688 b\'\\x05\' } }""" # noqa: E501 assert expected_asm in output["asm"] diff --git a/tests/functional/codegen/test_stateless_modules.py b/tests/functional/codegen/test_stateless_modules.py new file mode 100644 index 0000000000..8e634e5868 --- /dev/null +++ b/tests/functional/codegen/test_stateless_modules.py @@ -0,0 +1,335 @@ +import hypothesis.strategies as st +import pytest +from hypothesis import given, settings + +from vyper import compiler +from vyper.exceptions import ( + CallViolation, + DuplicateImport, + ImportCycle, + StructureException, + TypeMismatch, +) + +# test modules which have no variables - "libraries" + + +def test_simple_library(get_contract, make_input_bundle, w3): + library_source = """ +@internal +def foo() -> uint256: + return block.number + 1 + """ + main = """ +import library + +@external +def bar() -> uint256: + return library.foo() - 1 + """ + input_bundle = make_input_bundle({"library.vy": library_source}) + + c = get_contract(main, input_bundle=input_bundle) + + assert c.bar() == w3.eth.block_number + + +# is this the best place for this? +def test_import_cycle(make_input_bundle): + code_a = "import b\n" + code_b = "import a\n" + + input_bundle = make_input_bundle({"a.vy": code_a, "b.vy": code_b}) + + with pytest.raises(ImportCycle): + compiler.compile_code(code_a, input_bundle=input_bundle) + + +# test we can have a function in the library with the same name as +# in the main contract +def test_library_function_same_name(get_contract, make_input_bundle): + library = """ +@internal +def foo() -> uint256: + return 10 + """ + + main = """ +import library + +@internal +def foo() -> uint256: + return 100 + +@external +def self_foo() -> uint256: + return self.foo() + +@external +def library_foo() -> uint256: + return library.foo() + """ + + input_bundle = make_input_bundle({"library.vy": library}) + + c = get_contract(main, input_bundle=input_bundle) + + assert c.self_foo() == 100 + assert c.library_foo() == 10 + + +def test_transitive_import(get_contract, make_input_bundle): + a = """ +@internal +def foo() -> uint256: + return 1 + """ + b = """ +import a + +@internal +def bar() -> uint256: + return a.foo() + 1 + """ + c = """ +import b + +@external +def baz() -> uint256: + return b.bar() + 1 + """ + # more complicated call graph, with `a` imported twice. + d = """ +import b +import a + +@external +def qux() -> uint256: + s: uint256 = a.foo() + return s + b.bar() + 1 + """ + input_bundle = make_input_bundle({"a.vy": a, "b.vy": b, "c.vy": c, "d.vy": d}) + + contract = get_contract(c, input_bundle=input_bundle) + assert contract.baz() == 3 + contract = get_contract(d, input_bundle=input_bundle) + assert contract.qux() == 4 + + +def test_cannot_call_library_external_functions(make_input_bundle): + library_source = """ +@external +def foo(): + pass + """ + contract_source = """ +import library + +@external +def bar(): + library.foo() + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(CallViolation): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_library_external_functions_not_in_abi(get_contract, make_input_bundle): + library_source = """ +@external +def foo(): + pass + """ + contract_source = """ +import library + +@external +def bar(): + pass + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + c = get_contract(contract_source, input_bundle=input_bundle) + assert not hasattr(c, "foo") + + +def test_library_structs(get_contract, make_input_bundle): + library_source = """ +struct SomeStruct: + x: uint256 + +@internal +def foo() -> SomeStruct: + return SomeStruct({x: 1}) + """ + contract_source = """ +import library + +@external +def bar(s: library.SomeStruct): + pass + +@external +def baz() -> library.SomeStruct: + return library.SomeStruct({x: 2}) + +@external +def qux() -> library.SomeStruct: + return library.foo() + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + c = get_contract(contract_source, input_bundle=input_bundle) + + assert c.bar((1,)) == [] + + assert c.baz() == (2,) + assert c.qux() == (1,) + + +# test calls to library functions in statement position +def test_library_statement_calls(get_contract, make_input_bundle, assert_tx_failed): + library_source = """ +from vyper.interfaces import ERC20 +@internal +def check_adds_to_ten(x: uint256, y: uint256): + assert x + y == 10 + """ + contract_source = """ +import library + +counter: public(uint256) + +@external +def foo(x: uint256): + library.check_adds_to_ten(3, x) + self.counter = x + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + + c = get_contract(contract_source, input_bundle=input_bundle) + + c.foo(7, transact={}) + + assert c.counter() == 7 + + assert_tx_failed(lambda: c.foo(8)) + + +def test_library_is_typechecked(make_input_bundle): + library_source = """ +@internal +def foo(): + asdlkfjasdflkajsdf + """ + contract_source = """ +import library + """ + + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(StructureException): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_library_is_typechecked2(make_input_bundle): + # check that we typecheck against imported function signatures + library_source = """ +@internal +def foo() -> uint256: + return 1 + """ + contract_source = """ +import library + +@external +def foo() -> bytes32: + return library.foo() + """ + + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(TypeMismatch): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_reject_duplicate_imports(make_input_bundle): + library_source = """ + """ + + contract_source = """ +import library +import library as library2 + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(DuplicateImport): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_nested_module_access(get_contract, make_input_bundle): + lib1 = """ +import lib2 + +@internal +def lib2_foo() -> uint256: + return lib2.foo() + """ + lib2 = """ +@internal +def foo() -> uint256: + return 1337 + """ + + main = """ +import lib1 +import lib2 + +@external +def lib1_foo() -> uint256: + return lib1.lib2_foo() + +@external +def lib2_foo() -> uint256: + return lib1.lib2.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.lib1_foo() == c.lib2_foo() == 1337 + + +_int_127 = st.integers(min_value=0, max_value=127) +_bytes_128 = st.binary(min_size=0, max_size=128) + + +def test_slice_builtin(get_contract, make_input_bundle): + lib = """ +@internal +def slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return slice(x, start, length) + """ + + main = """ +import lib +@external +def lib_slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return lib.slice_input(x, start, length) + +@external +def slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return slice(x, start, length) + """ + input_bundle = make_input_bundle({"lib.vy": lib}) + c = get_contract(main, input_bundle=input_bundle) + + # use an inner test so that we can cache the result of get_contract() + @given(start=_int_127, length=_int_127, bytesdata=_bytes_128) + @settings(max_examples=100) + def _test(bytesdata, start, length): + # surjectively map start into allowable range + if start > len(bytesdata): + start = start % (len(bytesdata) or 1) + # surjectively map length into allowable range + if length > (len(bytesdata) - start): + length = length % ((len(bytesdata) - start) or 1) + main_result = c.slice_input(bytesdata, start, length) + library_result = c.lib_slice_input(bytesdata, start, length) + assert main_result == library_result == bytesdata[start : start + length] + + _test() diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index aa0286cfa5..7dd8c35929 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -92,7 +92,7 @@ def from_grammar() -> st.SearchStrategy[str]: # Avoid examples with *only* single or double quote docstrings -# because they trigger a trivial compiler bug +# because they trigger a trivial parser bug SINGLE_QUOTE_DOCSTRING = re.compile(r"^'''.*'''$") DOUBLE_QUOTE_DOCSTRING = re.compile(r'^""".*"""$') diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 9100389dbd..a672ed7b88 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -376,17 +376,12 @@ def test_interfaces_success(good_code): def test_imports_and_implements_within_interface(make_input_bundle): interface_code = """ -from vyper.interfaces import ERC20 -import foo.bar as Baz - -implements: Baz - @external def foobar(): - pass + ... """ - input_bundle = make_input_bundle({"foo.vy": interface_code}) + input_bundle = make_input_bundle({"foo.vyi": interface_code}) code = """ import foo as Foo diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index 47483c493c..d413340083 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -37,9 +37,9 @@ def foo(): @pytest.mark.parametrize("code", code_invalid_checksum) -def test_invalid_checksum(code): +def test_invalid_checksum(code, dummy_input_bundle): vyper_module = vy_ast.parse_to_ast(code) with pytest.raises(InvalidLiteral): vy_ast.validation.validate_literal_nodes(vyper_module) - semantics.validate_semantics(vyper_module, {}) + semantics.validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 68a07178bb..16ce6fe631 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -1,7 +1,6 @@ import ast as python_ast -from vyper.ast.annotation import annotate_python_ast -from vyper.ast.pre_parser import pre_parse +from vyper.ast.parse import annotate_python_ast, pre_parse class AssertionVisitor(python_ast.NodeVisitor): diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 1f60c9ac8b..dc49f72561 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -1,7 +1,8 @@ import json from vyper import compiler -from vyper.ast.utils import ast_to_dict, dict_to_ast, parse_to_ast +from vyper.ast.parse import parse_to_ast +from vyper.ast.utils import ast_to_dict, dict_to_ast def get_node_ids(ast_struct, ids=None): @@ -40,7 +41,7 @@ def test_basic_ast(): code = """ a: int128 """ - dict_out = compiler.compile_code(code, output_formats=["ast_dict"]) + dict_out = compiler.compile_code(code, output_formats=["ast_dict"], source_id=0) assert dict_out["ast_dict"]["ast"]["body"][0] == { "annotation": { "ast_type": "Name", @@ -89,7 +90,7 @@ def foo() -> uint256: view def foo() -> uint256: return 1 """ - dict_out = compiler.compile_code(code, output_formats=["ast_dict"]) + dict_out = compiler.compile_code(code, output_formats=["ast_dict"], source_id=0) assert dict_out["ast_dict"]["ast"]["body"][1] == { "col_offset": 0, "annotation": { diff --git a/tests/unit/ast/test_parser.py b/tests/unit/ast/test_parser.py index c47bf40bfa..e0bfcbc2ef 100644 --- a/tests/unit/ast/test_parser.py +++ b/tests/unit/ast/test_parser.py @@ -1,4 +1,4 @@ -from vyper.ast.utils import parse_to_ast +from vyper.ast.parse import parse_to_ast def test_ast_equal(): diff --git a/tests/unit/cli/outputs/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py similarity index 100% rename from tests/unit/cli/outputs/test_storage_layout.py rename to tests/unit/cli/storage_layout/test_storage_layout.py diff --git a/tests/unit/cli/outputs/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py similarity index 98% rename from tests/unit/cli/outputs/test_storage_layout_overrides.py rename to tests/unit/cli/storage_layout/test_storage_layout_overrides.py index 94e0faeb37..f4c11b7ae6 100644 --- a/tests/unit/cli/outputs/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -103,7 +103,7 @@ def test_overflow(): storage_layout_override = {"x": {"slot": 2**256 - 1, "type": "uint256[2]"}} with pytest.raises( - StorageLayoutException, match=f"Invalid storage slot for var x, out of bounds: {2**256}\n" + StorageLayoutException, match=f"Invalid storage slot for var x, out of bounds: {2**256}" ): compile_code( code, output_formats=["layout"], storage_layout_override=storage_layout_override diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index 2a16efa777..f6e3a51a4b 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -30,93 +30,100 @@ def test_invalid_root_path(): compile_files([], [], root_folder="path/that/does/not/exist") -FOO_CODE = """ -{} - -struct FooStruct: - foo_: uint256 +CONTRACT_CODE = """ +{import_stmt} @external -def foo() -> FooStruct: - return FooStruct({{foo_: 13}}) +def foo() -> {alias}.FooStruct: + return {alias}.FooStruct({{foo_: 13}}) @external -def bar(a: address) -> FooStruct: - return {}(a).bar() +def bar(a: address) -> {alias}.FooStruct: + return {alias}(a).bar() """ -BAR_CODE = """ +INTERFACE_CODE = """ struct FooStruct: foo_: uint256 + +@external +def foo() -> FooStruct: + ... + @external def bar() -> FooStruct: - return FooStruct({foo_: 13}) + ... """ SAME_FOLDER_IMPORT_STMT = [ - ("import Bar as Bar", "Bar"), - ("import contracts.Bar as Bar", "Bar"), - ("from . import Bar", "Bar"), - ("from contracts import Bar", "Bar"), - ("from ..contracts import Bar", "Bar"), - ("from . import Bar as FooBar", "FooBar"), - ("from contracts import Bar as FooBar", "FooBar"), - ("from ..contracts import Bar as FooBar", "FooBar"), + ("import IFoo as IFoo", "IFoo"), + ("import contracts.IFoo as IFoo", "IFoo"), + ("from . import IFoo", "IFoo"), + ("from contracts import IFoo", "IFoo"), + ("from ..contracts import IFoo", "IFoo"), + ("from . import IFoo as FooBar", "FooBar"), + ("from contracts import IFoo as FooBar", "FooBar"), + ("from ..contracts import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt,alias", SAME_FOLDER_IMPORT_STMT) def test_import_same_folder(import_stmt, alias, tmp_path, make_file): foo = "contracts/foo.vy" - make_file("contracts/foo.vy", FOO_CODE.format(import_stmt, alias)) - make_file("contracts/Bar.vy", BAR_CODE) + make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + make_file("contracts/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) SUBFOLDER_IMPORT_STMT = [ - ("import other.Bar as Bar", "Bar"), - ("import contracts.other.Bar as Bar", "Bar"), - ("from other import Bar", "Bar"), - ("from contracts.other import Bar", "Bar"), - ("from .other import Bar", "Bar"), - ("from ..contracts.other import Bar", "Bar"), - ("from other import Bar as FooBar", "FooBar"), - ("from contracts.other import Bar as FooBar", "FooBar"), - ("from .other import Bar as FooBar", "FooBar"), - ("from ..contracts.other import Bar as FooBar", "FooBar"), + ("import other.IFoo as IFoo", "IFoo"), + ("import contracts.other.IFoo as IFoo", "IFoo"), + ("from other import IFoo", "IFoo"), + ("from contracts.other import IFoo", "IFoo"), + ("from .other import IFoo", "IFoo"), + ("from ..contracts.other import IFoo", "IFoo"), + ("from other import IFoo as FooBar", "FooBar"), + ("from contracts.other import IFoo as FooBar", "FooBar"), + ("from .other import IFoo as FooBar", "FooBar"), + ("from ..contracts.other import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt, alias", SUBFOLDER_IMPORT_STMT) def test_import_subfolder(import_stmt, alias, tmp_path, make_file): - foo = make_file("contracts/foo.vy", (FOO_CODE.format(import_stmt, alias))) - make_file("contracts/other/Bar.vy", BAR_CODE) + foo = make_file( + "contracts/foo.vy", (CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + ) + make_file("contracts/other/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) OTHER_FOLDER_IMPORT_STMT = [ - ("import interfaces.Bar as Bar", "Bar"), - ("from interfaces import Bar", "Bar"), - ("from ..interfaces import Bar", "Bar"), - ("from interfaces import Bar as FooBar", "FooBar"), - ("from ..interfaces import Bar as FooBar", "FooBar"), + ("import interfaces.IFoo as IFoo", "IFoo"), + ("from interfaces import IFoo", "IFoo"), + ("from ..interfaces import IFoo", "IFoo"), + ("from interfaces import IFoo as FooBar", "FooBar"), + ("from ..interfaces import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt, alias", OTHER_FOLDER_IMPORT_STMT) def test_import_other_folder(import_stmt, alias, tmp_path, make_file): - foo = make_file("contracts/foo.vy", FOO_CODE.format(import_stmt, alias)) - make_file("interfaces/Bar.vy", BAR_CODE) + foo = make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + make_file("interfaces/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) def test_import_parent_folder(tmp_path, make_file): - foo = make_file("contracts/baz/foo.vy", FOO_CODE.format("from ... import Bar", "Bar")) - make_file("Bar.vy", BAR_CODE) + foo = make_file( + "contracts/baz/foo.vy", + CONTRACT_CODE.format(import_stmt="from ... import IFoo", alias="IFoo"), + ) + make_file("IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) @@ -125,62 +132,60 @@ def test_import_parent_folder(tmp_path, make_file): META_IMPORT_STMT = [ - "import Meta as Meta", - "import contracts.Meta as Meta", - "from . import Meta", - "from contracts import Meta", + "import ISelf as ISelf", + "import contracts.ISelf as ISelf", + "from . import ISelf", + "from contracts import ISelf", ] @pytest.mark.parametrize("import_stmt", META_IMPORT_STMT) def test_import_self_interface(import_stmt, tmp_path, make_file): - # a contract can access its derived interface by importing itself - code = f""" -{import_stmt} - + interface_code = """ struct FooStruct: foo_: uint256 @external def know_thyself(a: address) -> FooStruct: - return Meta(a).be_known() + ... @external def be_known() -> FooStruct: - return FooStruct({{foo_: 42}}) + ... """ - meta = make_file("contracts/Meta.vy", code) - - assert compile_files([meta], ["combined_json"], root_folder=tmp_path) + code = f""" +{import_stmt} +@external +def know_thyself(a: address) -> ISelf.FooStruct: + return ISelf(a).be_known() -DERIVED_IMPORT_STMT_BAZ = ["import Foo as Foo", "from . import Foo"] +@external +def be_known() -> ISelf.FooStruct: + return ISelf.FooStruct({{foo_: 42}}) + """ + make_file("contracts/ISelf.vyi", interface_code) + meta = make_file("contracts/Self.vy", code) -DERIVED_IMPORT_STMT_FOO = ["import Bar as Bar", "from . import Bar"] + assert compile_files([meta], ["combined_json"], root_folder=tmp_path) -@pytest.mark.parametrize("import_stmt_baz", DERIVED_IMPORT_STMT_BAZ) -@pytest.mark.parametrize("import_stmt_foo", DERIVED_IMPORT_STMT_FOO) -def test_derived_interface_imports(import_stmt_baz, import_stmt_foo, tmp_path, make_file): - # contracts-as-interfaces should be able to contain import statements +# implement IFoo in another contract for fun +@pytest.mark.parametrize("import_stmt_foo,alias", SAME_FOLDER_IMPORT_STMT) +def test_another_interface_implementation(import_stmt_foo, alias, tmp_path, make_file): baz_code = f""" -{import_stmt_baz} - -struct FooStruct: - foo_: uint256 +{import_stmt_foo} @external -def foo(a: address) -> FooStruct: - return Foo(a).foo() +def foo(a: address) -> {alias}.FooStruct: + return {alias}(a).foo() @external -def bar(_foo: address, _bar: address) -> FooStruct: - return Foo(_foo).bar(_bar) +def bar(_foo: address) -> {alias}.FooStruct: + return {alias}(_foo).bar() """ - - make_file("Foo.vy", FOO_CODE.format(import_stmt_foo, "Bar")) - make_file("Bar.vy", BAR_CODE) - baz = make_file("Baz.vy", baz_code) + make_file("contracts/IFoo.vyi", INTERFACE_CODE) + baz = make_file("contracts/Baz.vy", baz_code) assert compile_files([baz], ["combined_json"], root_folder=tmp_path) @@ -207,15 +212,36 @@ def test_local_namespace(make_file, tmp_path): make_file(filename, code) paths.append(filename) - for file_name in ("foo.vy", "bar.vy"): - make_file(file_name, BAR_CODE) + for file_name in ("foo.vyi", "bar.vyi"): + make_file(file_name, INTERFACE_CODE) assert compile_files(paths, ["combined_json"], root_folder=tmp_path) def test_compile_outside_root_path(tmp_path, make_file): # absolute paths relative to "." - foo = make_file("foo.vy", FOO_CODE.format("import bar as Bar", "Bar")) - bar = make_file("bar.vy", BAR_CODE) + make_file("ifoo.vyi", INTERFACE_CODE) + foo = make_file("foo.vy", CONTRACT_CODE.format(import_stmt="import ifoo as IFoo", alias="IFoo")) + + assert compile_files([foo], ["combined_json"], root_folder=".") + + +def test_import_library(tmp_path, make_file): + library_source = """ +@internal +def foo() -> uint256: + return block.number + 1 + """ + + contract_source = """ +import lib + +@external +def foo() -> uint256: + return lib.foo() + """ + + make_file("lib.vy", library_source) + contract_file = make_file("contract.vy", contract_source) - assert compile_files([foo, bar], ["combined_json"], root_folder=".") + assert compile_files([contract_file], ["combined_json"], root_folder=tmp_path) is not None diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index 732762d72b..a50946ba21 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -1,30 +1,55 @@ import json +from pathlib import PurePath import pytest import vyper -from vyper.cli.vyper_json import compile_from_input_dict, compile_json, exc_handler_to_dict -from vyper.compiler import OUTPUT_FORMATS, compile_code +from vyper.cli.vyper_json import ( + compile_from_input_dict, + compile_json, + exc_handler_to_dict, + get_inputs, +) +from vyper.compiler import OUTPUT_FORMATS, compile_code, compile_from_file_input +from vyper.compiler.input_bundle import JSONInputBundle from vyper.exceptions import InvalidType, JSONError, SyntaxException FOO_CODE = """ -import contracts.bar as Bar +import contracts.ibar as IBar + +import contracts.library as library @external def foo(a: address) -> bool: - return Bar(a).bar(1) + return IBar(a).bar(1) @external def baz() -> uint256: - return self.balance + return self.balance + library.foo() """ BAR_CODE = """ +import contracts.ibar as IBar + +implements: IBar + @external def bar(a: uint256) -> bool: return True """ +BAR_VYI = """ +@external +def bar(a: uint256) -> bool: + ... +""" + +LIBRARY_CODE = """ +@internal +def foo() -> uint256: + return block.number + 1 +""" + BAD_SYNTAX_CODE = """ def bar()>: """ @@ -52,6 +77,7 @@ def input_json(): "language": "Vyper", "sources": { "contracts/foo.vy": {"content": FOO_CODE}, + "contracts/library.vy": {"content": LIBRARY_CODE}, "contracts/bar.vy": {"content": BAR_CODE}, }, "interfaces": {"contracts/ibar.json": {"abi": BAR_ABI}}, @@ -59,6 +85,14 @@ def input_json(): } +@pytest.fixture(scope="function") +def input_bundle(input_json): + # CMC 2023-12-11 maybe input_json -> JSONInputBundle should be a helper + # function in `vyper_json.py`. + sources = get_inputs(input_json) + return JSONInputBundle(sources, search_paths=[PurePath(".")]) + + # test string and dict inputs both work def test_string_input(input_json): assert compile_json(input_json) == compile_json(json.dumps(input_json)) @@ -77,29 +111,39 @@ def test_keyerror_becomes_jsonerror(input_json): compile_json(input_json) -def test_compile_json(input_json, make_input_bundle): - input_bundle = make_input_bundle({"contracts/bar.vy": BAR_CODE}) +def test_compile_json(input_json, input_bundle): + foo_input = input_bundle.load_file("contracts/foo.vy") + foo = compile_from_file_input( + foo_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + ) - foo = compile_code( - FOO_CODE, - source_id=0, - contract_name="contracts/foo.vy", - output_formats=OUTPUT_FORMATS, - input_bundle=input_bundle, + library_input = input_bundle.load_file("contracts/library.vy") + library = compile_from_file_input( + library_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle ) - bar = compile_code( - BAR_CODE, source_id=1, contract_name="contracts/bar.vy", output_formats=OUTPUT_FORMATS + + bar_input = input_bundle.load_file("contracts/bar.vy") + bar = compile_from_file_input( + bar_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle ) - compile_code_results = {"contracts/bar.vy": bar, "contracts/foo.vy": foo} + compile_code_results = { + "contracts/bar.vy": bar, + "contracts/library.vy": library, + "contracts/foo.vy": foo, + } output_json = compile_json(input_json) - assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + assert list(output_json["contracts"].keys()) == [ + "contracts/foo.vy", + "contracts/library.vy", + "contracts/bar.vy", + ] assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"] assert output_json["compiler"] == f"vyper-{vyper.__version__}" - for source_id, contract_name in enumerate(["foo", "bar"]): + for source_id, contract_name in [(0, "foo"), (2, "library"), (3, "bar")]: path = f"contracts/{contract_name}.vy" data = compile_code_results[path] assert output_json["sources"][path] == {"id": source_id, "ast": data["ast_dict"]["ast"]} @@ -123,13 +167,28 @@ def test_compile_json(input_json, make_input_bundle): } -def test_different_outputs(make_input_bundle, input_json): +def test_compilation_targets(input_json): + output_json = compile_json(input_json) + assert list(output_json["contracts"].keys()) == [ + "contracts/foo.vy", + "contracts/library.vy", + "contracts/bar.vy", + ] + + # omit library.vy + input_json["settings"]["outputSelection"] = {"contracts/foo.vy": "*", "contracts/bar.vy": "*"} + output_json = compile_json(input_json) + + assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + + +def test_different_outputs(input_bundle, input_json): input_json["settings"]["outputSelection"] = { "contracts/bar.vy": "*", "contracts/foo.vy": ["evm.methodIdentifiers"], } output_json = compile_json(input_json) - assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + assert list(output_json["contracts"].keys()) == ["contracts/bar.vy", "contracts/foo.vy"] assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"] assert output_json["compiler"] == f"vyper-{vyper.__version__}" @@ -143,10 +202,9 @@ def test_different_outputs(make_input_bundle, input_json): assert sorted(foo.keys()) == ["evm"] # check method_identifiers - input_bundle = make_input_bundle({"contracts/bar.vy": BAR_CODE}) method_identifiers = compile_code( FOO_CODE, - contract_name="contracts/foo.vy", + contract_path="contracts/foo.vy", output_formats=["method_identifiers"], input_bundle=input_bundle, )["method_identifiers"] @@ -204,11 +262,12 @@ def get(filename, contractname): return result["contracts"][filename][contractname]["evm"]["deployedBytecode"]["sourceMap"] assert get("contracts/foo.vy", "foo").startswith("-1:-1:0") - assert get("contracts/bar.vy", "bar").startswith("-1:-1:1") + assert get("contracts/library.vy", "library").startswith("-1:-1:2") + assert get("contracts/bar.vy", "bar").startswith("-1:-1:3") def test_relative_import_paths(input_json): - input_json["sources"]["contracts/potato/baz/baz.vy"] = {"content": """from ... import foo"""} - input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": """from . import baz"""} - input_json["sources"]["contracts/potato/footato.vy"] = {"content": """from baz import baz"""} + input_json["sources"]["contracts/potato/baz/baz.vy"] = {"content": "from ... import foo"} + input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": "from . import baz"} + input_json["sources"]["contracts/potato/footato.vy"] = {"content": "from baz import baz"} compile_from_input_dict(input_json) diff --git a/tests/unit/cli/vyper_json/test_get_inputs.py b/tests/unit/cli/vyper_json/test_get_inputs.py index 6e323a91bd..c91cc750f2 100644 --- a/tests/unit/cli/vyper_json/test_get_inputs.py +++ b/tests/unit/cli/vyper_json/test_get_inputs.py @@ -2,7 +2,7 @@ import pytest -from vyper.cli.vyper_json import get_compilation_targets, get_inputs +from vyper.cli.vyper_json import get_inputs from vyper.exceptions import JSONError from vyper.utils import keccak256 @@ -122,9 +122,6 @@ def test_interfaces_output(): "interface.folder/bar2.vy": {"content": BAR_CODE}, }, } - targets = get_compilation_targets(input_json) - assert targets == [PurePath("foo.vy")] - result = get_inputs(input_json) assert result == { PurePath("foo.vy"): {"content": FOO_CODE}, diff --git a/tests/unit/cli/vyper_json/test_output_selection.py b/tests/unit/cli/vyper_json/test_output_selection.py index 78ad7404f2..5383190a66 100644 --- a/tests/unit/cli/vyper_json/test_output_selection.py +++ b/tests/unit/cli/vyper_json/test_output_selection.py @@ -8,53 +8,61 @@ def test_no_outputs(): with pytest.raises(KeyError): - get_output_formats({}, {}) + get_output_formats({}) def test_invalid_output(): - input_json = {"settings": {"outputSelection": {"foo.vy": ["abi", "foobar"]}}} - targets = [PurePath("foo.vy")] + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": ["abi", "foobar"]}}, + } with pytest.raises(JSONError): - get_output_formats(input_json, targets) + get_output_formats(input_json) def test_unknown_contract(): - input_json = {"settings": {"outputSelection": {"bar.vy": ["abi"]}}} - targets = [PurePath("foo.vy")] + input_json = {"sources": {}, "settings": {"outputSelection": {"bar.vy": ["abi"]}}} with pytest.raises(JSONError): - get_output_formats(input_json, targets) + get_output_formats(input_json) @pytest.mark.parametrize("output", TRANSLATE_MAP.items()) def test_translate_map(output): - input_json = {"settings": {"outputSelection": {"foo.vy": [output[0]]}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): [output[1]]} + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": [output[0]]}}, + } + assert get_output_formats(input_json) == {PurePath("foo.vy"): [output[1]]} def test_star(): - input_json = {"settings": {"outputSelection": {"*": ["*"]}}} - targets = [PurePath("foo.vy"), PurePath("bar.vy")] + input_json = { + "sources": {"foo.vy": "", "bar.vy": ""}, + "settings": {"outputSelection": {"*": ["*"]}}, + } expected = sorted(set(TRANSLATE_MAP.values())) - result = get_output_formats(input_json, targets) + result = get_output_formats(input_json) assert result == {PurePath("foo.vy"): expected, PurePath("bar.vy"): expected} def test_evm(): - input_json = {"settings": {"outputSelection": {"foo.vy": ["abi", "evm"]}}} - targets = [PurePath("foo.vy")] + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": ["abi", "evm"]}}, + } expected = ["abi"] + sorted(v for k, v in TRANSLATE_MAP.items() if k.startswith("evm")) - result = get_output_formats(input_json, targets) + result = get_output_formats(input_json) assert result == {PurePath("foo.vy"): expected} def test_solc_style(): - input_json = {"settings": {"outputSelection": {"foo.vy": {"": ["abi"], "foo.vy": ["ir"]}}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): ["abi", "ir_dict"]} + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": {"": ["abi"], "foo.vy": ["ir"]}}}, + } + assert get_output_formats(input_json) == {PurePath("foo.vy"): ["abi", "ir_dict"]} def test_metadata(): - input_json = {"settings": {"outputSelection": {"*": ["metadata"]}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): ["metadata"]} + input_json = {"sources": {"foo.vy": ""}, "settings": {"outputSelection": {"*": ["metadata"]}}} + assert get_output_formats(input_json) == {PurePath("foo.vy"): ["metadata"]} diff --git a/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py b/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py index 3b0f700c7e..6b509dd3ef 100644 --- a/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py +++ b/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py @@ -9,11 +9,11 @@ from vyper.exceptions import JSONError FOO_CODE = """ -import contracts.bar as Bar +import contracts.ibar as IBar @external def foo(a: address) -> bool: - return Bar(a).bar(1) + return IBar(a).bar(1) """ BAR_CODE = """ diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index 47b70a8c70..44b823757c 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -1,5 +1,6 @@ import pytest +from vyper.compiler import compile_code from vyper.compiler.phases import CompilerData from vyper.compiler.settings import OptimizationLevel, Settings @@ -71,33 +72,61 @@ def __init__(): ] +# check dead code eliminator works on unreachable functions @pytest.mark.parametrize("code", codes) def test_dead_code_eliminator(code): c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.NONE)) - initcode_asm = [i for i in c.assembly if not isinstance(i, list)] - runtime_asm = c.assembly_runtime - ctor_only_label = "_sym_internal_ctor_only___" - runtime_only_label = "_sym_internal_runtime_only___" + # get the labels + initcode_asm = [i for i in c.assembly if isinstance(i, str)] + runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] + + ctor_only = "ctor_only()" + runtime_only = "runtime_only()" # qux reachable from unoptimized initcode, foo not reachable. - assert ctor_only_label + "_deploy" in initcode_asm - assert runtime_only_label + "_deploy" not in initcode_asm + assert any(ctor_only in instr for instr in initcode_asm) + assert all(runtime_only not in instr for instr in initcode_asm) # all labels should be in unoptimized runtime asm - for s in (ctor_only_label, runtime_only_label): - assert s + "_runtime" in runtime_asm + for s in (ctor_only, runtime_only): + assert any(s in instr for instr in runtime_asm) c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.GAS)) - initcode_asm = [i for i in c.assembly if not isinstance(i, list)] - runtime_asm = c.assembly_runtime + initcode_asm = [i for i in c.assembly if isinstance(i, str)] + runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] # ctor only label should not be in runtime code - for instr in runtime_asm: - if isinstance(instr, str): - assert not instr.startswith(ctor_only_label), instr + assert all(ctor_only not in instr for instr in runtime_asm) # runtime only label should not be in initcode asm - for instr in initcode_asm: - if isinstance(instr, str): - assert not instr.startswith(runtime_only_label), instr + assert all(runtime_only not in instr for instr in initcode_asm) + + +def test_library_code_eliminator(make_input_bundle): + library = """ +@internal +def unused1(): + pass + +@internal +def unused2(): + self.unused1() + +@internal +def some_function(): + pass + """ + code = """ +import library + +@external +def foo(): + library.some_function() + """ + input_bundle = make_input_bundle({"library.vy": library}) + res = compile_code(code, input_bundle=input_bundle, output_formats=["asm"]) + asm = res["asm"] + assert "some_function()" in asm + assert "unused1()" not in asm + assert "unused2()" not in asm diff --git a/tests/unit/compiler/test_input_bundle.py b/tests/unit/compiler/test_input_bundle.py index c49c81219b..e26555b169 100644 --- a/tests/unit/compiler/test_input_bundle.py +++ b/tests/unit/compiler/test_input_bundle.py @@ -1,4 +1,6 @@ +import contextlib import json +import os from pathlib import Path, PurePath import pytest @@ -12,19 +14,19 @@ def input_bundle(tmp_path): return FilesystemInputBundle([tmp_path]) -def test_load_file(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") +def test_load_file(make_file, input_bundle): + filepath = make_file("foo.vy", "contents") file = input_bundle.load_file(Path("foo.vy")) assert isinstance(file, FileInput) - assert file == FileInput(0, tmp_path / Path("foo.vy"), "contents") + assert file == FileInput(0, Path("foo.vy"), filepath, "contents") def test_search_path_context_manager(make_file, tmp_path): ib = FilesystemInputBundle([]) - make_file("foo.vy", "contents") + filepath = make_file("foo.vy", "contents") with pytest.raises(FileNotFoundError): # no search path given @@ -34,7 +36,7 @@ def test_search_path_context_manager(make_file, tmp_path): file = ib.load_file(Path("foo.vy")) assert isinstance(file, FileInput) - assert file == FileInput(0, tmp_path / Path("foo.vy"), "contents") + assert file == FileInput(0, Path("foo.vy"), filepath, "contents") def test_search_path_precedence(make_file, tmp_path, tmp_path_factory, input_bundle): @@ -43,59 +45,85 @@ def test_search_path_precedence(make_file, tmp_path, tmp_path_factory, input_bun tmpdir = tmp_path_factory.mktemp("some_directory") tmpdir2 = tmp_path_factory.mktemp("some_other_directory") + filepaths = [] for i, directory in enumerate([tmp_path, tmpdir, tmpdir2]): - with (directory / "foo.vy").open("w") as f: + path = directory / "foo.vy" + with path.open("w") as f: f.write(f"contents {i}") + filepaths.append(path) ib = FilesystemInputBundle([tmp_path, tmpdir, tmpdir2]) file = ib.load_file("foo.vy") assert isinstance(file, FileInput) - assert file == FileInput(0, tmpdir2 / "foo.vy", "contents 2") + assert file == FileInput(0, "foo.vy", filepaths[2], "contents 2") with ib.search_path(tmpdir): file = ib.load_file("foo.vy") assert isinstance(file, FileInput) - assert file == FileInput(1, tmpdir / "foo.vy", "contents 1") + assert file == FileInput(1, "foo.vy", filepaths[1], "contents 1") # special rules for handling json files def test_load_abi(make_file, input_bundle, tmp_path): contents = json.dumps("some string") - make_file("foo.json", contents) + path = make_file("foo.json", contents) file = input_bundle.load_file("foo.json") assert isinstance(file, ABIInput) - assert file == ABIInput(0, tmp_path / "foo.json", "some string") + assert file == ABIInput(0, "foo.json", path, "some string") # suffix doesn't matter - make_file("foo.txt", contents) - + path = make_file("foo.txt", contents) file = input_bundle.load_file("foo.txt") assert isinstance(file, ABIInput) - assert file == ABIInput(1, tmp_path / "foo.txt", "some string") + assert file == ABIInput(1, "foo.txt", path, "some string") + + +@contextlib.contextmanager +def working_directory(directory): + tmp = os.getcwd() + try: + os.chdir(directory) + yield + finally: + os.chdir(tmp) # check that unique paths give unique source ids def test_source_id_file_input(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") - make_file("bar.vy", "contents 2") + foopath = make_file("foo.vy", "contents") + barpath = make_file("bar.vy", "contents 2") file = input_bundle.load_file("foo.vy") assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "contents") + assert file == FileInput(0, "foo.vy", foopath, "contents") file2 = input_bundle.load_file("bar.vy") # source id increments assert file2.source_id == 1 - assert file2 == FileInput(1, tmp_path / "bar.vy", "contents 2") + assert file2 == FileInput(1, "bar.vy", barpath, "contents 2") file3 = input_bundle.load_file("foo.vy") assert file3.source_id == 0 - assert file3 == FileInput(0, tmp_path / "foo.vy", "contents") + assert file3 == FileInput(0, "foo.vy", foopath, "contents") + + # test source id is stable across different search paths + with working_directory(tmp_path): + with input_bundle.search_path(Path(".")): + file4 = input_bundle.load_file("foo.vy") + assert file4.source_id == 0 + assert file4 == FileInput(0, "foo.vy", foopath, "contents") + + # test source id is stable even when requested filename is different + with working_directory(tmp_path.parent): + with input_bundle.search_path(Path(".")): + file5 = input_bundle.load_file(Path(tmp_path.stem) / "foo.vy") + assert file5.source_id == 0 + assert file5 == FileInput(0, Path(tmp_path.stem) / "foo.vy", foopath, "contents") # check that unique paths give unique source ids @@ -103,37 +131,51 @@ def test_source_id_json_input(make_file, input_bundle, tmp_path): contents = json.dumps("some string") contents2 = json.dumps(["some list"]) - make_file("foo.json", contents) + foopath = make_file("foo.json", contents) - make_file("bar.json", contents2) + barpath = make_file("bar.json", contents2) file = input_bundle.load_file("foo.json") assert isinstance(file, ABIInput) - assert file == ABIInput(0, tmp_path / "foo.json", "some string") + assert file == ABIInput(0, "foo.json", foopath, "some string") file2 = input_bundle.load_file("bar.json") assert isinstance(file2, ABIInput) - assert file2 == ABIInput(1, tmp_path / "bar.json", ["some list"]) + assert file2 == ABIInput(1, "bar.json", barpath, ["some list"]) file3 = input_bundle.load_file("foo.json") - assert isinstance(file3, ABIInput) - assert file3 == ABIInput(0, tmp_path / "foo.json", "some string") + assert file3.source_id == 0 + assert file3 == ABIInput(0, "foo.json", foopath, "some string") + + # test source id is stable across different search paths + with working_directory(tmp_path): + with input_bundle.search_path(Path(".")): + file4 = input_bundle.load_file("foo.json") + assert file4.source_id == 0 + assert file4 == ABIInput(0, "foo.json", foopath, "some string") + + # test source id is stable even when requested filename is different + with working_directory(tmp_path.parent): + with input_bundle.search_path(Path(".")): + file5 = input_bundle.load_file(Path(tmp_path.stem) / "foo.json") + assert file5.source_id == 0 + assert file5 == ABIInput(0, Path(tmp_path.stem) / "foo.json", foopath, "some string") # test some pathological case where the file changes underneath def test_mutating_file_source_id(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") + foopath = make_file("foo.vy", "contents") file = input_bundle.load_file("foo.vy") assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "contents") + assert file == FileInput(0, "foo.vy", foopath, "contents") - make_file("foo.vy", "new contents") + foopath = make_file("foo.vy", "new contents") file = input_bundle.load_file("foo.vy") # source id hasn't changed, even though contents have assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "new contents") + assert file == FileInput(0, "foo.vy", foopath, "new contents") # test the os.normpath behavior of symlink @@ -147,10 +189,12 @@ def test_load_file_symlink(make_file, input_bundle, tmp_path, tmp_path_factory): dir2.mkdir() symlink.symlink_to(dir2, target_is_directory=True) - with (tmp_path / "foo.vy").open("w") as f: - f.write("contents of the upper directory") + outer_path = tmp_path / "foo.vy" + with outer_path.open("w") as f: + f.write("contents of the outer directory") - with (dir1 / "foo.vy").open("w") as f: + inner_path = dir1 / "foo.vy" + with inner_path.open("w") as f: f.write("contents of the inner directory") # symlink rules would be: @@ -159,9 +203,10 @@ def test_load_file_symlink(make_file, input_bundle, tmp_path, tmp_path_factory): # base/first/foo.vy # normpath would be base/symlink/../foo.vy => # base/foo.vy - file = input_bundle.load_file(symlink / ".." / "foo.vy") + to_load = symlink / ".." / "foo.vy" + file = input_bundle.load_file(to_load) - assert file == FileInput(0, tmp_path / "foo.vy", "contents of the upper directory") + assert file == FileInput(0, to_load, outer_path.resolve(), "contents of the outer directory") def test_json_input_bundle_basic(): @@ -169,40 +214,42 @@ def test_json_input_bundle_basic(): input_bundle = JSONInputBundle(files, [PurePath(".")]) file = input_bundle.load_file(PurePath("foo.vy")) - assert file == FileInput(0, PurePath("foo.vy"), "some text") + assert file == FileInput(0, PurePath("foo.vy"), PurePath("foo.vy"), "some text") def test_json_input_bundle_normpath(): - files = {PurePath("foo/../bar.vy"): {"content": "some text"}} + contents = "some text" + files = {PurePath("foo/../bar.vy"): {"content": contents}} input_bundle = JSONInputBundle(files, [PurePath(".")]) - expected = FileInput(0, PurePath("bar.vy"), "some text") + barpath = PurePath("bar.vy") + + expected = FileInput(0, barpath, barpath, contents) file = input_bundle.load_file(PurePath("bar.vy")) assert file == expected file = input_bundle.load_file(PurePath("baz/../bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("baz/../bar.vy"), barpath, contents) file = input_bundle.load_file(PurePath("./bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("./bar.vy"), barpath, contents) with input_bundle.search_path(PurePath("foo")): file = input_bundle.load_file(PurePath("../bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("../bar.vy"), barpath, contents) def test_json_input_abi(): some_abi = ["some abi"] some_abi_str = json.dumps(some_abi) - files = { - PurePath("foo.json"): {"abi": some_abi}, - PurePath("bar.txt"): {"content": some_abi_str}, - } + foopath = PurePath("foo.json") + barpath = PurePath("bar.txt") + files = {foopath: {"abi": some_abi}, barpath: {"content": some_abi_str}} input_bundle = JSONInputBundle(files, [PurePath(".")]) - file = input_bundle.load_file(PurePath("foo.json")) - assert file == ABIInput(0, PurePath("foo.json"), some_abi) + file = input_bundle.load_file(foopath) + assert file == ABIInput(0, foopath, foopath, some_abi) - file = input_bundle.load_file(PurePath("bar.txt")) - assert file == ABIInput(1, PurePath("bar.txt"), some_abi) + file = input_bundle.load_file(barpath) + assert file == ABIInput(1, barpath, barpath, some_abi) diff --git a/tests/unit/semantics/analysis/test_array_index.py b/tests/unit/semantics/analysis/test_array_index.py index 27c0634cf8..5ea373fc19 100644 --- a/tests/unit/semantics/analysis/test_array_index.py +++ b/tests/unit/semantics/analysis/test_array_index.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("value", ["address", "Bytes[10]", "decimal", "bool"]) -def test_type_mismatch(namespace, value): +def test_type_mismatch(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -23,11 +23,11 @@ def foo(b: {value}): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "0x00", "b'\x01'", "False"]) -def test_invalid_literal(namespace, value): +def test_invalid_literal(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -38,11 +38,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidType): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", [-1, 3, -(2**127), 2**127 - 1, 2**256 - 1]) -def test_out_of_bounds(namespace, value): +def test_out_of_bounds(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -53,11 +53,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ArrayIndexException): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["b", "self.b"]) -def test_undeclared_definition(namespace, value): +def test_undeclared_definition(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -68,11 +68,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(UndeclaredDefinition): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["a", "foo", "int128"]) -def test_invalid_reference(namespace, value): +def test_invalid_reference(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -83,4 +83,4 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidReference): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/semantics/analysis/test_cyclic_function_calls.py b/tests/unit/semantics/analysis/test_cyclic_function_calls.py index 2a09bd5ed5..c31146b16f 100644 --- a/tests/unit/semantics/analysis/test_cyclic_function_calls.py +++ b/tests/unit/semantics/analysis/test_cyclic_function_calls.py @@ -3,22 +3,20 @@ from vyper.ast import parse_to_ast from vyper.exceptions import CallViolation, StructureException from vyper.semantics.analysis import validate_semantics -from vyper.semantics.analysis.module import ModuleAnalyzer -def test_self_function_call(namespace): +def test_self_function_call(dummy_input_bundle): code = """ @internal def foo(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_cyclic_function_call(namespace): +def test_cyclic_function_call(dummy_input_bundle): code = """ @internal def foo(): @@ -29,12 +27,11 @@ def bar(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_multi_cyclic_function_call(namespace): +def test_multi_cyclic_function_call(dummy_input_bundle): code = """ @internal def foo(): @@ -53,12 +50,11 @@ def potato(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_global_ann_assign_callable_no_crash(): +def test_global_ann_assign_callable_no_crash(dummy_input_bundle): code = """ balanceOf: public(HashMap[address, uint256]) @@ -68,5 +64,5 @@ def foo(to : address): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException) as excinfo: - validate_semantics(vyper_module, {}) - assert excinfo.value.message == "Value is not callable" + validate_semantics(vyper_module, dummy_input_bundle) + assert excinfo.value.message == "HashMap[address, uint256] is not callable" diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index 0d61a8f8f8..e2c0f555af 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -10,7 +10,7 @@ from vyper.semantics.analysis import validate_semantics -def test_modify_iterator_function_outside_loop(namespace): +def test_modify_iterator_function_outside_loop(dummy_input_bundle): code = """ a: uint256[3] @@ -26,10 +26,10 @@ def bar(): pass """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_pass_memory_var_to_other_function(namespace): +def test_pass_memory_var_to_other_function(dummy_input_bundle): code = """ @internal @@ -46,10 +46,10 @@ def bar(): self.foo(a) """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator(namespace): +def test_modify_iterator(dummy_input_bundle): code = """ a: uint256[3] @@ -61,10 +61,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_bad_keywords(namespace): +def test_bad_keywords(dummy_input_bundle): code = """ @internal @@ -75,10 +75,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(ArgumentException): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_bad_bound(namespace): +def test_bad_bound(dummy_input_bundle): code = """ @internal @@ -89,10 +89,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(StateAccessViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator_function_call(namespace): +def test_modify_iterator_function_call(dummy_input_bundle): code = """ a: uint256[3] @@ -108,10 +108,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator_recursive_function_call(namespace): +def test_modify_iterator_recursive_function_call(dummy_input_bundle): code = """ a: uint256[3] @@ -131,7 +131,7 @@ def baz(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) iterator_inference_codes = [ @@ -169,7 +169,7 @@ def foo(): @pytest.mark.parametrize("code", iterator_inference_codes) -def test_iterator_type_inference_checker(namespace, code): +def test_iterator_type_inference_checker(code, dummy_input_bundle): vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index d390fe9a39..002ee38cd2 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -110,6 +110,6 @@ def test_allocator_overflow(get_contract): """ with pytest.raises( StorageLayoutException, - match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}\n", + match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}", ): get_contract(code) diff --git a/tox.ini b/tox.ini index c949354dfe..f9d4c3b60b 100644 --- a/tox.ini +++ b/tox.ini @@ -53,4 +53,4 @@ commands = basepython = python3 extras = lint commands = - mypy --install-types --non-interactive --follow-imports=silent --ignore-missing-imports --disallow-incomplete-defs -p vyper + mypy --install-types --non-interactive --follow-imports=silent --ignore-missing-imports --implicit-optional -p vyper diff --git a/vyper/__init__.py b/vyper/__init__.py index 482d5c3a60..5bb6469757 100644 --- a/vyper/__init__.py +++ b/vyper/__init__.py @@ -1,6 +1,6 @@ from pathlib import Path as _Path -from vyper.compiler import compile_code # noqa: F401 +from vyper.compiler import compile_code, compile_from_file_input try: from importlib.metadata import PackageNotFoundError # type: ignore diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index e5b81f1e7f..4b46801153 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -6,7 +6,8 @@ from . import nodes, validation from .natspec import parse_natspec from .nodes import compare_nodes -from .utils import ast_to_dict, parse_to_ast, parse_to_ast_with_settings +from .utils import ast_to_dict +from .parse import parse_to_ast, parse_to_ast_with_settings # adds vyper.ast.nodes classes into the local namespace for name, obj in ( diff --git a/vyper/ast/__init__.pyi b/vyper/ast/__init__.pyi index d349e804d6..eac8ffdef5 100644 --- a/vyper/ast/__init__.pyi +++ b/vyper/ast/__init__.pyi @@ -4,5 +4,5 @@ from typing import Any, Optional, Union from . import expansion, folding, nodes, validation from .natspec import parse_natspec as parse_natspec from .nodes import * +from .parse import parse_to_ast as parse_to_ast from .utils import ast_to_dict as ast_to_dict -from .utils import parse_to_ast as parse_to_ast diff --git a/vyper/ast/expansion.py b/vyper/ast/expansion.py index 5471b971a4..1536f39165 100644 --- a/vyper/ast/expansion.py +++ b/vyper/ast/expansion.py @@ -5,22 +5,9 @@ from vyper.semantics.types.function import ContractFunctionT -def expand_annotated_ast(vyper_module: vy_ast.Module) -> None: - """ - Perform expansion / simplification operations on an annotated Vyper AST. - - This pass uses annotated type information to modify the AST, simplifying - logic and expanding subtrees to reduce the compexity during codegen. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node that has been type-checked and annotated. - """ - generate_public_variable_getters(vyper_module) - remove_unused_statements(vyper_module) - - +# TODO: remove this function. it causes correctness/performance problems +# because of copying and mutating the AST - getter generation should be handled +# during code generation. def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: """ Create getter functions for public variables. @@ -32,7 +19,7 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: """ for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_public": True}): - func_type = node._metadata["func_type"] + func_type = node._metadata["getter_type"] input_types, return_type = node._metadata["type"].getter_signature input_nodes = [] @@ -86,31 +73,11 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: returns=return_node, ) - with vyper_module.namespace(): - func_type = ContractFunctionT.from_FunctionDef(expanded) - - expanded._metadata["type"] = func_type - return_node.set_parent(expanded) + # update pointers vyper_module.add_to_body(expanded) + return_node.set_parent(expanded) + with vyper_module.namespace(): + func_type = ContractFunctionT.from_FunctionDef(expanded) -def remove_unused_statements(vyper_module: vy_ast.Module) -> None: - """ - Remove statement nodes that are unused after type checking. - - Once type checking is complete, we can remove now-meaningless statements to - simplify the AST prior to IR generation. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - """ - - # constant declarations - values were substituted within the AST during folding - for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_constant": True}): - vyper_module.remove_from_body(node) - - # `implements: interface` statements - validated during type checking - for node in vyper_module.get_children(vy_ast.ImplementsDecl): - vyper_module.remove_from_body(node) + expanded._metadata["func_type"] = func_type diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index ca9979b2a3..15367ce94a 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -89,7 +89,8 @@ tuple_def: "(" ( NAME | array_def | dyn_array_def | tuple_def ) ( "," ( NAME | a // NOTE: Map takes a basic type and maps to another type (can be non-basic, including maps) _MAP: "HashMap" map_def: _MAP "[" ( NAME | array_def ) "," type "]" -type: ( NAME | array_def | tuple_def | map_def | dyn_array_def ) +imported_type: NAME "." NAME +type: ( NAME | imported_type | array_def | tuple_def | map_def | dyn_array_def ) // Structs can be composed of 1+ basic types or other custom_types _STRUCT_DECL: "struct" @@ -291,7 +292,7 @@ special_builtins: empty | abi_decode // Adapted from: https://docs.python.org/3/reference/grammar.html // Adapted by: Erez Shinan NAME: /[a-zA-Z_]\w*/ -COMMENT: /#[^\n]*/ +COMMENT: /#[^\n\r]*/ _NEWLINE: ( /\r?\n[\t ]*/ | COMMENT )+ @@ -312,8 +313,10 @@ _number: DEC_NUMBER BOOL.2: "True" | "False" +ELLIPSIS: "..." + // TODO: Remove Docstring from here, and add to first part of body -?literal: ( _number | STRING | DOCSTRING | BOOL ) +?literal: ( _number | STRING | DOCSTRING | BOOL | ELLIPSIS) %ignore /[\t \f]+/ // WS %ignore /\\[\t \f]*\r?\n/ // LINE_CONT diff --git a/vyper/ast/natspec.py b/vyper/ast/natspec.py index c25fc423f8..41905b178a 100644 --- a/vyper/ast/natspec.py +++ b/vyper/ast/natspec.py @@ -43,7 +43,7 @@ def parse_natspec(vyper_module_folded: vy_ast.Module) -> Tuple[dict, dict]: for node in [i for i in vyper_module_folded.body if i.get("doc_string.value")]: docstring = node.doc_string.value - func_type = node._metadata["type"] + func_type = node._metadata["func_type"] if func_type.visibility != FunctionVisibility.EXTERNAL: continue diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 69bd1fed53..3bccc5f141 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -589,7 +589,8 @@ def __contains__(self, obj): class Module(TopLevel): - __slots__ = () + # metadata + __slots__ = ("path", "resolved_path", "source_id") def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: """ @@ -897,12 +898,16 @@ def validate(self): raise InvalidLiteral("Cannot have an empty tuple", self) -class Dict(ExprNode): - __slots__ = ("keys", "values") +class NameConstant(Constant): + __slots__ = () -class NameConstant(Constant): - __slots__ = ("value",) +class Ellipsis(Constant): + __slots__ = () + + +class Dict(ExprNode): + __slots__ = ("keys", "values") class Name(ExprNode): @@ -1407,7 +1412,7 @@ class Pass(Stmt): __slots__ = () -class _Import(Stmt): +class _ImportStmt(Stmt): __slots__ = ("name", "alias") def __init__(self, *args, **kwargs): @@ -1419,11 +1424,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -class Import(_Import): +class Import(_ImportStmt): __slots__ = () -class ImportFrom(_Import): +class ImportFrom(_ImportStmt): __slots__ = ("level", "module") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 47c9af8526..05784aed0f 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -2,9 +2,9 @@ import ast as python_ast from typing import Any, Optional, Sequence, Type, Union from .natspec import parse_natspec as parse_natspec +from .parse import parse_to_ast as parse_to_ast +from .parse import parse_to_ast_with_settings as parse_to_ast_with_settings from .utils import ast_to_dict as ast_to_dict -from .utils import parse_to_ast as parse_to_ast -from .utils import parse_to_ast_with_settings as parse_to_ast_with_settings NODE_BASE_ATTRIBUTES: Any NODE_SRC_ATTRIBUTES: Any @@ -59,6 +59,8 @@ class TopLevel(VyperNode): def __contains__(self, obj: Any) -> bool: ... class Module(TopLevel): + path: str = ... + resolved_path: str = ... def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: ... def add_to_body(self, node: VyperNode) -> None: ... def remove_from_body(self, node: VyperNode) -> None: ... @@ -121,6 +123,9 @@ class Bytes(Constant): @property def s(self): ... +class NameConstant(Constant): ... +class Ellipsis(Constant): ... + class List(VyperNode): elements: list = ... @@ -131,8 +136,6 @@ class Dict(VyperNode): keys: list = ... values: list = ... -class NameConstant(Constant): ... - class Name(VyperNode): id: str = ... _type: str = ... @@ -188,7 +191,7 @@ class NotIn(VyperNode): ... class Call(ExprNode): args: list = ... keywords: list = ... - func: Name = ... + func: VyperNode = ... class keyword(VyperNode): ... diff --git a/vyper/ast/annotation.py b/vyper/ast/parse.py similarity index 68% rename from vyper/ast/annotation.py rename to vyper/ast/parse.py index 9c7b1e063f..a2f2542179 100644 --- a/vyper/ast/annotation.py +++ b/vyper/ast/parse.py @@ -1,14 +1,114 @@ import ast as python_ast import tokenize from decimal import Decimal -from typing import Optional, cast +from typing import Any, Dict, List, Optional, Union, cast import asttokens -from vyper.exceptions import CompilerPanic, SyntaxException +from vyper.ast import nodes as vy_ast +from vyper.ast.pre_parser import pre_parse +from vyper.compiler.settings import Settings +from vyper.exceptions import CompilerPanic, ParserException, SyntaxException from vyper.typing import ModificationOffsets +def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module: + _settings, ast = parse_to_ast_with_settings(*args, **kwargs) + return ast + + +def parse_to_ast_with_settings( + source_code: str, + source_id: int = 0, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, + add_fn_node: Optional[str] = None, +) -> tuple[Settings, vy_ast.Module]: + """ + Parses a Vyper source string and generates basic Vyper AST nodes. + + Parameters + ---------- + source_code : str + The Vyper source code to parse. + source_id : int, optional + Source id to use in the `src` member of each node. + contract_name: str, optional + Name of contract. + add_fn_node: str, optional + If not None, adds a dummy Python AST FunctionDef wrapper node. + source_id: int, optional + The source ID generated for this source code. + Corresponds to FileInput.source_id + module_path: str, optional + The path of the source code + Corresponds to FileInput.path + resolved_path: str, optional + The resolved path of the source code + Corresponds to FileInput.resolved_path + + Returns + ------- + list + Untyped, unoptimized Vyper AST nodes. + """ + if "\x00" in source_code: + raise ParserException("No null bytes (\\x00) allowed in the source code.") + settings, class_types, reformatted_code = pre_parse(source_code) + try: + py_ast = python_ast.parse(reformatted_code) + except SyntaxError as e: + # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors + raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e + + # Add dummy function node to ensure local variables are treated as `AnnAssign` + # instead of state variables (`VariableDecl`) + if add_fn_node: + fn_node = python_ast.FunctionDef(add_fn_node, py_ast.body, [], []) + fn_node.body = py_ast.body + fn_node.args = python_ast.arguments(defaults=[]) + py_ast.body = [fn_node] + + annotate_python_ast( + py_ast, + source_code, + class_types, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) + + # Convert to Vyper AST. + module = vy_ast.get_node(py_ast) + assert isinstance(module, vy_ast.Module) # mypy hint + return settings, module + + +def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: + """ + Converts a Vyper AST node, or list of nodes, into a dictionary suitable for + output to the user. + """ + if isinstance(ast_struct, vy_ast.VyperNode): + return ast_struct.to_dict() + + if isinstance(ast_struct, list): + return [i.to_dict() for i in ast_struct] + + raise CompilerPanic(f'Unknown Vyper AST node provided: "{type(ast_struct)}".') + + +def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: + """ + Converts an AST dict, or list of dicts, into Vyper AST node objects. + """ + if isinstance(ast_struct, dict): + return vy_ast.get_node(ast_struct) + if isinstance(ast_struct, list): + return [vy_ast.get_node(i) for i in ast_struct] + raise CompilerPanic(f'Unknown ast_struct provided: "{type(ast_struct)}".') + + class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets @@ -19,11 +119,13 @@ def __init__( modification_offsets: Optional[ModificationOffsets], tokens: asttokens.ASTTokens, source_id: int, - contract_name: Optional[str], + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, ): self._tokens = tokens self._source_id = source_id - self._contract_name = contract_name + self._module_path = module_path + self._resolved_path = resolved_path self._source_code: str = source_code self.counter: int = 0 self._modification_offsets = {} @@ -83,7 +185,9 @@ def _visit_docstring(self, node): return node def visit_Module(self, node): - node.name = self._contract_name + node.path = self._module_path + node.resolved_path = self._resolved_path + node.source_id = self._source_id return self._visit_docstring(node) def visit_FunctionDef(self, node): @@ -163,6 +267,8 @@ def visit_Constant(self, node): node.ast_type = "Str" elif isinstance(node.value, bytes): node.ast_type = "Bytes" + elif isinstance(node.value, Ellipsis.__class__): + node.ast_type = "Ellipsis" else: raise SyntaxException( "Invalid syntax (unsupported Python Constant AST node).", @@ -250,7 +356,8 @@ def annotate_python_ast( source_code: str, modification_offsets: Optional[ModificationOffsets] = None, source_id: int = 0, - contract_name: Optional[str] = None, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, ) -> python_ast.AST: """ Annotate and optimize a Python AST in preparation conversion to a Vyper AST. @@ -270,7 +377,14 @@ def annotate_python_ast( """ tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) - visitor = AnnotatingVisitor(source_code, modification_offsets, tokens, source_id, contract_name) + visitor = AnnotatingVisitor( + source_code, + modification_offsets, + tokens, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) visitor.visit(parsed_ast) return parsed_ast diff --git a/vyper/ast/utils.py b/vyper/ast/utils.py index 4e669385ab..4c2e5394c9 100644 --- a/vyper/ast/utils.py +++ b/vyper/ast/utils.py @@ -1,64 +1,7 @@ -import ast as python_ast -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Union from vyper.ast import nodes as vy_ast -from vyper.ast.annotation import annotate_python_ast -from vyper.ast.pre_parser import pre_parse -from vyper.compiler.settings import Settings -from vyper.exceptions import CompilerPanic, ParserException, SyntaxException - - -def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module: - return parse_to_ast_with_settings(*args, **kwargs)[1] - - -def parse_to_ast_with_settings( - source_code: str, - source_id: int = 0, - contract_name: Optional[str] = None, - add_fn_node: Optional[str] = None, -) -> tuple[Settings, vy_ast.Module]: - """ - Parses a Vyper source string and generates basic Vyper AST nodes. - - Parameters - ---------- - source_code : str - The Vyper source code to parse. - source_id : int, optional - Source id to use in the `src` member of each node. - contract_name: str, optional - Name of contract. - add_fn_node: str, optional - If not None, adds a dummy Python AST FunctionDef wrapper node. - - Returns - ------- - list - Untyped, unoptimized Vyper AST nodes. - """ - if "\x00" in source_code: - raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, class_types, reformatted_code = pre_parse(source_code) - try: - py_ast = python_ast.parse(reformatted_code) - except SyntaxError as e: - # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors - raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e - - # Add dummy function node to ensure local variables are treated as `AnnAssign` - # instead of state variables (`VariableDecl`) - if add_fn_node: - fn_node = python_ast.FunctionDef(add_fn_node, py_ast.body, [], []) - fn_node.body = py_ast.body - fn_node.args = python_ast.arguments(defaults=[]) - py_ast.body = [fn_node] - annotate_python_ast(py_ast, source_code, class_types, source_id, contract_name) - - # Convert to Vyper AST. - module = vy_ast.get_node(py_ast) - assert isinstance(module, vy_ast.Module) # mypy hint - return settings, module +from vyper.exceptions import CompilerPanic def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: diff --git a/vyper/builtins/_utils.py b/vyper/builtins/_utils.py index afc0987b6d..72b05f15e3 100644 --- a/vyper/builtins/_utils.py +++ b/vyper/builtins/_utils.py @@ -1,10 +1,10 @@ from vyper.ast import parse_to_ast from vyper.codegen.context import Context -from vyper.codegen.global_context import GlobalContext from vyper.codegen.stmt import parse_body from vyper.semantics.analysis.local import FunctionNodeVisitor from vyper.semantics.namespace import Namespace, override_global_namespace from vyper.semantics.types.function import ContractFunctionT, FunctionVisibility, StateMutability +from vyper.semantics.types.module import ModuleT def _strip_source_pos(ir_node): @@ -22,15 +22,16 @@ def generate_inline_function(code, variables, variables_2, memory_allocator): # Initialise a placeholder `FunctionDef` AST node and corresponding # `ContractFunctionT` type to rely on the annotation visitors in semantics # module. - ast_code.body[0]._metadata["type"] = ContractFunctionT( + ast_code.body[0]._metadata["func_type"] = ContractFunctionT( "sqrt_builtin", [], [], None, FunctionVisibility.INTERNAL, StateMutability.NONPAYABLE ) # The FunctionNodeVisitor's constructor performs semantic checks # annotate the AST as side effects. - FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer = FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer.analyze() new_context = Context( - vars_=variables, global_ctx=GlobalContext(), memory_allocator=memory_allocator + vars_=variables, module_ctx=ModuleT(ast_code), memory_allocator=memory_allocator ) generated_ir = parse_body(ast_code.body[0].body, new_context) # strip source position info from the generated_ir since diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 22931508a6..d50a31767d 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2499,9 +2499,9 @@ def infer_arg_types(self, node): validate_call_args(node, 2, ["unwrap_tuple"]) data_type = get_exact_type_from_node(node.args[0]) - output_typedef = TYPE_T(type_from_annotation(node.args[1])) + output_type = type_from_annotation(node.args[1]) - return [data_type, output_typedef] + return [data_type, TYPE_T(output_type)] @process_inputs def build_IR(self, expr, args, kwargs, context): diff --git a/vyper/builtins/interfaces/ERC165.vy b/vyper/builtins/interfaces/ERC165.vyi similarity index 88% rename from vyper/builtins/interfaces/ERC165.vy rename to vyper/builtins/interfaces/ERC165.vyi index a4ca451abd..441130f77c 100644 --- a/vyper/builtins/interfaces/ERC165.vy +++ b/vyper/builtins/interfaces/ERC165.vyi @@ -1,4 +1,4 @@ @view @external def supportsInterface(interface_id: bytes4) -> bool: - pass + ... diff --git a/vyper/builtins/interfaces/ERC20.vy b/vyper/builtins/interfaces/ERC20.vyi similarity index 68% rename from vyper/builtins/interfaces/ERC20.vy rename to vyper/builtins/interfaces/ERC20.vyi index 065ca97a9b..ee533ab326 100644 --- a/vyper/builtins/interfaces/ERC20.vy +++ b/vyper/builtins/interfaces/ERC20.vyi @@ -1,38 +1,38 @@ # Events event Transfer: - _from: indexed(address) - _to: indexed(address) - _value: uint256 + sender: indexed(address) + recipient: indexed(address) + value: uint256 event Approval: - _owner: indexed(address) - _spender: indexed(address) - _value: uint256 + owner: indexed(address) + spender: indexed(address) + value: uint256 # Functions @view @external def totalSupply() -> uint256: - pass + ... @view @external def balanceOf(_owner: address) -> uint256: - pass + ... @view @external def allowance(_owner: address, _spender: address) -> uint256: - pass + ... @external def transfer(_to: address, _value: uint256) -> bool: - pass + ... @external def transferFrom(_from: address, _to: address, _value: uint256) -> bool: - pass + ... @external def approve(_spender: address, _value: uint256) -> bool: - pass + ... diff --git a/vyper/builtins/interfaces/ERC20Detailed.vy b/vyper/builtins/interfaces/ERC20Detailed.vyi similarity index 93% rename from vyper/builtins/interfaces/ERC20Detailed.vy rename to vyper/builtins/interfaces/ERC20Detailed.vyi index 7c4f546d45..0be1c6f153 100644 --- a/vyper/builtins/interfaces/ERC20Detailed.vy +++ b/vyper/builtins/interfaces/ERC20Detailed.vyi @@ -5,14 +5,14 @@ @view @external def name() -> String[1]: - pass + ... @view @external def symbol() -> String[1]: - pass + ... @view @external def decimals() -> uint8: - pass + ... diff --git a/vyper/builtins/interfaces/ERC4626.vy b/vyper/builtins/interfaces/ERC4626.vyi similarity index 90% rename from vyper/builtins/interfaces/ERC4626.vy rename to vyper/builtins/interfaces/ERC4626.vyi index 05865406cf..6d9e4c6ef7 100644 --- a/vyper/builtins/interfaces/ERC4626.vy +++ b/vyper/builtins/interfaces/ERC4626.vyi @@ -16,75 +16,75 @@ event Withdraw: @view @external def asset() -> address: - pass + ... @view @external def totalAssets() -> uint256: - pass + ... @view @external def convertToShares(assetAmount: uint256) -> uint256: - pass + ... @view @external def convertToAssets(shareAmount: uint256) -> uint256: - pass + ... @view @external def maxDeposit(owner: address) -> uint256: - pass + ... @view @external def previewDeposit(assets: uint256) -> uint256: - pass + ... @external def deposit(assets: uint256, receiver: address=msg.sender) -> uint256: - pass + ... @view @external def maxMint(owner: address) -> uint256: - pass + ... @view @external def previewMint(shares: uint256) -> uint256: - pass + ... @external def mint(shares: uint256, receiver: address=msg.sender) -> uint256: - pass + ... @view @external def maxWithdraw(owner: address) -> uint256: - pass + ... @view @external def previewWithdraw(assets: uint256) -> uint256: - pass + ... @external def withdraw(assets: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: - pass + ... @view @external def maxRedeem(owner: address) -> uint256: - pass + ... @view @external def previewRedeem(shares: uint256) -> uint256: - pass + ... @external def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: - pass + ... diff --git a/vyper/builtins/interfaces/ERC721.vy b/vyper/builtins/interfaces/ERC721.vyi similarity index 61% rename from vyper/builtins/interfaces/ERC721.vy rename to vyper/builtins/interfaces/ERC721.vyi index 464c0e255b..b8dcfd3c5f 100644 --- a/vyper/builtins/interfaces/ERC721.vy +++ b/vyper/builtins/interfaces/ERC721.vyi @@ -1,67 +1,62 @@ # Events event Transfer: - _from: indexed(address) - _to: indexed(address) - _tokenId: indexed(uint256) + sender: indexed(address) + recipient: indexed(address) + token_id: indexed(uint256) event Approval: - _owner: indexed(address) - _approved: indexed(address) - _tokenId: indexed(uint256) + owner: indexed(address) + approved: indexed(address) + token_id: indexed(uint256) event ApprovalForAll: - _owner: indexed(address) - _operator: indexed(address) - _approved: bool + owner: indexed(address) + operator: indexed(address) + approved: bool # Functions @view @external def supportsInterface(interface_id: bytes4) -> bool: - pass + ... @view @external def balanceOf(_owner: address) -> uint256: - pass + ... @view @external def ownerOf(_tokenId: uint256) -> address: - pass + ... @view @external def getApproved(_tokenId: uint256) -> address: - pass + ... @view @external def isApprovedForAll(_owner: address, _operator: address) -> bool: - pass + ... @external @payable def transferFrom(_from: address, _to: address, _tokenId: uint256): - pass + ... @external @payable -def safeTransferFrom(_from: address, _to: address, _tokenId: uint256): - pass - -@external -@payable -def safeTransferFrom(_from: address, _to: address, _tokenId: uint256, _data: Bytes[1024]): - pass +def safeTransferFrom(_from: address, _to: address, _tokenId: uint256, _data: Bytes[1024] = b""): + ... @external @payable def approve(_approved: address, _tokenId: uint256): - pass + ... @external def setApprovalForAll(_operator: address, _approved: bool): - pass + ... diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index ca1792384e..4f88812fa0 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -271,10 +271,8 @@ def compile_files( with open(storage_file_path) as sfh: storage_layout_override = json.load(sfh) - output = vyper.compile_code( - file.source_code, - contract_name=str(file.path), - source_id=file.source_id, + output = vyper.compile_from_file_input( + file, input_bundle=input_bundle, output_formats=final_formats, exc_handler=exc_handler, diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index 2720f20d23..63da2e0643 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -12,7 +12,7 @@ from vyper.compiler.settings import OptimizationLevel, Settings from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import JSONError -from vyper.utils import keccak256 +from vyper.utils import OrderedSet, keccak256 TRANSLATE_MAP = { "abi": "abi", @@ -151,13 +151,6 @@ def get_evm_version(input_dict: dict) -> Optional[str]: return evm_version -def get_compilation_targets(input_dict: dict) -> list[PurePath]: - # TODO: once we have modules, add optional "compilation_targets" key - # which specifies which sources we actually want to compile. - - return [PurePath(p) for p in input_dict["sources"].keys()] - - def get_inputs(input_dict: dict) -> dict[PurePath, Any]: ret = {} seen = {} @@ -218,14 +211,14 @@ def get_inputs(input_dict: dict) -> dict[PurePath, Any]: # get unique output formats for each contract, given the input_dict # NOTE: would maybe be nice to raise on duplicated output formats -def get_output_formats(input_dict: dict, targets: list[PurePath]) -> dict[PurePath, list[str]]: +def get_output_formats(input_dict: dict) -> dict[PurePath, list[str]]: output_formats: dict[PurePath, list[str]] = {} for path, outputs in input_dict["settings"]["outputSelection"].items(): if isinstance(outputs, dict): # if outputs are given in solc json format, collapse them into a single list - outputs = set(x for i in outputs.values() for x in i) + outputs = OrderedSet(x for i in outputs.values() for x in i) else: - outputs = set(outputs) + outputs = OrderedSet(outputs) for key in [i for i in ("evm", "evm.bytecode", "evm.deployedBytecode") if i in outputs]: outputs.remove(key) @@ -239,13 +232,13 @@ def get_output_formats(input_dict: dict, targets: list[PurePath]) -> dict[PurePa except KeyError as e: raise JSONError(f"Invalid outputSelection - {e}") - outputs = sorted(set(outputs)) + outputs = sorted(list(outputs)) if path == "*": - output_paths = targets + output_paths = [PurePath(path) for path in input_dict["sources"].keys()] else: output_paths = [PurePath(path)] - if output_paths[0] not in targets: + if str(output_paths[0]) not in input_dict["sources"]: raise JSONError(f"outputSelection references unknown contract '{output_paths[0]}'") for output_path in output_paths: @@ -281,9 +274,9 @@ def compile_from_input_dict( no_bytecode_metadata = not input_dict["settings"].get("bytecodeMetadata", True) - compilation_targets = get_compilation_targets(input_dict) sources = get_inputs(input_dict) - output_formats = get_output_formats(input_dict, compilation_targets) + output_formats = get_output_formats(input_dict) + compilation_targets = list(output_formats.keys()) input_bundle = JSONInputBundle(sources, search_paths=[Path(root_folder)]) @@ -295,12 +288,10 @@ def compile_from_input_dict( # use load_file to get a unique source_id file = input_bundle.load_file(contract_path) assert isinstance(file, FileInput) # mypy hint - data = vyper.compile_code( - file.source_code, - contract_name=str(file.path), + data = vyper.compile_from_file_input( + file, input_bundle=input_bundle, output_formats=output_formats[contract_path], - source_id=file.source_id, settings=settings, no_bytecode_metadata=no_bytecode_metadata, ) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 5b79f293bd..dea30faabc 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -48,7 +48,7 @@ def __repr__(self): class Context: def __init__( self, - global_ctx, + module_ctx, memory_allocator, vars_=None, forvars=None, @@ -60,7 +60,7 @@ def __init__( self.vars = vars_ or {} # Global variables, in the form (name, storage location, type) - self.globals = global_ctx.variables + self.globals = module_ctx.variables # Variables defined in for loops, e.g. for i in range(6): ... self.forvars = forvars or {} @@ -75,8 +75,8 @@ def __init__( # Whether we are currently parsing a range expression self.in_range_expr = False - # store global context - self.global_ctx = global_ctx + # store module context + self.module_ctx = module_ctx # full function type self.func_t = func_t diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index dc0e98786f..5870e64e98 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -47,8 +47,10 @@ StringT, StructT, TupleT, + is_type_t, ) from vyper.semantics.types.bytestrings import _BytestringT +from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T from vyper.utils import ( DECIMAL_DIVISOR, @@ -79,7 +81,7 @@ def __init__(self, node, context): self.ir_node = fn() if self.ir_node is None: - raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.", node) + raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.\n", node) self.ir_node.annotation = self.expr.get("node_source_code") self.ir_node.source_pos = getpos(self.expr) @@ -662,39 +664,38 @@ def parse_Call(self): if function_name in DISPATCH_TABLE: return DISPATCH_TABLE[function_name].build_IR(self.expr, self.context) - # Struct constructors do not need `self` prefix. - elif isinstance(self.expr._metadata["type"], StructT): - args = self.expr.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return Expr.struct_literals(args[0], self.context, self.expr._metadata["type"]) + func_type = self.expr.func._metadata["type"] - # Interface assignment. Bar(
). - elif isinstance(self.expr._metadata["type"], InterfaceT): - (arg0,) = self.expr.args - arg_ir = Expr(arg0, self.context).ir_node + # Struct constructor + if is_type_t(func_type, StructT): + args = self.expr.args + if len(args) == 1 and isinstance(args[0], vy_ast.Dict): + return Expr.struct_literals(args[0], self.context, self.expr._metadata["type"]) - assert arg_ir.typ == AddressT() - arg_ir.typ = self.expr._metadata["type"] + # Interface constructor. Bar(
). + if is_type_t(func_type, InterfaceT): + (arg0,) = self.expr.args + arg_ir = Expr(arg0, self.context).ir_node - return arg_ir + assert arg_ir.typ == AddressT() + arg_ir.typ = self.expr._metadata["type"] - elif isinstance(self.expr.func, vy_ast.Attribute) and self.expr.func.attr == "pop": + return arg_ir + + if isinstance(func_type, MemberFunctionT) and self.expr.func.attr == "pop": # TODO consider moving this to builtins darray = Expr(self.expr.func.value, self.context).ir_node assert len(self.expr.args) == 0 assert isinstance(darray.typ, DArrayT) return pop_dyn_array(darray, return_popped_item=True) - elif ( - # TODO use expr.func.type.is_internal once - # type annotations are consistently available - isinstance(self.expr.func, vy_ast.Attribute) - and isinstance(self.expr.func.value, vy_ast.Name) - and self.expr.func.value.id == "self" - ): - return self_call.ir_for_self_call(self.expr, self.context) - else: - return external_call.ir_for_external_call(self.expr, self.context) + if isinstance(func_type, ContractFunctionT): + if func_type.is_internal: + return self_call.ir_for_self_call(self.expr, self.context) + else: + return external_call.ir_for_external_call(self.expr, self.context) + + raise CompilerPanic("unreachable", self.expr) def parse_List(self): typ = self.expr._metadata["type"] diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index c48f1256c3..454ba9c8cd 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -7,13 +7,13 @@ from vyper.codegen.core import check_single_exit from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function -from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode from vyper.codegen.memory_allocator import MemoryAllocator from vyper.exceptions import CompilerPanic from vyper.semantics.types import VyperType from vyper.semantics.types.function import ContractFunctionT -from vyper.utils import MemoryPositions, calc_mem_gas, mkalphanum +from vyper.semantics.types.module import ModuleT +from vyper.utils import MemoryPositions, calc_mem_gas @dataclass @@ -44,7 +44,14 @@ def exit_sequence_label(self) -> str: @cached_property def ir_identifier(self) -> str: argz = ",".join([str(argtyp) for argtyp in self.func_t.argument_types]) - return mkalphanum(f"{self.visibility} {self.func_t.name} ({argz})") + + name = self.func_t.name + function_id = self.func_t._function_id + assert function_id is not None + + # include module id in the ir identifier to disambiguate functions + # with the same name but which come from different modules + return f"{self.visibility} {function_id} {name}({argz})" def set_frame_info(self, frame_info: FrameInfo) -> None: if self.frame_info is not None: @@ -94,7 +101,7 @@ class InternalFuncIR(FuncIR): # TODO: should split this into external and internal ir generation? def generate_ir_for_function( - code: vy_ast.FunctionDef, global_ctx: GlobalContext, is_ctor_context: bool = False + code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False ) -> FuncIR: """ Parse a function and produce IR code for the function, includes: @@ -103,7 +110,7 @@ def generate_ir_for_function( - Clamping and copying of arguments - Function body """ - func_t = code._metadata["type"] + func_t = code._metadata["func_type"] # generate _FuncIRInfo func_t._ir_info = _FuncIRInfo(func_t) @@ -126,7 +133,7 @@ def generate_ir_for_function( context = Context( vars_=None, - global_ctx=global_ctx, + module_ctx=module_ctx, memory_allocator=memory_allocator, constancy=Constancy.Mutable if func_t.is_mutable else Constancy.Constant, func_t=func_t, diff --git a/vyper/codegen/global_context.py b/vyper/codegen/global_context.py deleted file mode 100644 index 1f6783f6f8..0000000000 --- a/vyper/codegen/global_context.py +++ /dev/null @@ -1,32 +0,0 @@ -from functools import cached_property -from typing import Optional - -from vyper import ast as vy_ast - - -# Datatype to store all global context information. -# TODO: rename me to ModuleT -class GlobalContext: - def __init__(self, module: Optional[vy_ast.Module] = None): - self._module = module - - @cached_property - def functions(self): - return self._module.get_children(vy_ast.FunctionDef) - - @cached_property - def variables(self): - # variables that this module defines, ex. - # `x: uint256` is a private storage variable named x - if self._module is None: # TODO: make self._module never be None - return None - variable_decls = self._module.get_children(vy_ast.VariableDecl) - return {s.target.id: s.target._metadata["varinfo"] for s in variable_decls} - - @property - def immutables(self): - return [t for t in self.variables.values() if t.is_immutable] - - @cached_property - def immutable_section_bytes(self): - return sum([imm.typ.memory_bytes_required for imm in self.immutables]) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index bfdafa8ba9..ef861e3953 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -5,49 +5,67 @@ from vyper.codegen import core, jumptable_utils from vyper.codegen.core import shr from vyper.codegen.function_definitions import generate_ir_for_function -from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import _is_debug_mode from vyper.exceptions import CompilerPanic -from vyper.utils import method_id_int +from vyper.semantics.types.module import ModuleT +from vyper.utils import OrderedSet, method_id_int -def _topsort_helper(functions, lookup): - # single pass to get a global topological sort of functions (so that each - # function comes after each of its callees). may have duplicates, which get - # filtered out in _topsort() +def _topsort(functions): + # single pass to get a global topological sort of functions (so that each + # function comes after each of its callees). + ret = OrderedSet() + for func_ast in functions: + fn_t = func_ast._metadata["func_type"] + + for reachable_t in fn_t.reachable_internal_functions: + assert reachable_t.ast_def is not None + ret.add(reachable_t.ast_def) + + ret.add(func_ast) + + # create globally unique IDs for each function + for idx, func in enumerate(ret): + func._metadata["func_type"]._function_id = idx + + return list(ret) + - ret = [] +# calculate globally reachable functions to see which +# ones should make it into the final bytecode. +# TODO: in the future, this should get obsolesced by IR dead code eliminator. +def _globally_reachable_functions(functions): + ret = OrderedSet() for f in functions: - # called_functions is a list of ContractFunctions, need to map - # back to FunctionDefs. - callees = [lookup[t.name] for t in f._metadata["type"].called_functions] - ret.extend(_topsort_helper(callees, lookup)) - ret.append(f) + fn_t = f._metadata["func_type"] - return ret + if not fn_t.is_external: + continue + for reachable_t in fn_t.reachable_internal_functions: + assert reachable_t.ast_def is not None + ret.add(reachable_t) -def _topsort(functions): - lookup = {f.name: f for f in functions} - # strip duplicates - return list(dict.fromkeys(_topsort_helper(functions, lookup))) + ret.add(fn_t) + + return ret def _is_constructor(func_ast): - return func_ast._metadata["type"].is_constructor + return func_ast._metadata["func_type"].is_constructor def _is_fallback(func_ast): - return func_ast._metadata["type"].is_fallback + return func_ast._metadata["func_type"].is_fallback def _is_internal(func_ast): - return func_ast._metadata["type"].is_internal + return func_ast._metadata["func_type"].is_internal def _is_payable(func_ast): - return func_ast._metadata["type"].is_payable + return func_ast._metadata["func_type"].is_payable def _annotated_method_id(abi_sig): @@ -63,7 +81,7 @@ def label_for_entry_point(abi_sig, entry_point): # adapt whatever generate_ir_for_function gives us into an IR node def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): - func_t = func_ast._metadata["type"] + func_t = func_ast._metadata["func_type"] assert func_t.is_fallback or func_t.is_constructor ret = ["seq"] @@ -86,12 +104,12 @@ def _ir_for_internal_function(func_ast, *args, **kwargs): return generate_ir_for_function(func_ast, *args, **kwargs).func_ir -def _generate_external_entry_points(external_functions, global_ctx): +def _generate_external_entry_points(external_functions, module_ctx): entry_points = {} # map from ABI sigs to ir code sig_of = {} # reverse map from method ids to abi sig for code in external_functions: - func_ir = generate_ir_for_function(code, global_ctx) + func_ir = generate_ir_for_function(code, module_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): method_id = method_id_int(abi_sig) assert abi_sig not in entry_points @@ -113,13 +131,13 @@ def _generate_external_entry_points(external_functions, global_ctx): # into a bucket (of about 8-10 items), and then uses perfect hash # to select the final function. # costs about 212 gas for typical function and 8 bytes of code (+ ~87 bytes of global overhead) -def _selector_section_dense(external_functions, global_ctx): +def _selector_section_dense(external_functions, module_ctx): function_irs = [] if len(external_functions) == 0: return IRnode.from_list(["seq"]) - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) # generate the label so the jumptable works for abi_sig, entry_point in entry_points.items(): @@ -264,13 +282,13 @@ def _selector_section_dense(external_functions, global_ctx): # a bucket, and then descends into linear search from there. # costs about 126 gas for typical (nonpayable, >0 args, avg bucket size 1.5) # function and 24 bytes of code (+ ~23 bytes of global overhead) -def _selector_section_sparse(external_functions, global_ctx): +def _selector_section_sparse(external_functions, module_ctx): ret = ["seq"] if len(external_functions) == 0: return ret - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) n_buckets, buckets = jumptable_utils.generate_sparse_jumptable_buckets(entry_points.keys()) @@ -367,14 +385,14 @@ def _selector_section_sparse(external_functions, global_ctx): # O(n) linear search for the method id # mainly keep this in for backends which cannot handle the indirect jump # in selector_section_dense and selector_section_sparse -def _selector_section_linear(external_functions, global_ctx): +def _selector_section_linear(external_functions, module_ctx): ret = ["seq"] if len(external_functions) == 0: return ret ret.append(["if", ["lt", "calldatasize", 4], ["goto", "fallback"]]) - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) dispatcher = ["seq"] @@ -402,10 +420,11 @@ def _selector_section_linear(external_functions, global_ctx): return ret -# take a GlobalContext, and generate the runtime and deploy IR -def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: +# take a ModuleT, and generate the runtime and deploy IR +def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: # order functions so that each function comes after all of its callees - function_defs = _topsort(global_ctx.functions) + function_defs = _topsort(module_ctx.function_defs) + reachable = _globally_reachable_functions(module_ctx.function_defs) runtime_functions = [f for f in function_defs if not _is_constructor(f)] init_function = next((f for f in function_defs if _is_constructor(f)), None) @@ -421,20 +440,26 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: # compile internal functions first so we have the function info for func_ast in internal_functions: - func_ir = _ir_for_internal_function(func_ast, global_ctx, False) - internal_functions_ir.append(IRnode.from_list(func_ir)) + # compile it so that _ir_info is populated (whether or not it makes + # it into the final IR artifact) + func_ir = _ir_for_internal_function(func_ast, module_ctx, False) + + # only include it in the IR if it is reachable from an external + # function. + if func_ast._metadata["func_type"] in reachable: + internal_functions_ir.append(IRnode.from_list(func_ir)) if core._opt_none(): - selector_section = _selector_section_linear(external_functions, global_ctx) + selector_section = _selector_section_linear(external_functions, module_ctx) # dense vs sparse global overhead is amortized after about 4 methods. # (--debug will force dense selector table anyway if _opt_codesize is selected.) elif core._opt_codesize() and (len(external_functions) > 4 or _is_debug_mode()): - selector_section = _selector_section_dense(external_functions, global_ctx) + selector_section = _selector_section_dense(external_functions, module_ctx) else: - selector_section = _selector_section_sparse(external_functions, global_ctx) + selector_section = _selector_section_sparse(external_functions, module_ctx) if default_function: - fallback_ir = _ir_for_fallback_or_ctor(default_function, global_ctx) + fallback_ir = _ir_for_fallback_or_ctor(default_function, module_ctx) else: fallback_ir = IRnode.from_list( ["revert", 0, 0], annotation="Default function", error_msg="fallback function" @@ -447,29 +472,30 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: runtime.extend(internal_functions_ir) deploy_code: List[Any] = ["seq"] - immutables_len = global_ctx.immutable_section_bytes + immutables_len = module_ctx.immutable_section_bytes if init_function: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` + init_func_t = init_function._metadata["func_type"] ctor_internal_func_irs = [] internal_functions = [f for f in runtime_functions if _is_internal(f)] for f in internal_functions: - init_func_t = init_function._metadata["type"] - if f.name not in init_func_t.recursive_calls: + func_t = f._metadata["func_type"] + if func_t not in init_func_t.reachable_internal_functions: # unreachable code, delete it continue - func_ir = _ir_for_internal_function(f, global_ctx, is_ctor_context=True) + func_ir = _ir_for_internal_function(f, module_ctx, is_ctor_context=True) ctor_internal_func_irs.append(func_ir) # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, global_ctx, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx, is_ctor_context=True) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables # note: (deploy mem_ofst, code, extra_padding) - init_mem_used = init_function._metadata["type"]._ir_info.frame_info.mem_used + init_mem_used = init_function._metadata["func_type"]._ir_info.frame_info.mem_used # force msize to be initialized past the end of immutables section # so that builtins which use `msize` for "dynamic" memory diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index f03f2eb9c8..f53e4a81b4 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -4,15 +4,6 @@ from vyper.exceptions import StateAccessViolation from vyper.semantics.types.subscriptable import TupleT -_label_counter = 0 - - -# TODO a more general way of doing this -def _generate_label(name: str) -> str: - global _label_counter - _label_counter += 1 - return f"label{_label_counter}" - def _align_kwargs(func_t, args_ir): """ @@ -63,7 +54,7 @@ def ir_for_self_call(stmt_expr, context): # note: internal_function_label asserts `func_t.is_internal`. _label = func_t._ir_info.internal_function_label(context.is_ctor_context) - return_label = _generate_label(f"{_label}_call") + return_label = _freshname(f"{_label}_call") # allocate space for the return buffer # TODO allocate in stmt and/or expr.py diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 254cad32e6..cc7a603b7c 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -26,6 +26,7 @@ from vyper.evm.address_space import MEMORY, STORAGE from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure from vyper.semantics.types import DArrayT, MemberFunctionT +from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.shortcuts import INT256_T, UINT256_T @@ -117,44 +118,32 @@ def parse_Log(self): return events.ir_node_for_log(self.stmt, event, topic_ir, data_ir, self.context) def parse_Call(self): - # TODO use expr.func.type.is_internal once type annotations - # are consistently available. - is_self_function = ( - (isinstance(self.stmt.func, vy_ast.Attribute)) - and isinstance(self.stmt.func.value, vy_ast.Name) - and self.stmt.func.value.id == "self" - ) - if isinstance(self.stmt.func, vy_ast.Name): funcname = self.stmt.func.id return STMT_DISPATCH_TABLE[funcname].build_IR(self.stmt, self.context) - elif isinstance(self.stmt.func, vy_ast.Attribute) and self.stmt.func.attr in ( - "append", - "pop", - ): - func_type = self.stmt.func._metadata["type"] - if isinstance(func_type, MemberFunctionT): - darray = Expr(self.stmt.func.value, self.context).ir_node - args = [Expr(x, self.context).ir_node for x in self.stmt.args] - if self.stmt.func.attr == "append": - # sanity checks - assert len(args) == 1 - arg = args[0] - assert isinstance(darray.typ, DArrayT) - check_assign( - dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) - ) - - return append_dyn_array(darray, arg) - else: - assert len(args) == 0 - return pop_dyn_array(darray, return_popped_item=False) - - if is_self_function: - return self_call.ir_for_self_call(self.stmt, self.context) - else: - return external_call.ir_for_external_call(self.stmt, self.context) + func_type = self.stmt.func._metadata["type"] + + if isinstance(func_type, MemberFunctionT) and self.stmt.func.attr in ("append", "pop"): + darray = Expr(self.stmt.func.value, self.context).ir_node + args = [Expr(x, self.context).ir_node for x in self.stmt.args] + if self.stmt.func.attr == "append": + (arg,) = args + assert isinstance(darray.typ, DArrayT) + check_assign( + dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) + ) + + return append_dyn_array(darray, arg) + else: + assert len(args) == 0 + return pop_dyn_array(darray, return_popped_item=False) + + if isinstance(func_type, ContractFunctionT): + if func_type.is_internal: + return self_call.ir_for_self_call(self.stmt, self.context) + else: + return external_call.ir_for_external_call(self.stmt, self.context) def _assert_reason(self, test_expr, msg): # from parse_Raise: None passed as the assert condition diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 61d7a7c229..026c8369c5 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -5,7 +5,7 @@ import vyper.ast as vy_ast # break an import cycle import vyper.codegen.core as codegen import vyper.compiler.output as output -from vyper.compiler.input_bundle import InputBundle, PathLike +from vyper.compiler.input_bundle import FileInput, InputBundle, PathLike from vyper.compiler.phases import CompilerData from vyper.compiler.settings import Settings from vyper.evm.opcodes import DEFAULT_EVM_VERSION, anchor_evm_version @@ -44,10 +44,8 @@ UNKNOWN_CONTRACT_NAME = "" -def compile_code( - contract_source: str, - contract_name: str = UNKNOWN_CONTRACT_NAME, - source_id: int = 0, +def compile_from_file_input( + file_input: FileInput, input_bundle: InputBundle = None, settings: Settings = None, output_formats: Optional[OutputFormats] = None, @@ -58,6 +56,8 @@ def compile_code( experimental_codegen: bool = False, ) -> dict: """ + Main entry point into the compiler. + Generate consumable compiler output(s) from a single contract source code. Basically, a wrapper around CompilerData which munges the output data into the requested output formats. @@ -72,6 +72,8 @@ def compile_code( evm_version: str, optional The target EVM ruleset to compile for. If not given, defaults to the latest implemented ruleset. + source_id: int, optional + source_id to tag AST nodes with. -1 if not provided. settings: Settings, optional Compiler settings. show_gas_estimates: bool, optional @@ -96,11 +98,11 @@ def compile_code( # make IR output the same between runs codegen.reset_names() + # TODO: maybe at this point we might as well just pass a `FileInput` + # directly to `CompilerData`. compiler_data = CompilerData( - contract_source, + file_input, input_bundle, - Path(contract_name), - source_id, settings, storage_layout_override, show_gas_estimates, @@ -118,8 +120,33 @@ def compile_code( ret[output_format] = formatter(compiler_data) except Exception as exc: if exc_handler is not None: - exc_handler(contract_name, exc) + exc_handler(str(file_input.path), exc) else: raise exc return ret + + +def compile_code( + source_code: str, + contract_path: str | PathLike = UNKNOWN_CONTRACT_NAME, + source_id: int = -1, + resolved_path: PathLike | None = None, + *args, + **kwargs, +): + # this function could be renamed to compile_from_string + """ + Do the same thing as compile_from_file_input but takes a string for source + code. This was previously the main entry point into the compiler + # (`compile_from_file_input()` is newer) + """ + if isinstance(contract_path, str): + contract_path = Path(contract_path) + file_input = FileInput( + source_id=source_id, + source_code=source_code, + path=contract_path, + resolved_path=resolved_path or contract_path, # type: ignore + ) + return compile_from_file_input(file_input, *args, **kwargs) diff --git a/vyper/compiler/input_bundle.py b/vyper/compiler/input_bundle.py index 1e41c3f137..27170f0a56 100644 --- a/vyper/compiler/input_bundle.py +++ b/vyper/compiler/input_bundle.py @@ -15,15 +15,11 @@ class CompilerInput: # an input to the compiler, basically an abstraction for file contents source_id: int - path: PathLike + path: PathLike # the path that was asked for - @staticmethod - def from_string(source_id: int, path: PathLike, file_contents: str) -> "CompilerInput": - try: - s = json.loads(file_contents) - return ABIInput(source_id, path, s) - except (ValueError, TypeError): - return FileInput(source_id, path, file_contents) + # resolved_path is the real path that was resolved to. + # mainly handy for debugging at this point + resolved_path: PathLike @dataclass @@ -40,13 +36,16 @@ class ABIInput(CompilerInput): abi: Any # something that json.load() returns -class _NotFound(Exception): - pass +def try_parse_abi(file_input: FileInput) -> CompilerInput: + try: + s = json.loads(file_input.source_code) + return ABIInput(file_input.source_id, file_input.path, file_input.resolved_path, s) + except (ValueError, TypeError): + return file_input -# wrap os.path.normpath, but return the same type as the input -def _normpath(path): - return path.__class__(os.path.normpath(path)) +class _NotFound(Exception): + pass # an "input bundle" to the compiler, representing the files which are @@ -60,20 +59,31 @@ class InputBundle: # a list of search paths search_paths: list[PathLike] + _cache: Any + def __init__(self, search_paths): self.search_paths = search_paths self._source_id_counter = 0 self._source_ids: dict[PathLike, int] = {} - def _load_from_path(self, path): + # this is a little bit cursed, but it allows consumers to cache data that + # share the same lifetime as this input bundle. + self._cache = lambda: None + + def _normalize_path(self, path): + raise NotImplementedError(f"not implemented! {self.__class__}._normalize_path()") + + def _load_from_path(self, resolved_path, path): raise NotImplementedError(f"not implemented! {self.__class__}._load_from_path()") - def _generate_source_id(self, path: PathLike) -> int: - if path not in self._source_ids: - self._source_ids[path] = self._source_id_counter + def _generate_source_id(self, resolved_path: PathLike) -> int: + # Note: it is possible for a file to get in here more than once, + # e.g. by symlink + if resolved_path not in self._source_ids: + self._source_ids[resolved_path] = self._source_id_counter self._source_id_counter += 1 - return self._source_ids[path] + return self._source_ids[resolved_path] def load_file(self, path: PathLike | str) -> CompilerInput: # search path precedence @@ -84,12 +94,9 @@ def load_file(self, path: PathLike | str) -> CompilerInput: # Path("/a") / Path("/b") => Path("/b") to_try = sp / path - # normalize the path with os.path.normpath, to break down - # things like "foo/bar/../x.vy" => "foo/x.vy", with all - # the caveats around symlinks that os.path.normpath comes with. - to_try = _normpath(to_try) try: - res = self._load_from_path(to_try) + to_try = self._normalize_path(to_try) + res = self._load_from_path(to_try, path) break except _NotFound: tried.append(to_try) @@ -104,7 +111,7 @@ def load_file(self, path: PathLike | str) -> CompilerInput: # try to parse from json, so that return types are consistent # across FilesystemInputBundle and JSONInputBundle. if isinstance(res, FileInput): - return CompilerInput.from_string(res.source_id, res.path, res.source_code) + res = try_parse_abi(res) return res @@ -126,20 +133,45 @@ def search_path(self, path: Optional[PathLike]) -> Iterator[None]: finally: self.search_paths.pop() + # temporarily modify the top of the search path (within the + # scope of the context manager) with highest precedence to something else + @contextlib.contextmanager + def poke_search_path(self, path: PathLike) -> Iterator[None]: + tmp = self.search_paths[-1] + self.search_paths[-1] = path + try: + yield + finally: + self.search_paths[-1] = tmp + # regular input. takes a search path(s), and `load_file()` will search all # search paths for the file and read it from the filesystem class FilesystemInputBundle(InputBundle): - def _load_from_path(self, path: Path) -> CompilerInput: + def _normalize_path(self, path: Path) -> Path: + # normalize the path with os.path.normpath, to break down + # things like "foo/bar/../x.vy" => "foo/x.vy", with all + # the caveats around symlinks that os.path.normpath comes with. try: - with path.open() as f: - code = f.read() - except FileNotFoundError: + return path.resolve(strict=True) + except (FileNotFoundError, NotADirectoryError): raise _NotFound(path) - source_id = super()._generate_source_id(path) + def _load_from_path(self, resolved_path: Path, original_path: Path) -> CompilerInput: + try: + with resolved_path.open() as f: + code = f.read() + except (FileNotFoundError, NotADirectoryError): + raise _NotFound(resolved_path) + + source_id = super()._generate_source_id(resolved_path) + + return FileInput(source_id, original_path, resolved_path, code) - return FileInput(source_id, path, code) + +# wrap os.path.normpath, but return the same type as the input +def _normpath(path): + return path.__class__(os.path.normpath(path)) # fake filesystem for JSON inputs. takes a base path, and `load_file()` @@ -156,25 +188,28 @@ def __init__(self, input_json, search_paths): # should be checked by caller assert path not in self.input_json - self.input_json[_normpath(path)] = item + self.input_json[path] = item + + def _normalize_path(self, path: PurePath) -> PurePath: + return _normpath(path) - def _load_from_path(self, path: PurePath) -> CompilerInput: + def _load_from_path(self, resolved_path: PurePath, original_path: PurePath) -> CompilerInput: try: - value = self.input_json[path] + value = self.input_json[resolved_path] except KeyError: - raise _NotFound(path) + raise _NotFound(resolved_path) - source_id = super()._generate_source_id(path) + source_id = super()._generate_source_id(resolved_path) if "content" in value: - return FileInput(source_id, path, value["content"]) + return FileInput(source_id, original_path, resolved_path, value["content"]) if "abi" in value: - return ABIInput(source_id, path, value["abi"]) + return ABIInput(source_id, original_path, resolved_path, value["abi"]) # TODO: ethPM support # if isinstance(contents, dict) and "contractTypes" in contents: # unreachable, based on how JSONInputBundle is constructed in # the codebase. - raise JSONError(f"Unexpected type in file: '{path}'") # pragma: nocover + raise JSONError(f"Unexpected type in file: '{resolved_path}'") # pragma: nocover diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index e47f300ba9..6d1e7ef70f 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -1,5 +1,6 @@ import warnings from collections import OrderedDict, deque +from pathlib import PurePath import asttokens @@ -33,8 +34,8 @@ def build_userdoc(compiler_data: CompilerData) -> dict: def build_external_interface_output(compiler_data: CompilerData) -> str: - interface = compiler_data.vyper_module_folded._metadata["type"] - stem = compiler_data.contract_path.stem + interface = compiler_data.vyper_module_folded._metadata["type"].interface + stem = PurePath(compiler_data.contract_path).stem # capitalize words separated by '_' # ex: test_interface.vy -> TestInterface name = "".join([x.capitalize() for x in stem.split("_")]) @@ -52,7 +53,7 @@ def build_external_interface_output(compiler_data: CompilerData) -> str: def build_interface_output(compiler_data: CompilerData) -> str: - interface = compiler_data.vyper_module_folded._metadata["type"] + interface = compiler_data.vyper_module_folded._metadata["type"].interface out = "" if interface.events: @@ -70,7 +71,7 @@ def build_interface_output(compiler_data: CompilerData) -> str: out = f"{out}@{func.mutability.value}\n" args = ", ".join([f"{arg.name}: {arg.typ}" for arg in func.arguments]) return_value = f" -> {func.return_type}" if func.return_type is not None else "" - out = f"{out}@external\ndef {func.name}({args}){return_value}:\n pass\n\n" + out = f"{out}@external\ndef {func.name}({args}){return_value}:\n ...\n\n" return out @@ -154,14 +155,19 @@ def _to_dict(func_t): def build_method_identifiers_output(compiler_data: CompilerData) -> dict: - interface = compiler_data.vyper_module_folded._metadata["type"] - functions = interface.functions.values() + module_t = compiler_data.vyper_module_folded._metadata["type"] + functions = module_t.function_defs - return {k: hex(v) for func in functions for k, v in func.method_ids.items()} + return { + k: hex(v) for func in functions for k, v in func._metadata["func_type"].method_ids.items() + } def build_abi_output(compiler_data: CompilerData) -> list: - abi = compiler_data.vyper_module_folded._metadata["type"].to_toplevel_abi_dict() + module_t = compiler_data.vyper_module_folded._metadata["type"] + _ = compiler_data.ir_runtime # ensure _ir_info is generated + + abi = module_t.interface.to_toplevel_abi_dict() if compiler_data.show_gas_estimates: # Add gas estimates for each function to ABI gas_estimates = build_gas_estimates(compiler_data.function_signatures) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 4e32812fee..edffa9a85e 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -7,18 +7,18 @@ from vyper import ast as vy_ast from vyper.codegen import module from vyper.codegen.core import anchor_opt_level -from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode -from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle +from vyper.compiler.input_bundle import FileInput, FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings from vyper.exceptions import StructureException from vyper.ir import compile_ir, optimizer from vyper.semantics import set_data_positions, validate_semantics from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout from vyper.venom import generate_assembly_experimental, generate_ir -DEFAULT_CONTRACT_NAME = PurePath("VyperContract.vy") +DEFAULT_CONTRACT_PATH = PurePath("VyperContract.vy") class CompilerData: @@ -35,7 +35,7 @@ class CompilerData: Top-level Vyper AST node vyper_module_folded : vy_ast.Module Folded Vyper AST - global_ctx : GlobalContext + global_ctx : ModuleT Sorted, contextualized representation of the Vyper AST ir_nodes : IRnode IR used to generate deployment bytecode @@ -53,10 +53,8 @@ class CompilerData: def __init__( self, - source_code: str, + file_input: FileInput | str, input_bundle: InputBundle = None, - contract_path: Path | PurePath = DEFAULT_CONTRACT_NAME, - source_id: int = 0, settings: Settings = None, storage_layout: StorageLayout = None, show_gas_estimates: bool = False, @@ -68,12 +66,10 @@ def __init__( Arguments --------- - source_code: str - Vyper source code. - contract_path: Path, optional - The name of the contract being compiled. - source_id: int, optional - ID number used to identify this contract in the source map. + file_input: FileInput | str + A FileInput or string representing the input to the compiler. + FileInput is preferred, but `str` is accepted as a convenience + method (and also for backwards compatibility reasons) settings: Settings Set optimization mode. show_gas_estimates: bool, optional @@ -85,9 +81,15 @@ def __init__( """ # to force experimental codegen, uncomment: # experimental_codegen = True - self.contract_path = contract_path - self.source_code = source_code - self.source_id = source_id + + if isinstance(file_input, str): + file_input = FileInput( + source_code=file_input, + source_id=-1, + path=DEFAULT_CONTRACT_PATH, + resolved_path=DEFAULT_CONTRACT_PATH, + ) + self.file_input = file_input self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata @@ -97,10 +99,26 @@ def __init__( _ = self._generate_ast # force settings to be calculated + @cached_property + def source_code(self): + return self.file_input.source_code + + @cached_property + def source_id(self): + return self.file_input.source_id + + @cached_property + def contract_path(self): + return self.file_input.path + @cached_property def _generate_ast(self): - contract_name = str(self.contract_path) - settings, ast = generate_ast(self.source_code, self.source_id, contract_name) + settings, ast = vy_ast.parse_to_ast_with_settings( + self.source_code, + self.source_id, + module_path=str(self.contract_path), + resolved_path=str(self.file_input.resolved_path), + ) # validate the compiler settings # XXX: this is a bit ugly, clean up later @@ -141,12 +159,12 @@ def vyper_module_unfolded(self) -> vy_ast.Module: # This phase is intended to generate an AST for tooling use, and is not # used in the compilation process. - return generate_unfolded_ast(self.contract_path, self.vyper_module, self.input_bundle) + return generate_unfolded_ast(self.vyper_module, self.input_bundle) @cached_property def _folded_module(self): return generate_folded_ast( - self.contract_path, self.vyper_module, self.input_bundle, self.storage_layout_override + self.vyper_module, self.input_bundle, self.storage_layout_override ) @property @@ -160,8 +178,8 @@ def storage_layout(self) -> StorageLayout: return storage_layout @property - def global_ctx(self) -> GlobalContext: - return GlobalContext(self.vyper_module_folded) + def global_ctx(self) -> ModuleT: + return self.vyper_module_folded._metadata["type"] @cached_property def _ir_output(self): @@ -189,7 +207,7 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: _ = self._ir_output fs = self.vyper_module_folded.get_children(vy_ast.FunctionDef) - return {f.name: f._metadata["type"] for f in fs} + return {f.name: f._metadata["func_type"] for f in fs} @cached_property def assembly(self) -> list: @@ -230,37 +248,12 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_ast( - source_code: str, source_id: int, contract_name: str -) -> tuple[Settings, vy_ast.Module]: - """ - Generate a Vyper AST from source code. - - Arguments - --------- - source_code : str - Vyper source code. - source_id : int - ID number used to identify this contract in the source map. - contract_name: str - Name of the contract. - - Returns - ------- - vy_ast.Module - Top-level Vyper AST node - """ - return vy_ast.parse_to_ast_with_settings(source_code, source_id, contract_name) - - # destructive -- mutates module in place! -def generate_unfolded_ast( - contract_path: Path | PurePath, vyper_module: vy_ast.Module, input_bundle: InputBundle -) -> vy_ast.Module: +def generate_unfolded_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: vy_ast.validation.validate_literal_nodes(vyper_module) vy_ast.folding.replace_builtin_functions(vyper_module) - with input_bundle.search_path(contract_path.parent): + with input_bundle.search_path(Path(vyper_module.resolved_path).parent): # note: validate_semantics does type inference on the AST validate_semantics(vyper_module, input_bundle) @@ -268,7 +261,6 @@ def generate_unfolded_ast( def generate_folded_ast( - contract_path: Path, vyper_module: vy_ast.Module, input_bundle: InputBundle, storage_layout_overrides: StorageLayout = None, @@ -294,7 +286,7 @@ def generate_folded_ast( vyper_module_folded = copy.deepcopy(vyper_module) vy_ast.folding.fold(vyper_module_folded) - with input_bundle.search_path(contract_path.parent): + with input_bundle.search_path(Path(vyper_module.resolved_path).parent): validate_semantics(vyper_module_folded, input_bundle) symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides) @@ -302,9 +294,7 @@ def generate_folded_ast( return vyper_module_folded, symbol_tables -def generate_ir_nodes( - global_ctx: GlobalContext, optimize: OptimizationLevel -) -> tuple[IRnode, IRnode]: +def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. @@ -315,7 +305,7 @@ def generate_ir_nodes( Arguments --------- - global_ctx : GlobalContext + global_ctx: ModuleT Contextualized Vyper AST Returns diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 3bde20356e..993c0a85eb 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -49,6 +49,7 @@ def __init__(self, message="Error Message not found.", *items): self.message = message self.lineno = None self.col_offset = None + self.annotations = None if len(items) == 1 and isinstance(items[0], tuple) and isinstance(items[0][0], int): # support older exceptions that don't annotate - remove this in the future! @@ -79,7 +80,7 @@ def __str__(self): from vyper import ast as vy_ast from vyper.utils import annotate_source_code - if not hasattr(self, "annotations"): + if not self.annotations: if self.lineno is not None and self.col_offset is not None: return f"line {self.lineno}:{self.col_offset} {self.message}" else: @@ -105,8 +106,9 @@ def __str__(self): if isinstance(node, vy_ast.VyperNode): module_node = node.get_ancestor(vy_ast.Module) - if module_node.get("name") not in (None, ""): - node_msg = f'{node_msg}contract "{module_node.name}:{node.lineno}", ' + + if module_node.get("path") not in (None, ""): + node_msg = f'{node_msg}contract "{module_node.path}:{node.lineno}", ' fn_node = node.get_ancestor(vy_ast.FunctionDef) if fn_node: @@ -229,6 +231,18 @@ class CallViolation(VyperException): """Illegal function call.""" +class ImportCycle(VyperException): + """An import cycle""" + + +class DuplicateImport(VyperException): + """A module was imported twice from the same module""" + + +class ModuleNotFound(VyperException): + """Module was not found""" + + class ImmutableViolation(VyperException): """Modifying an immutable variable, constant, or definition.""" diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 7db230167e..7b52a68e92 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -1,17 +1,4 @@ -import vyper.ast as vy_ast - from .. import types # break a dependency cycle. -from ..namespace import get_namespace -from .local import validate_functions -from .module import add_module_namespace -from .utils import _ExprAnalyser - - -def validate_semantics(vyper_ast, input_bundle): - # validate semantics and annotate AST with type/semantics information - namespace = get_namespace() +from .module import validate_semantics - with namespace.enter_scope(): - add_module_namespace(vyper_ast, input_bundle) - vy_ast.expansion.expand_annotated_ast(vyper_ast) - validate_functions(vyper_ast) +__all__ = ["validate_semantics"] diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 449e6ca338..4d1b1cdbab 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,8 +1,9 @@ import enum from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union from vyper import ast as vy_ast +from vyper.compiler.input_bundle import InputBundle from vyper.exceptions import ( CompilerPanic, ImmutableViolation, @@ -12,6 +13,9 @@ from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +if TYPE_CHECKING: + from vyper.semantics.types.module import InterfaceT, ModuleT + class _StringEnum(enum.Enum): @staticmethod @@ -145,6 +149,35 @@ def __repr__(self): return f"" +# base class for things that are the "result" of analysis +class AnalysisResult: + pass + + +@dataclass +class ModuleInfo(AnalysisResult): + module_t: "ModuleT" + + @property + def module_node(self): + return self.module_t._module + + # duck type, conform to interface of VarInfo and ExprInfo + @property + def typ(self): + return self.module_t + + +@dataclass +class ImportInfo(AnalysisResult): + typ: Union[ModuleInfo, "InterfaceT"] + alias: str # the name in the namespace + qualified_module_name: str # for error messages + # source_id: int + input_bundle: InputBundle + node: vy_ast.VyperNode + + @dataclass class VarInfo: """ @@ -212,6 +245,10 @@ def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": is_immutable=var_info.is_immutable, ) + @classmethod + def from_moduleinfo(cls, module_info: ModuleInfo) -> "ExprInfo": + return cls(module_info.module_t) + def copy_with_type(self, typ: VyperType) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else diff --git a/vyper/semantics/analysis/common.py b/vyper/semantics/analysis/common.py index 507eb0a570..9d35aef2bd 100644 --- a/vyper/semantics/analysis/common.py +++ b/vyper/semantics/analysis/common.py @@ -1,6 +1,17 @@ +import contextlib from typing import Tuple -from vyper.exceptions import StructureException +from vyper.exceptions import StructureException, VyperException + + +@contextlib.contextmanager +def tag_exceptions(node): + try: + yield + except VyperException as e: + if not e.annotations and not e.lineno: + raise e.with_annotation(node) from None + raise e from None class VyperNodeVisitorBase: @@ -16,9 +27,11 @@ def visit(self, node, *args): # node types with a shared parent for class_ in node.__class__.mro(): ast_type = class_.__name__ - visitor_fn = getattr(self, f"visit_{ast_type}", None) - if visitor_fn: - return visitor_fn(node, *args) + + with tag_exceptions(node): + visitor_fn = getattr(self, f"visit_{ast_type}", None) + if visitor_fn: + return visitor_fn(node, *args) node_type = type(node).__name__ raise StructureException( diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 87ec45c40d..88679a4b09 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -79,7 +79,7 @@ def set_storage_slots_with_overrides( # Search through function definitions to find non-reentrant functions for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["type"] + type_ = node._metadata["func_type"] # Ignore functions without non-reentrant if type_.nonreentrant is None: @@ -165,7 +165,7 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: ret: Dict[str, Dict] = {} for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["type"] + type_ = node._metadata["func_type"] if type_.nonreentrant is None: continue diff --git a/vyper/semantics/analysis/import_graph.py b/vyper/semantics/analysis/import_graph.py new file mode 100644 index 0000000000..e406878194 --- /dev/null +++ b/vyper/semantics/analysis/import_graph.py @@ -0,0 +1,37 @@ +import contextlib +from dataclasses import dataclass, field +from typing import Iterator + +from vyper import ast as vy_ast +from vyper.exceptions import CompilerPanic, ImportCycle + +""" +data structure for collecting import statements and validating the +import graph +""" + + +@dataclass +class ImportGraph: + # the current path in the import graph traversal + _path: list[vy_ast.Module] = field(default_factory=list) + + def push_path(self, module_ast: vy_ast.Module) -> None: + if module_ast in self._path: + cycle = self._path + [module_ast] + raise ImportCycle(" imports ".join(f'"{t.path}"' for t in cycle)) + + self._path.append(module_ast) + + def pop_path(self, expected: vy_ast.Module) -> None: + popped = self._path.pop() + if expected != popped: + raise CompilerPanic("unreachable") + + @contextlib.contextmanager + def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]: + self.push_path(module_ast) + try: + yield + finally: + self.pop_path(module_ast) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 647f01c299..974c14f261 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -55,14 +55,15 @@ def validate_functions(vy_module: vy_ast.Module) -> None: - """Analyzes a vyper ast and validates the function-level namespaces.""" + """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() namespace = get_namespace() for node in vy_module.get_children(vy_ast.FunctionDef): with namespace.enter_scope(): try: - FunctionNodeVisitor(vy_module, node, namespace) + analyzer = FunctionNodeVisitor(vy_module, node, namespace) + analyzer.analyze() except VyperException as e: err_list.append(e) @@ -185,26 +186,31 @@ def __init__( self.vyper_module = vyper_module self.fn_node = fn_node self.namespace = namespace - self.func = fn_node._metadata["type"] + self.func = fn_node._metadata["func_type"] self.expr_visitor = _ExprVisitor(self.func) + def analyze(self): # allow internal function params to be mutable location, is_immutable = ( (DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True) ) for arg in self.func.arguments: - namespace[arg.name] = VarInfo(arg.typ, location=location, is_immutable=is_immutable) + self.namespace[arg.name] = VarInfo( + arg.typ, location=location, is_immutable=is_immutable + ) - for node in fn_node.body: + for node in self.fn_node.body: self.visit(node) + if self.func.return_type: - if not check_for_terminus(fn_node.body): + if not check_for_terminus(self.fn_node.body): raise FunctionDeclarationException( - f"Missing or unmatched return statements in function '{fn_node.name}'", fn_node + f"Missing or unmatched return statements in function '{self.fn_node.name}'", + self.fn_node, ) # visit default args - assert self.func.n_keyword_args == len(fn_node.args.defaults) + assert self.func.n_keyword_args == len(self.fn_node.args.defaults) for kwarg in self.func.keyword_args: self.expr_visitor.visit(kwarg.default_value, kwarg.typ) @@ -224,10 +230,7 @@ def visit_AnnAssign(self, node): typ = type_from_annotation(node.annotation, DataLocation.MEMORY) validate_expected_type(node.value, typ) - try: - self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) self.expr_visitor.visit(node.target, typ) self.expr_visitor.visit(node.value, typ) @@ -290,6 +293,13 @@ def visit_Continue(self, node): raise StructureException("`continue` must be enclosed in a `for` loop", node) def visit_Expr(self, node): + if isinstance(node.value, vy_ast.Ellipsis): + raise StructureException( + "`...` is not allowed in `.vy` files! " + "Did you mean to import me as a `.vyi` file?", + node, + ) + if not isinstance(node.value, vy_ast.Call): raise StructureException("Expressions without assignment are disallowed", node) @@ -433,6 +443,7 @@ def visit_For(self, node): # Check if `iter` is a storage variable. get_descendants` is used to check for # nested `self` (e.g. structs) + # NOTE: this analysis will be borked once stateful modules are allowed! iter_is_storage_var = ( isinstance(node.iter, vy_ast.Attribute) and len(node.iter.get_descendants(vy_ast.Name, {"id": "self"})) > 0 @@ -453,8 +464,11 @@ def visit_For(self, node): call_node, ) - for name in self.namespace["self"].typ.members[fn_name].recursive_calls: + for reachable_t in ( + self.namespace["self"].typ.members[fn_name].reachable_internal_functions + ): # check for indirect modification + name = reachable_t.name fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0] if _check_iterator_modification(node.iter, fn_node): raise ImmutableViolation( @@ -472,10 +486,7 @@ def visit_For(self, node): # type check the for loop body using each possible type for iterator value with self.namespace.enter_scope(): - try: - self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) try: with NodeMetadata.enter_typechecker_speculation(): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 239438f35b..7aa661aec3 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -1,6 +1,6 @@ import os from pathlib import Path, PurePath -from typing import Optional +from typing import Any, Optional import vyper.builtins.interfaces from vyper import ast as vy_ast @@ -8,9 +8,11 @@ from vyper.evm.opcodes import version_check from vyper.exceptions import ( CallViolation, + DuplicateImport, ExceptionList, InvalidLiteral, InvalidType, + ModuleNotFound, NamespaceCollision, StateAccessViolation, StructureException, @@ -18,128 +20,200 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import ImportInfo, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase -from vyper.semantics.analysis.utils import check_constant, validate_expected_type +from vyper.semantics.analysis.import_graph import ImportGraph +from vyper.semantics.analysis.local import validate_functions +from vyper.semantics.analysis.utils import ( + check_constant, + get_exact_type_from_node, + validate_expected_type, +) from vyper.semantics.data_locations import DataLocation -from vyper.semantics.namespace import Namespace, get_namespace +from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation -def add_module_namespace(vy_module: vy_ast.Module, input_bundle: InputBundle) -> None: +def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: + return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) + + +def validate_semantics_r( + module_ast: vy_ast.Module, + input_bundle: InputBundle, + import_graph: ImportGraph, + is_interface: bool, +) -> ModuleT: """ Analyze a Vyper module AST node, add all module-level objects to the - namespace and validate top-level correctness + namespace, type-check/validate semantics and annotate with type and analysis info """ - + # validate semantics and annotate AST with type/semantics information namespace = get_namespace() - ModuleAnalyzer(vy_module, input_bundle, namespace) + with namespace.enter_scope(), import_graph.enter_path(module_ast): + analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace, import_graph, is_interface) + ret = analyzer.analyze() + + vy_ast.expansion.generate_public_variable_getters(module_ast) + + # if this is an interface, the function is already validated + # in `ContractFunction.from_vyi()` + if not is_interface: + validate_functions(module_ast) + + return ret + + +# compute reachable set and validate the call graph (detect cycles) +def _compute_reachable_set(fn_t: ContractFunctionT, path: list[ContractFunctionT] = None) -> None: + path = path or [] + + path.append(fn_t) + root = path[0] -def _find_cyclic_call(fn_names: list, self_members: dict) -> Optional[list]: - if fn_names[-1] not in self_members: - return None - internal_calls = self_members[fn_names[-1]].internal_calls - for name in internal_calls: - if name in fn_names: - return fn_names + [name] - sequence = _find_cyclic_call(fn_names + [name], self_members) - if sequence: - return sequence - return None + for g in fn_t.called_functions: + if g == root: + message = " -> ".join([f.name for f in path]) + raise CallViolation(f"Contract contains cyclic function call: {message}") + + _compute_reachable_set(g, path=path) + + for h in g.reachable_internal_functions: + assert h != fn_t # sanity check + + fn_t.reachable_internal_functions.add(h) + + fn_t.reachable_internal_functions.add(g) + + path.pop() class ModuleAnalyzer(VyperNodeVisitorBase): scope_name = "module" def __init__( - self, module_node: vy_ast.Module, input_bundle: InputBundle, namespace: Namespace + self, + module_node: vy_ast.Module, + input_bundle: InputBundle, + namespace: Namespace, + import_graph: ImportGraph, + is_interface: bool = False, ) -> None: self.ast = module_node self.input_bundle = input_bundle self.namespace = namespace + self._import_graph = import_graph + self.is_interface = is_interface - # TODO: Move computation out of constructor - module_nodes = module_node.body.copy() - while module_nodes: - count = len(module_nodes) + # keep track of imported modules to prevent duplicate imports + self._imported_modules: dict[PurePath, vy_ast.VyperNode] = {} + + self.module_t: Optional[ModuleT] = None + + # ast cache, hitchhike onto the input_bundle object + if not hasattr(self.input_bundle._cache, "_ast_of"): + self.input_bundle._cache._ast_of: dict[int, vy_ast.Module] = {} # type: ignore + + def analyze(self) -> ModuleT: + # generate a `ModuleT` from the top-level node + # note: also validates unique method ids + if "type" in self.ast._metadata: + assert isinstance(self.ast._metadata["type"], ModuleT) + # we don't need to analyse again, skip out + self.module_t = self.ast._metadata["type"] + return self.module_t + + to_visit = self.ast.body.copy() + + # handle imports linearly + # (do this instead of handling in the next block so that + # `self._imported_modules` does not end up with garbage in it after + # exception swallowing). + import_stmts = self.ast.get_children((vy_ast.Import, vy_ast.ImportFrom)) + for node in import_stmts: + self.visit(node) + to_visit.remove(node) + + # keep trying to process all the nodes until we finish or can + # no longer progress. this makes it so we don't need to + # calculate a dependency tree between top-level items. + while len(to_visit) > 0: + count = len(to_visit) err_list = ExceptionList() - for node in list(module_nodes): + for node in to_visit.copy(): try: self.visit(node) - module_nodes.remove(node) - except (InvalidLiteral, InvalidType, VariableDeclarationException): + to_visit.remove(node) + except (InvalidLiteral, InvalidType, VariableDeclarationException) as e: # these exceptions cannot be caused by another statement not yet being # parsed, so we raise them immediately - raise + raise e from None except VyperException as e: err_list.append(e) # Only raise if no nodes were successfully processed. This allows module # level logic to parse regardless of the ordering of code elements. - if count == len(module_nodes): + if count == len(to_visit): err_list.raise_if_not_empty() - # generate an `InterfaceT` from the top-level node - used for building the ABI - # note: also validates unique method ids - interface = InterfaceT.from_ast(module_node) - module_node._metadata["type"] = interface - self.interface = interface # this is useful downstream + self.module_t = ModuleT(self.ast) + self.ast._metadata["type"] = self.module_t # attach namespace to the module for downstream use. _ns = Namespace() # note that we don't just copy the namespace because # there are constructor issues. - _ns.update({k: namespace[k] for k in namespace._scopes[-1]}) # type: ignore - module_node._metadata["namespace"] = _ns + _ns.update({k: self.namespace[k] for k in self.namespace._scopes[-1]}) # type: ignore + self.ast._metadata["namespace"] = _ns + + self.analyze_call_graph() - self_members = namespace["self"].typ.members + return self.module_t + def analyze_call_graph(self): # get list of internal function calls made by each function - function_defs = self.ast.get_children(vy_ast.FunctionDef) - function_names = set(node.name for node in function_defs) - for node in function_defs: - calls_to_self = set( - i.func.attr for i in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}) - ) - # anything that is not a function call will get semantically checked later - calls_to_self = calls_to_self.intersection(function_names) - self_members[node.name].internal_calls = calls_to_self - - for fn_name in sorted(function_names): - if fn_name not in self_members: - # the referenced function does not exist - this is an issue, but we'll report - # it later when parsing the function so we can give more meaningful output - continue - - # check for circular function calls - sequence = _find_cyclic_call([fn_name], self_members) - if sequence is not None: - nodes = [] - for i in range(len(sequence) - 1): - fn_node = self.ast.get_children(vy_ast.FunctionDef, {"name": sequence[i]})[0] - call_node = fn_node.get_descendants( - vy_ast.Attribute, {"value.id": "self", "attr": sequence[i + 1]} - )[0] - nodes.append(call_node) - - raise CallViolation("Contract contains cyclic function call", *nodes) - - # get complete list of functions that are reachable from this function - function_set = set(i for i in self_members[fn_name].internal_calls if i in self_members) - while True: - expanded = set(x for i in function_set for x in self_members[i].internal_calls) - expanded |= function_set - if expanded == function_set: - break - function_set = expanded - - self_members[fn_name].recursive_calls = function_set + function_defs = self.module_t.function_defs + + for func in function_defs: + fn_t = func._metadata["func_type"] + + function_calls = func.get_descendants(vy_ast.Call) + + for call in function_calls: + try: + call_t = get_exact_type_from_node(call.func) + except VyperException: + # either there is a problem getting the call type. this is + # an issue, but it will be handled properly later. right now + # we just want to be able to construct the call graph. + continue + + if isinstance(call_t, ContractFunctionT) and call_t.is_internal: + fn_t.called_functions.add(call_t) + + for func in function_defs: + fn_t = func._metadata["func_type"] + + # compute reachable set and validate the call graph + _compute_reachable_set(fn_t) + + def _ast_from_file(self, file: FileInput) -> vy_ast.Module: + # cache ast if we have seen it before. + # this gives us the additional property of object equality on + # two ASTs produced from the same source + ast_of = self.input_bundle._cache._ast_of + if file.source_id not in ast_of: + ast_of[file.source_id] = _parse_and_fold_ast(file) + + return ast_of[file.source_id] def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) + if not isinstance(type_, InterfaceT): raise StructureException("Invalid interface name", node.annotation) @@ -153,8 +227,9 @@ def visit_VariableDecl(self, node): if node.is_public: # generate function type and add to metadata # we need this when building the public getter - node._metadata["func_type"] = ContractFunctionT.getter_from_VariableDecl(node) + node._metadata["getter_type"] = ContractFunctionT.getter_from_VariableDecl(node) + # TODO: move this check to local analysis if node.is_immutable: # mutability is checked automatically preventing assignment # outside of the constructor, here we just check a value is assigned, @@ -213,22 +288,18 @@ def _finalize(): self.namespace["self"].typ.add_member(name, var_info) node.target._metadata["type"] = type_ except NamespaceCollision: + # rewrite the error message to be slightly more helpful raise NamespaceCollision( f"Value '{name}' has already been declared", node ) from None - except VyperException as exc: - raise exc.with_annotation(node) from None def _validate_self_namespace(): # block globals if storage variable already exists - try: - if name in self.namespace["self"].typ.members: - raise NamespaceCollision( - f"Value '{name}' has already been declared", node - ) from None - self.namespace[name] = var_info - except VyperException as exc: - raise exc.with_annotation(node) from None + if name in self.namespace["self"].typ.members: + raise NamespaceCollision( + f"Value '{name}' has already been declared", node + ) from None + self.namespace[name] = var_info if node.is_constant: if not node.value: @@ -251,41 +322,50 @@ def _validate_self_namespace(): _validate_self_namespace() return _finalize() - try: - self.namespace.validate_assignment(name) - except NamespaceCollision as exc: - raise exc.with_annotation(node) from None + self.namespace.validate_assignment(name) return _finalize() def visit_EnumDef(self, node): obj = EnumT.from_EnumDef(node) - try: - self.namespace[node.name] = obj - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace[node.name] = obj def visit_EventDef(self, node): obj = EventT.from_EventDef(node) - try: - self.namespace[node.name] = obj - except VyperException as exc: - raise exc.with_annotation(node) from None + node._metadata["event_type"] = obj + self.namespace[node.name] = obj def visit_FunctionDef(self, node): - func = ContractFunctionT.from_FunctionDef(node) + if self.is_interface: + func_t = ContractFunctionT.from_vyi(node) + if not func_t.is_external: + # TODO test me! + raise StructureException( + "Internal functions in `.vyi` files are not allowed!", node + ) + else: + func_t = ContractFunctionT.from_FunctionDef(node) - try: - self.namespace["self"].typ.add_member(func.name, func) - node._metadata["type"] = func - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace["self"].typ.add_member(func_t.name, func_t) + node._metadata["func_type"] = func_t def visit_Import(self, node): - if not node.alias: - raise StructureException("Import requires an accompanying `as` statement", node) # import x.y[name] as y[alias] - self._add_import(node, 0, node.name, node.alias) + + alias = node.alias + + if alias is None: + alias = node.name + + # don't handle things like `import x.y` + if "." in alias: + suggested_alias = node.name[node.name.rfind(".") :] + suggestion = f"hint: try `import {node.name} as {suggested_alias}`" + raise StructureException( + f"import requires an accompanying `as` statement ({suggestion})", node + ) + + self._add_import(node, 0, node.name, alias) def visit_ImportFrom(self, node): # from m.n[module] import x[name] as y[alias] @@ -299,42 +379,87 @@ def visit_ImportFrom(self, node): self._add_import(node, node.level, qualified_module_name, alias) def visit_InterfaceDef(self, node): - obj = InterfaceT.from_ast(node) - try: - self.namespace[node.name] = obj - except VyperException as exc: - raise exc.with_annotation(node) from None + obj = InterfaceT.from_InterfaceDef(node) + self.namespace[node.name] = obj def visit_StructDef(self, node): - struct_t = StructT.from_ast_def(node) - try: - self.namespace[node.name] = struct_t - except VyperException as exc: - raise exc.with_annotation(node) from None + struct_t = StructT.from_StructDef(node) + node._metadata["struct_type"] = struct_t + self.namespace[node.name] = struct_t def _add_import( self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str ) -> None: - type_ = self._load_import(level, qualified_module_name) - - try: - self.namespace[alias] = type_ - except VyperException as exc: - raise exc.with_annotation(node) from None + module_info = self._load_import(node, level, qualified_module_name, alias) + node._metadata["import_info"] = ImportInfo( + module_info, alias, qualified_module_name, self.input_bundle, node + ) + self.namespace[alias] = module_info - # load an InterfaceT from an import. + # load an InterfaceT or ModuleInfo from an import. # raises FileNotFoundError - def _load_import(self, level: int, module_str: str) -> InterfaceT: + def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str) -> Any: + # the directory this (currently being analyzed) module is in + self_search_path = Path(self.ast.resolved_path).parent + + with self.input_bundle.poke_search_path(self_search_path): + return self._load_import_helper(node, level, module_str, alias) + + def _load_import_helper( + self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str + ) -> Any: if _is_builtin(module_str): return _load_builtin_import(level, module_str) path = _import_to_path(level, module_str) + # this could conceivably be in the ImportGraph but no need at this point + if path in self._imported_modules: + previous_import_stmt = self._imported_modules[path] + raise DuplicateImport(f"{alias} imported more than once!", previous_import_stmt, node) + + self._imported_modules[path] = node + + err = None + + try: + path_vy = path.with_suffix(".vy") + file = self.input_bundle.load_file(path_vy) + assert isinstance(file, FileInput) # mypy hint + + module_ast = self._ast_from_file(file) + + with override_global_namespace(Namespace()): + module_t = validate_semantics_r( + module_ast, + self.input_bundle, + import_graph=self._import_graph, + is_interface=False, + ) + + return ModuleInfo(module_t) + + except FileNotFoundError as e: + # escape `e` from the block scope, it can make things + # easier to debug. + err = e + try: - file = self.input_bundle.load_file(path.with_suffix(".vy")) + file = self.input_bundle.load_file(path.with_suffix(".vyi")) assert isinstance(file, FileInput) # mypy hint - interface_ast = vy_ast.parse_to_ast(file.source_code, contract_name=str(file.path)) - return InterfaceT.from_ast(interface_ast) + module_ast = self._ast_from_file(file) + + with override_global_namespace(Namespace()): + validate_semantics_r( + module_ast, + self.input_bundle, + import_graph=self._import_graph, + is_interface=True, + ) + module_t = module_ast._metadata["type"] + + return module_t.interface + except FileNotFoundError: pass @@ -343,7 +468,24 @@ def _load_import(self, level: int, module_str: str) -> InterfaceT: assert isinstance(file, ABIInput) # mypy hint return InterfaceT.from_json_abi(str(file.path), file.abi) except FileNotFoundError: - raise ModuleNotFoundError(module_str) + pass + + # copy search_paths, makes debugging a bit easier + search_paths = self.input_bundle.search_paths.copy() # noqa: F841 + raise ModuleNotFound(module_str, node) from err + + +def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: + ret = vy_ast.parse_to_ast( + file.source_code, + source_id=file.source_id, + module_path=str(file.path), + resolved_path=str(file.resolved_path), + ) + vy_ast.validation.validate_literal_nodes(ret) + vy_ast.folding.fold(ret) + + return ret # convert an import to a path (without suffix) @@ -385,7 +527,7 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: remapped_module = remapped_module.removeprefix("vyper.interfaces") remapped_module = vyper.builtins.interfaces.__package__ + remapped_module - path = _import_to_path(level, remapped_module).with_suffix(".vy") + path = _import_to_path(level, remapped_module).with_suffix(".vyi") try: file = input_bundle.load_file(path) @@ -394,5 +536,8 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: raise ModuleNotFoundError(f"Not a builtin: {module_str}") from None # TODO: it might be good to cache this computation - interface_ast = vy_ast.parse_to_ast(file.source_code, contract_name=module_str) - return InterfaceT.from_ast(interface_ast) + interface_ast = _parse_and_fold_ast(file) + + with override_global_namespace(Namespace()): + module_t = validate_semantics(interface_ast, input_bundle, is_interface=True) + return module_t.interface diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index afa6b56838..1785afd92d 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,7 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, VarInfo +from vyper.semantics.analysis.base import ExprInfo, ModuleInfo, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -66,8 +66,15 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: # if it's a Name, we have varinfo for it if isinstance(node, vy_ast.Name): - varinfo = self.namespace[node.id] - return ExprInfo.from_varinfo(varinfo) + info = self.namespace[node.id] + + if isinstance(info, VarInfo): + return ExprInfo.from_varinfo(info) + + if isinstance(info, ModuleInfo): + return ExprInfo.from_moduleinfo(info) + + raise CompilerPanic("unreachable!", node) if isinstance(node, vy_ast.Attribute): # if it's an Attr, we check the parent exprinfo and @@ -192,16 +199,17 @@ def _raise_invalid_reference(name, node): try: s = t.get_member(name, node) - if isinstance(s, VyperType): + + if isinstance(s, (VyperType, TYPE_T)): # ex. foo.bar(). bar() is a ContractFunctionT return [s] if is_self_reference and (s.is_constant or s.is_immutable): _raise_invalid_reference(name, node) # general case. s is a VarInfo, e.g. self.foo return [s.typ] - except UnknownAttribute: + except UnknownAttribute as e: if not is_self_reference: - raise + raise e from None if name in self.namespace: _raise_invalid_reference(name, node) @@ -364,6 +372,7 @@ def types_from_Name(self, node): return [TYPE_T(t)] return [t.typ] + except VyperException as exc: raise exc.with_annotation(node) from None diff --git a/vyper/semantics/namespace.py b/vyper/semantics/namespace.py index 613ac0c03b..4df2511a29 100644 --- a/vyper/semantics/namespace.py +++ b/vyper/semantics/namespace.py @@ -95,7 +95,7 @@ def validate_assignment(self, attr): def get_namespace(): """ - Get the active namespace object. + Get the global namespace object. """ global _namespace try: diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index ad470718c8..1fef6a706e 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -2,9 +2,10 @@ from .base import TYPE_T, KwargSettings, VyperType, is_type_t from .bytestrings import BytesT, StringT, _BytestringT from .function import MemberFunctionT +from .module import InterfaceT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT -from .user import EnumT, EventT, InterfaceT, StructT +from .user import EnumT, EventT, StructT def _get_primitive_types(): diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index c5af5c2a39..d22d9bfff9 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -44,6 +44,13 @@ class VyperType: A tuple of invalid `DataLocation`s for this type _is_prim_word: bool, optional This is a word type like uint256, int8, bytesM or address + _supports_external_calls: bool, optional + Whether or not this type supports external calls. Currently + limited to `InterfaceT`s + _attribute_in_annotation: bool, optional + Whether or not this type can be attributed in a type + annotation, like IFoo.SomeType. Currently limited to + `InterfaceT`s. """ _id: str @@ -58,6 +65,9 @@ class VyperType: _as_array: bool = False # rename to something like can_be_array_member _as_hashmap_key: bool = False + _supports_external_calls: bool = False + _attribute_in_annotation: bool = False + size_in_bytes = 32 # default; override for larger types def __init__(self, members: Optional[Dict] = None) -> None: @@ -261,7 +271,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional["VyperType"]: VyperType, optional Type generated as a result of the call. """ - raise StructureException("Value is not callable", node) + raise StructureException(f"{self} is not callable", node) @classmethod def get_subscripted_type(self, node: vy_ast.Index) -> None: diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index 09130626aa..e3c381ac69 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -132,7 +132,15 @@ def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": raise UnexpectedValue("Node id does not match type name") length = get_index_value(node.slice) # type: ignore - # return cls._type(length, location, is_constant, is_public, is_immutable) + + if length is None: + raise StructureException( + f"Cannot declare {cls._id} type without a maximum length, e.g. {cls._id}[5]", node + ) + + # TODO: pass None to constructor after we redo length inference on bytestrings + length = length or 0 + return cls(length) @classmethod diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 140f73f095..ec30ac85d6 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -17,7 +17,11 @@ StructureException, ) from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot -from vyper.semantics.analysis.utils import check_kwargable, validate_expected_type +from vyper.semantics.analysis.utils import ( + check_kwargable, + get_exact_type_from_node, + validate_expected_type, +) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import KwargSettings, VyperType from vyper.semantics.types.primitives import BoolT @@ -44,6 +48,7 @@ class KeywordArg(_FunctionArg): ast_source: Optional[vy_ast.VyperNode] = None +# TODO: refactor this into FunctionT (from an ast) and ABIFunctionT (from json) class ContractFunctionT(VyperType): """ Contract function type. @@ -81,6 +86,7 @@ def __init__( function_visibility: FunctionVisibility, state_mutability: StateMutability, nonreentrant: Optional[str] = None, + ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: super().__init__() @@ -92,11 +98,18 @@ def __init__( self.mutability = state_mutability self.nonreentrant = nonreentrant - # a list of internal functions this function calls - self.called_functions = OrderedSet[ContractFunctionT]() + self.ast_def = ast_def + + # a list of internal functions this function calls. + # to be populated during analysis + self.called_functions: OrderedSet[ContractFunctionT] = OrderedSet() + + # recursively reachable from this function + self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() # to be populated during codegen self._ir_info: Any = None + self._function_id: Optional[int] = None @cached_property def call_site_kwargs(self): @@ -126,7 +139,7 @@ def __hash__(self): return hash(id(self)) @classmethod - def from_abi(cls, abi: Dict) -> "ContractFunctionT": + def from_abi(cls, abi: dict) -> "ContractFunctionT": """ Generate a `ContractFunctionT` object from an ABI interface. @@ -157,190 +170,174 @@ def from_abi(cls, abi: Dict) -> "ContractFunctionT": ) @classmethod - def from_FunctionDef( - cls, node: vy_ast.FunctionDef, is_interface: Optional[bool] = False - ) -> "ContractFunctionT": + def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": """ - Generate a `ContractFunctionT` object from a `FunctionDef` node. + Generate a `ContractFunctionT` object from a `FunctionDef` inside + of an `InterfaceDef` Arguments --------- - node : FunctionDef + funcdef: FunctionDef Vyper ast node to generate the function definition from. - is_interface: bool, optional - Boolean indicating if the function definition is part of an interface. Returns ------- ContractFunctionT """ - kwargs: Dict[str, Any] = {} - if is_interface: - # FunctionDef with stateMutability in body (Interface defintions) - if ( - len(node.body) == 1 - and isinstance(node.body[0], vy_ast.Expr) - and isinstance(node.body[0].value, vy_ast.Name) - and StateMutability.is_valid_value(node.body[0].value.id) - ): - # Interfaces are always public - kwargs["function_visibility"] = FunctionVisibility.EXTERNAL - kwargs["state_mutability"] = StateMutability(node.body[0].value.id) - elif len(node.body) == 1 and node.body[0].get("value.id") in ("constant", "modifying"): - if node.body[0].value.id == "constant": - expected = "view or pure" - else: - expected = "payable or nonpayable" - raise StructureException( - f"State mutability should be set to {expected}", node.body[0] - ) + # FunctionDef with stateMutability in body (Interface defintions) + body = funcdef.body + if ( + len(body) == 1 + and isinstance(body[0], vy_ast.Expr) + and isinstance(body[0].value, vy_ast.Name) + and StateMutability.is_valid_value(body[0].value.id) + ): + # Interfaces are always public + function_visibility = FunctionVisibility.EXTERNAL + state_mutability = StateMutability(body[0].value.id) + # handle errors + elif len(body) == 1 and body[0].get("value.id") in ("constant", "modifying"): + if body[0].value.id == "constant": + expected = "view or pure" else: - raise StructureException( - "Body must only contain state mutability label", node.body[0] - ) - + expected = "payable or nonpayable" + raise StructureException(f"State mutability should be set to {expected}", body[0]) else: - # FunctionDef with decorators (normal functions) - for decorator in node.decorator_list: - if isinstance(decorator, vy_ast.Call): - if "nonreentrant" in kwargs: - raise StructureException( - "nonreentrant decorator is already set with key: " - f"{kwargs['nonreentrant']}", - node, - ) + raise StructureException("Body must only contain state mutability label", body[0]) - if decorator.get("func.id") != "nonreentrant": - raise StructureException("Decorator is not callable", decorator) - if len(decorator.args) != 1 or not isinstance(decorator.args[0], vy_ast.Str): - raise StructureException( - "@nonreentrant name must be given as a single string literal", decorator - ) + if funcdef.name == "__init__": + raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) - if node.name == "__init__": - msg = "Nonreentrant decorator disallowed on `__init__`" - raise FunctionDeclarationException(msg, decorator) - - nonreentrant_key = decorator.args[0].value - validate_identifier(nonreentrant_key, decorator.args[0]) - - kwargs["nonreentrant"] = nonreentrant_key - - elif isinstance(decorator, vy_ast.Name): - if FunctionVisibility.is_valid_value(decorator.id): - if "function_visibility" in kwargs: - raise FunctionDeclarationException( - f"Visibility is already set to: {kwargs['function_visibility']}", - node, - ) - kwargs["function_visibility"] = FunctionVisibility(decorator.id) - - elif StateMutability.is_valid_value(decorator.id): - if "state_mutability" in kwargs: - raise FunctionDeclarationException( - f"Mutability is already set to: {kwargs['state_mutability']}", node - ) - kwargs["state_mutability"] = StateMutability(decorator.id) - - else: - if decorator.id == "constant": - warnings.warn( - "'@constant' decorator has been removed (see VIP2040). " - "Use `@view` instead.", - DeprecationWarning, - ) - raise FunctionDeclarationException( - f"Unknown decorator: {decorator.id}", decorator - ) + if funcdef.name == "__default__": + raise FunctionDeclarationException( + "Default functions cannot appear in interfaces", funcdef + ) + + positional_args, keyword_args = _parse_args(funcdef) + + return_type = _parse_return_type(funcdef) + + return cls( + funcdef.name, + positional_args, + keyword_args, + return_type, + function_visibility, + state_mutability, + nonreentrant=None, + ast_def=funcdef, + ) + + @classmethod + def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": + """ + Generate a `ContractFunctionT` object from a `FunctionDef` inside + of an interface (`.vyi`) file + + Arguments + --------- + funcdef: FunctionDef + Vyper ast node to generate the function definition from. + + Returns + ------- + ContractFunctionT + """ + function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + + if nonreentrant_key is not None: + raise FunctionDeclarationException( + "nonreentrant key not allowed in interfaces", funcdef + ) + + if funcdef.name == "__init__": + raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) - else: - raise StructureException("Bad decorator syntax", decorator) + if funcdef.name == "__default__": + raise FunctionDeclarationException( + "Default functions cannot appear in interfaces", funcdef + ) + + positional_args, keyword_args = _parse_args(funcdef) + + return_type = _parse_return_type(funcdef) - if "function_visibility" not in kwargs: + if len(funcdef.body) != 1 or not isinstance(funcdef.body[0].get("value"), vy_ast.Ellipsis): raise FunctionDeclarationException( - f"Visibility must be set to one of: {', '.join(FunctionVisibility.values())}", node + "function body in an interface can only be ...!", funcdef ) - if node.name == "__default__": - if kwargs["function_visibility"] != FunctionVisibility.EXTERNAL: + return cls( + funcdef.name, + positional_args, + keyword_args, + return_type, + function_visibility, + state_mutability, + nonreentrant=nonreentrant_key, + ast_def=funcdef, + ) + + @classmethod + def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": + """ + Generate a `ContractFunctionT` object from a `FunctionDef` node. + + Arguments + --------- + funcdef: FunctionDef + Vyper ast node to generate the function definition from. + + Returns + ------- + ContractFunctionT + """ + function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + + positional_args, keyword_args = _parse_args(funcdef) + + return_type = _parse_return_type(funcdef) + + # validate default and init functions + if funcdef.name == "__default__": + if function_visibility != FunctionVisibility.EXTERNAL: raise FunctionDeclarationException( - "Default function must be marked as `@external`", node + "Default function must be marked as `@external`", funcdef ) - if node.args.args: + if funcdef.args.args: raise FunctionDeclarationException( - "Default function may not receive any arguments", node.args.args[0] + "Default function may not receive any arguments", funcdef.args.args[0] ) - if "state_mutability" not in kwargs: - # Assume nonpayable if not set at all (cannot accept Ether, but can modify state) - kwargs["state_mutability"] = StateMutability.NONPAYABLE - - if kwargs["state_mutability"] == StateMutability.PURE and "nonreentrant" in kwargs: - raise StructureException("Cannot use reentrancy guard on pure functions", node) - - if node.name == "__init__": + if funcdef.name == "__init__": if ( - kwargs["state_mutability"] in (StateMutability.PURE, StateMutability.VIEW) - or kwargs["function_visibility"] == FunctionVisibility.INTERNAL + state_mutability in (StateMutability.PURE, StateMutability.VIEW) + or function_visibility == FunctionVisibility.INTERNAL ): raise FunctionDeclarationException( - "Constructor cannot be marked as `@pure`, `@view` or `@internal`", node + "Constructor cannot be marked as `@pure`, `@view` or `@internal`", funcdef ) - - # call arguments - if node.args.defaults: + if return_type is not None: raise FunctionDeclarationException( - "Constructor may not use default arguments", node.args.defaults[0] + "Constructor may not have a return type", funcdef.returns ) - argnames = set() # for checking uniqueness - n_total_args = len(node.args.args) - n_positional_args = n_total_args - len(node.args.defaults) - - positional_args: list[PositionalArg] = [] - keyword_args: list[KeywordArg] = [] - - for i, arg in enumerate(node.args.args): - argname = arg.arg - if argname in ("gas", "value", "skip_contract_check", "default_return_value"): - raise ArgumentException( - f"Cannot use '{argname}' as a variable name in a function input", arg + # call arguments + if funcdef.args.defaults: + raise FunctionDeclarationException( + "Constructor may not use default arguments", funcdef.args.defaults[0] ) - if argname in argnames: - raise ArgumentException(f"Function contains multiple inputs named {argname}", arg) - - if arg.annotation is None: - raise ArgumentException(f"Function argument '{argname}' is missing a type", arg) - - type_ = type_from_annotation(arg.annotation, DataLocation.CALLDATA) - - if i < n_positional_args: - positional_args.append(PositionalArg(argname, type_, ast_source=arg)) - else: - value = node.args.defaults[i - n_positional_args] - if not check_kwargable(value): - raise StateAccessViolation( - "Value must be literal or environment variable", value - ) - validate_expected_type(value, type_) - keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) - - argnames.add(argname) - # return types - if node.returns is None: - return_type = None - elif node.name == "__init__": - raise FunctionDeclarationException( - "Constructor may not have a return type", node.returns - ) - elif isinstance(node.returns, (vy_ast.Name, vy_ast.Subscript, vy_ast.Tuple)): - # note: consider, for cleanliness, adding DataLocation.RETURN_VALUE - return_type = type_from_annotation(node.returns, DataLocation.MEMORY) - else: - raise InvalidType("Function return value must be a type name or tuple", node.returns) - - return cls(node.name, positional_args, keyword_args, return_type, **kwargs) + return cls( + funcdef.name, + positional_args, + keyword_args, + return_type, + function_visibility, + state_mutability, + nonreentrant=nonreentrant_key, + ast_def=funcdef, + ) def set_reentrancy_key_position(self, position: StorageSlot) -> None: if hasattr(self, "reentrancy_key_position"): @@ -383,6 +380,7 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio return_type, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.VIEW, + ast_def=node, ) @property @@ -489,8 +487,12 @@ def method_ids(self) -> Dict[str, int]: return method_ids def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: - if node.get("func.value.id") == "self" and self.visibility == FunctionVisibility.EXTERNAL: - raise CallViolation("Cannot call external functions via 'self'", node) + # mypy hint - right now, the only way a ContractFunctionT can be + # called is via `Attribute`, e.x. self.foo() or library.bar() + assert isinstance(node.func, vy_ast.Attribute) + parent_t = get_exact_type_from_node(node.func.value) + if not parent_t._supports_external_calls and self.visibility == FunctionVisibility.EXTERNAL: + raise CallViolation("Cannot call external functions via 'self' or via library", node) kwarg_keys = [] # for external calls, include gas and value as optional kwargs @@ -584,6 +586,125 @@ def abi_signature_for_kwargs(self, kwargs: list[KeywordArg]) -> str: return self.name + "(" + ",".join([arg.typ.abi_type.selector_name() for arg in args]) + ")" +def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: + # return types + if funcdef.returns is None: + return None + # note: consider, for cleanliness, adding DataLocation.RETURN_VALUE + return type_from_annotation(funcdef.returns, DataLocation.MEMORY) + + +def _parse_decorators( + funcdef: vy_ast.FunctionDef, +) -> tuple[FunctionVisibility, StateMutability, Optional[str]]: + function_visibility = None + state_mutability = None + nonreentrant_key = None + + for decorator in funcdef.decorator_list: + if isinstance(decorator, vy_ast.Call): + if nonreentrant_key is not None: + raise StructureException( + "nonreentrant decorator is already set with key: " f"{nonreentrant_key}", + funcdef, + ) + + if decorator.get("func.id") != "nonreentrant": + raise StructureException("Decorator is not callable", decorator) + if len(decorator.args) != 1 or not isinstance(decorator.args[0], vy_ast.Str): + raise StructureException( + "@nonreentrant name must be given as a single string literal", decorator + ) + + if funcdef.name == "__init__": + msg = "Nonreentrant decorator disallowed on `__init__`" + raise FunctionDeclarationException(msg, decorator) + + nonreentrant_key = decorator.args[0].value + validate_identifier(nonreentrant_key, decorator.args[0]) + + elif isinstance(decorator, vy_ast.Name): + if FunctionVisibility.is_valid_value(decorator.id): + if function_visibility is not None: + raise FunctionDeclarationException( + f"Visibility is already set to: {function_visibility}", funcdef + ) + function_visibility = FunctionVisibility(decorator.id) + + elif StateMutability.is_valid_value(decorator.id): + if state_mutability is not None: + raise FunctionDeclarationException( + f"Mutability is already set to: {state_mutability}", funcdef + ) + state_mutability = StateMutability(decorator.id) + + else: + if decorator.id == "constant": + warnings.warn( + "'@constant' decorator has been removed (see VIP2040). " + "Use `@view` instead.", + DeprecationWarning, + ) + raise FunctionDeclarationException(f"Unknown decorator: {decorator.id}", decorator) + + else: + raise StructureException("Bad decorator syntax", decorator) + + if function_visibility is None: + raise FunctionDeclarationException( + f"Visibility must be set to one of: {', '.join(FunctionVisibility.values())}", funcdef + ) + + if state_mutability is None: + # default to nonpayable + state_mutability = StateMutability.NONPAYABLE + + if state_mutability == StateMutability.PURE and nonreentrant_key is not None: + raise StructureException("Cannot use reentrancy guard on pure functions", funcdef) + + # assert function_visibility is not None # mypy + # assert state_mutability is not None # mypy + return function_visibility, state_mutability, nonreentrant_key + + +def _parse_args( + funcdef: vy_ast.FunctionDef, is_interface: bool = False +) -> tuple[list[PositionalArg], list[KeywordArg]]: + argnames = set() # for checking uniqueness + n_total_args = len(funcdef.args.args) + n_positional_args = n_total_args - len(funcdef.args.defaults) + + positional_args = [] + keyword_args = [] + + for i, arg in enumerate(funcdef.args.args): + argname = arg.arg + if argname in ("gas", "value", "skip_contract_check", "default_return_value"): + raise ArgumentException( + f"Cannot use '{argname}' as a variable name in a function input", arg + ) + if argname in argnames: + raise ArgumentException(f"Function contains multiple inputs named {argname}", arg) + + if arg.annotation is None: + raise ArgumentException(f"Function argument '{argname}' is missing a type", arg) + + type_ = type_from_annotation(arg.annotation, DataLocation.CALLDATA) + + if i < n_positional_args: + positional_args.append(PositionalArg(argname, type_, ast_source=arg)) + else: + value = funcdef.args.defaults[i - n_positional_args] + if not check_kwargable(value): + raise StateAccessViolation("Value must be literal or environment variable", value) + validate_expected_type(value, type_) + keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) + + argnames.add(argname) + + return positional_args, keyword_args + + class MemberFunctionT(VyperType): """ Member function type definition. diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py new file mode 100644 index 0000000000..4622482951 --- /dev/null +++ b/vyper/semantics/types/module.py @@ -0,0 +1,332 @@ +from functools import cached_property +from typing import Optional + +from vyper import ast as vy_ast +from vyper.abi_types import ABI_Address, ABIType +from vyper.ast.validation import validate_call_args +from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException +from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids +from vyper.semantics.namespace import get_namespace +from vyper.semantics.types.base import TYPE_T, VyperType +from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.primitives import AddressT +from vyper.semantics.types.user import EventT, StructT, _UserType + + +class InterfaceT(_UserType): + _type_members = {"address": AddressT()} + _is_prim_word = True + _as_array = True + _as_hashmap_key = True + _supports_external_calls = True + _attribute_in_annotation = True + + def __init__(self, _id: str, functions: dict, events: dict, structs: dict) -> None: + validate_unique_method_ids(list(functions.values())) + + members = functions | events | structs + + # sanity check: by construction, there should be no duplicates. + assert len(members) == len(functions) + len(events) + len(structs) + + super().__init__(functions) + + self._helper = VyperType(events | structs) + self._id = _id + self.functions = functions + self.events = events + self.structs = structs + + def get_type_member(self, attr, node): + # get an event or struct from this interface + return TYPE_T(self._helper.get_member(attr, node)) + + @property + def getter_signature(self): + return (), AddressT() + + @property + def abi_type(self) -> ABIType: + return ABI_Address() + + def __repr__(self): + return f"interface {self._id}" + + # when using the type itself (not an instance) in the call position + def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": + self._ctor_arg_types(node) + return self + + def _ctor_arg_types(self, node): + validate_call_args(node, 1) + validate_expected_type(node.args[0], AddressT()) + return [AddressT()] + + def _ctor_kwarg_types(self, node): + return {} + + # TODO x.validate_implements(other) + def validate_implements(self, node: vy_ast.ImplementsDecl) -> None: + namespace = get_namespace() + unimplemented = [] + + def _is_function_implemented(fn_name, fn_type): + vyper_self = namespace["self"].typ + if fn_name not in vyper_self.members: + return False + s = vyper_self.members[fn_name] + if isinstance(s, ContractFunctionT): + to_compare = vyper_self.members[fn_name] + # this is kludgy, rework order of passes in ModuleNodeVisitor + elif isinstance(s, VarInfo) and s.is_public: + to_compare = s.decl_node._metadata["getter_type"] + else: + return False + + return to_compare.implements(fn_type) + + # check for missing functions + for name, type_ in self.functions.items(): + if not isinstance(type_, ContractFunctionT): + # ex. address + continue + + if not _is_function_implemented(name, type_): + unimplemented.append(name) + + # check for missing events + for name, event in self.events.items(): + if name not in namespace: + unimplemented.append(name) + continue + + if not isinstance(namespace[name], EventT): + unimplemented.append(f"{name} is not an event!") + if ( + namespace[name].event_id != event.event_id + or namespace[name].indexed != event.indexed + ): + unimplemented.append(f"{name} is not implemented! (should be {event})") + + if len(unimplemented) > 0: + # TODO: improve the error message for cases where the + # mismatch is small (like mutability, or just one argument + # is off, etc). + missing_str = ", ".join(sorted(unimplemented)) + raise InterfaceViolation( + f"Contract does not implement all interface functions or events: {missing_str}", + node, + ) + + def to_toplevel_abi_dict(self) -> list[dict]: + abi = [] + for event in self.events.values(): + abi += event.to_toplevel_abi_dict() + for func in self.functions.values(): + abi += func.to_toplevel_abi_dict() + return abi + + # helper function which performs namespace collision checking + @classmethod + def _from_lists( + cls, + name: str, + function_list: list[tuple[str, ContractFunctionT]], + event_list: list[tuple[str, EventT]], + struct_list: list[tuple[str, StructT]], + ) -> "InterfaceT": + functions = {} + events = {} + structs = {} + + seen_items: dict = {} + + for name, function in function_list: + if name in seen_items: + raise NamespaceCollision(f"multiple functions named '{name}'!", function.ast_def) + functions[name] = function + seen_items[name] = function + + for name, event in event_list: + if name in seen_items: + raise NamespaceCollision( + f"multiple functions or events named '{name}'!", event.decl_node + ) + events[name] = event + seen_items[name] = event + + for name, struct in struct_list: + if name in seen_items: + raise NamespaceCollision( + f"multiple functions or events named '{name}'!", event.decl_node + ) + structs[name] = struct + seen_items[name] = struct + + return cls(name, functions, events, structs) + + @classmethod + def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT": + """ + Generate an `InterfaceT` object from an ABI. + + Arguments + --------- + name : str + The name of the interface + abi : dict + Contract ABI + + Returns + ------- + InterfaceT + primitive interface type + """ + functions: list = [] + events: list = [] + + for item in [i for i in abi if i.get("type") == "function"]: + functions.append((item["name"], ContractFunctionT.from_abi(item))) + for item in [i for i in abi if i.get("type") == "event"]: + events.append((item["name"], EventT.from_abi(item))) + + structs: list = [] # no structs in json ABI (as of yet) + return cls._from_lists(name, functions, events, structs) + + @classmethod + def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": + """ + Generate an `InterfaceT` object from a Vyper ast node. + + Arguments + --------- + module_t: ModuleT + Vyper module type + Returns + ------- + InterfaceT + primitive interface type + """ + funcs = [] + + for node in module_t.function_defs: + func_t = node._metadata["func_type"] + if not func_t.is_external: + continue + funcs.append((node.name, func_t)) + + # add getters for public variables since they aren't yet in the AST + for node in module_t._module.get_children(vy_ast.VariableDecl): + if not node.is_public: + continue + getter = node._metadata["getter_type"] + funcs.append((node.target.id, getter)) + + events = [(node.name, node._metadata["event_type"]) for node in module_t.event_defs] + + structs = [(node.name, node._metadata["struct_type"]) for node in module_t.struct_defs] + + return cls._from_lists(module_t._id, funcs, events, structs) + + @classmethod + def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": + functions = [] + for node in node.body: + if not isinstance(node, vy_ast.FunctionDef): + raise StructureException("Interfaces can only contain function definitions", node) + if len(node.decorator_list) > 0: + raise StructureException( + "Function definition in interface cannot be decorated", node.decorator_list[0] + ) + functions.append((node.name, ContractFunctionT.from_InterfaceDef(node))) + + # no structs or events in InterfaceDefs + events: list = [] + structs: list = [] + + return cls._from_lists(node.name, functions, events, structs) + + +# Datatype to store all module information. +class ModuleT(VyperType): + def __init__(self, module: vy_ast.Module, name: Optional[str] = None): + super().__init__() + + self._module = module + + self._id = name or module.path + + # compute the interface, note this has the side effect of checking + # for function collisions + self._helper = self.interface + + for f in self.function_defs: + # note: this checks for collisions + self.add_member(f.name, f._metadata["func_type"]) + + for e in self.event_defs: + # add the type of the event so it can be used in call position + self.add_member(e.name, TYPE_T(e._metadata["event_type"])) # type: ignore + + for s in self.struct_defs: + # add the type of the struct so it can be used in call position + self.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore + + for v in self.variable_decls: + self.add_member(v.target.id, v.target._metadata["varinfo"]) + + for i in self.import_stmts: + import_info = i._metadata["import_info"] + self.add_member(import_info.alias, import_info.typ) + + # __eq__ is very strict on ModuleT - object equality! this is because we + # don't want to reason about where a module came from (i.e. input bundle, + # search path, symlinked vs normalized path, etc.) + def __eq__(self, other): + return self is other + + def __hash__(self): + return hash(id(self)) + + def get_type_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": + return self._helper.get_member(key, node) + + # this is a property, because the function set changes after AST expansion + @property + def function_defs(self): + return self._module.get_children(vy_ast.FunctionDef) + + @property + def event_defs(self): + return self._module.get_children(vy_ast.EventDef) + + @property + def struct_defs(self): + return self._module.get_children(vy_ast.StructDef) + + @property + def import_stmts(self): + return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom)) + + @property + def variable_decls(self): + return self._module.get_children(vy_ast.VariableDecl) + + @cached_property + def variables(self): + # variables that this module defines, ex. + # `x: uint256` is a private storage variable named x + return {s.target.id: s.target._metadata["varinfo"] for s in self.variable_decls} + + @cached_property + def immutables(self): + return [t for t in self.variables.values() if t.is_immutable] + + @cached_property + def immutable_section_bytes(self): + return sum([imm.typ.memory_bytes_required for imm in self.immutables]) + + @cached_property + def interface(self): + return InterfaceT.from_ModuleT(self) diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 6a2d3aae73..46dffbdec4 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple from vyper import ast as vy_ast from vyper.abi_types import ABI_DynamicArray, ABI_StaticArray, ABI_Tuple, ABIType @@ -68,7 +68,7 @@ def get_subscripted_type(self, node): return self.value_type @classmethod - def from_annotation(cls, node: Union[vy_ast.Name, vy_ast.Call, vy_ast.Subscript]) -> "HashMapT": + def from_annotation(cls, node: vy_ast.Subscript) -> "HashMapT": if ( not isinstance(node, vy_ast.Subscript) or not isinstance(node.slice, vy_ast.Index) @@ -274,24 +274,32 @@ def compare_type(self, other): @classmethod def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": + # common error message, different ast locations + err_msg = "DynArray must be defined with base type and max length, e.g. DynArray[bool, 5]" + + if not isinstance(node, vy_ast.Subscript): + raise StructureException(err_msg, node) + if ( - not isinstance(node, vy_ast.Subscript) - or not isinstance(node.slice, vy_ast.Index) + not isinstance(node.slice, vy_ast.Index) or not isinstance(node.slice.value, vy_ast.Tuple) - or not isinstance(node.slice.value.elements[1], vy_ast.Int) or len(node.slice.value.elements) != 2 ): - raise StructureException( - "DynArray must be defined with base type and max length, e.g. DynArray[bool, 5]", - node, - ) + raise StructureException(err_msg, node.slice) + + length_node = node.slice.value.elements[1] + + if not isinstance(length_node, vy_ast.Int): + raise StructureException(err_msg, length_node) - value_type = type_from_annotation(node.slice.value.elements[0]) + length = length_node.value + + value_node = node.slice.value.elements[0] + value_type = type_from_annotation(value_node) if not value_type._as_darray: - raise StructureException(f"Arrays of {value_type} are not allowed", node) + raise StructureException(f"Arrays of {value_type} are not allowed", value_node) - max_length = node.slice.value.elements[1].value - return cls(value_type, max_length) + return cls(value_type, length) class TupleT(VyperType): @@ -333,7 +341,7 @@ def tuple_items(self): return list(enumerate(self.member_types)) @classmethod - def from_annotation(cls, node: vy_ast.Tuple) -> VyperType: + def from_annotation(cls, node: vy_ast.Tuple) -> "TupleT": values = node.elements types = tuple(type_from_annotation(v) for v in values) return cls(types) diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index ce82731c34..ef7e1d0eb4 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -1,27 +1,22 @@ from functools import cached_property -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional from vyper import ast as vy_ast -from vyper.abi_types import ABI_Address, ABI_GIntM, ABI_Tuple, ABIType +from vyper.abi_types import ABI_GIntM, ABI_Tuple, ABIType from vyper.ast.validation import validate_call_args from vyper.exceptions import ( EnumDeclarationException, EventDeclarationException, - InterfaceViolation, InvalidAttribute, NamespaceCollision, StructureException, UnknownAttribute, VariableDeclarationException, ) -from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions -from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids +from vyper.semantics.analysis.utils import validate_expected_type from vyper.semantics.data_locations import DataLocation -from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import VyperType -from vyper.semantics.types.function import ContractFunctionT -from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.subscriptable import HashMapT from vyper.semantics.types.utils import type_from_abi, type_from_annotation from vyper.utils import keccak256 @@ -29,12 +24,19 @@ # user defined type class _UserType(VyperType): + def __init__(self, members=None): + super().__init__(members=members) + def __eq__(self, other): return self is other - # TODO: revisit this once user types can be imported via modules def compare_type(self, other): - return super().compare_type(other) and self._id == other._id + # object exact comparison is a bit tricky here since we have + # to be careful to construct any given user type exactly + # only one time. however, the alternative requires reasoning + # about both the name and source (module or json abi) of + # the type. + return self is other def __hash__(self): return hash(id(self)) @@ -52,7 +54,8 @@ def __init__(self, name: str, members: dict) -> None: if len(members.keys()) > 256: raise EnumDeclarationException("Enums are limited to 256 members!") - super().__init__() + super().__init__(members=None) + self._id = name self._enum_members = members @@ -112,7 +115,7 @@ def from_EnumDef(cls, base_node: vy_ast.EnumDef) -> "EnumT": ------- Enum """ - members: Dict = {} + members: dict = {} if len(base_node.body) == 1 and isinstance(base_node.body[0], vy_ast.Pass): raise EnumDeclarationException("Enum must have members", base_node) @@ -135,7 +138,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: # TODO return None - def to_toplevel_abi_dict(self) -> List[Dict]: + def to_toplevel_abi_dict(self) -> list[dict]: # TODO return [] @@ -160,13 +163,21 @@ class EventT(_UserType): _invalid_locations = tuple(iter(DataLocation)) # not instantiable in any location - def __init__(self, name: str, arguments: dict, indexed: list) -> None: + def __init__( + self, + name: str, + arguments: dict, + indexed: list, + decl_node: Optional[vy_ast.VyperNode] = None, + ) -> None: super().__init__(members=arguments) self.name = name self.indexed = indexed assert len(self.indexed) == len(self.arguments) self.event_id = int(keccak256(self.signature.encode()).hex(), 16) + self.decl_node = decl_node + # backward compatible @property def arguments(self): @@ -187,7 +198,7 @@ def signature(self): return f"{self.name}({','.join(v.canonical_abi_type for v in self.arguments.values())})" @classmethod - def from_abi(cls, abi: Dict) -> "EventT": + def from_abi(cls, abi: dict) -> "EventT": """ Generate an `Event` object from an ABI interface. @@ -201,7 +212,7 @@ def from_abi(cls, abi: Dict) -> "EventT": Event object. """ members: dict = {} - indexed: List = [i["indexed"] for i in abi["inputs"]] + indexed: list = [i["indexed"] for i in abi["inputs"]] for item in abi["inputs"]: members[item["name"]] = type_from_abi(item) return cls(abi["name"], members, indexed) @@ -219,11 +230,11 @@ def from_EventDef(cls, base_node: vy_ast.EventDef) -> "EventT": ------- Event """ - members: Dict = {} - indexed: List = [] + members: dict = {} + indexed: list = [] if len(base_node.body) == 1 and isinstance(base_node.body[0], vy_ast.Pass): - return EventT(base_node.name, members, indexed) + return cls(base_node.name, members, indexed, base_node) for node in base_node.body: if not isinstance(node, vy_ast.AnnAssign): @@ -252,14 +263,14 @@ def from_EventDef(cls, base_node: vy_ast.EventDef) -> "EventT": members[member_name] = type_from_annotation(annotation) - return cls(base_node.name, members, indexed) + return cls(base_node.name, members, indexed, base_node) def _ctor_call_return(self, node: vy_ast.Call) -> None: validate_call_args(node, len(self.arguments)) for arg, expected in zip(node.args, self.arguments.values()): validate_expected_type(arg, expected) - def to_toplevel_abi_dict(self) -> List[Dict]: + def to_toplevel_abi_dict(self) -> list[dict]: return [ { "name": self.name, @@ -273,215 +284,6 @@ def to_toplevel_abi_dict(self) -> List[Dict]: ] -class InterfaceT(_UserType): - _type_members = {"address": AddressT()} - _is_prim_word = True - _as_array = True - _as_hashmap_key = True - - def __init__(self, _id: str, members: dict, events: dict) -> None: - validate_unique_method_ids(list(members.values())) # explicit list cast for mypy - super().__init__(members) - - self._id = _id - self.events = events - - @property - def getter_signature(self): - return (), AddressT() - - @property - def abi_type(self) -> ABIType: - return ABI_Address() - - def __repr__(self): - return f"{self._id}" - - # when using the type itself (not an instance) in the call position - # maybe rename to _ctor_call_return - def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": - self._ctor_arg_types(node) - - return self - - def _ctor_arg_types(self, node): - validate_call_args(node, 1) - validate_expected_type(node.args[0], AddressT()) - return [AddressT()] - - def _ctor_kwarg_types(self, node): - return {} - - # TODO x.validate_implements(other) - def validate_implements(self, node: vy_ast.ImplementsDecl) -> None: - namespace = get_namespace() - unimplemented = [] - - def _is_function_implemented(fn_name, fn_type): - vyper_self = namespace["self"].typ - if fn_name not in vyper_self.members: - return False - s = vyper_self.members[fn_name] - if isinstance(s, ContractFunctionT): - to_compare = vyper_self.members[fn_name] - # this is kludgy, rework order of passes in ModuleNodeVisitor - elif isinstance(s, VarInfo) and s.is_public: - to_compare = s.decl_node._metadata["func_type"] - else: - return False - - return to_compare.implements(fn_type) - - # check for missing functions - for name, type_ in self.members.items(): - if not isinstance(type_, ContractFunctionT): - # ex. address - continue - - if not _is_function_implemented(name, type_): - unimplemented.append(name) - - # check for missing events - for name, event in self.events.items(): - if name not in namespace: - unimplemented.append(name) - continue - - if not isinstance(namespace[name], EventT): - unimplemented.append(f"{name} is not an event!") - if ( - namespace[name].event_id != event.event_id - or namespace[name].indexed != event.indexed - ): - unimplemented.append(f"{name} is not implemented! (should be {event})") - - if len(unimplemented) > 0: - # TODO: improve the error message for cases where the - # mismatch is small (like mutability, or just one argument - # is off, etc). - missing_str = ", ".join(sorted(unimplemented)) - raise InterfaceViolation( - f"Contract does not implement all interface functions or events: {missing_str}", - node, - ) - - def to_toplevel_abi_dict(self) -> List[Dict]: - abi = [] - for event in self.events.values(): - abi += event.to_toplevel_abi_dict() - for func in self.functions.values(): - abi += func.to_toplevel_abi_dict() - return abi - - @property - def functions(self): - return {k: v for (k, v) in self.members.items() if isinstance(v, ContractFunctionT)} - - @classmethod - def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT": - """ - Generate an `InterfaceT` object from an ABI. - - Arguments - --------- - name : str - The name of the interface - abi : dict - Contract ABI - - Returns - ------- - InterfaceT - primitive interface type - """ - members: Dict = {} - events: Dict = {} - - names = [i["name"] for i in abi if i.get("type") in ("event", "function")] - collisions = set(i for i in names if names.count(i) > 1) - if collisions: - collision_list = ", ".join(sorted(collisions)) - raise NamespaceCollision( - f"ABI '{name}' has multiple functions or events " - f"with the same name: {collision_list}" - ) - - for item in [i for i in abi if i.get("type") == "function"]: - members[item["name"]] = ContractFunctionT.from_abi(item) - for item in [i for i in abi if i.get("type") == "event"]: - events[item["name"]] = EventT.from_abi(item) - - return cls(name, members, events) - - # TODO: split me into from_InterfaceDef and from_Module - @classmethod - def from_ast(cls, node: Union[vy_ast.InterfaceDef, vy_ast.Module]) -> "InterfaceT": - """ - Generate an `InterfaceT` object from a Vyper ast node. - - Arguments - --------- - node : InterfaceDef | Module - Vyper ast node defining the interface - Returns - ------- - InterfaceT - primitive interface type - """ - if isinstance(node, vy_ast.Module): - members, events = _get_module_definitions(node) - elif isinstance(node, vy_ast.InterfaceDef): - members = _get_class_functions(node) - events = {} - else: - raise StructureException("Invalid syntax for interface definition", node) - - return cls(node.name, members, events) - - -def _get_module_definitions(base_node: vy_ast.Module) -> Tuple[Dict, Dict]: - functions: Dict = {} - events: Dict = {} - for node in base_node.get_children(vy_ast.FunctionDef): - if "external" in [i.id for i in node.decorator_list if isinstance(i, vy_ast.Name)]: - func = ContractFunctionT.from_FunctionDef(node) - functions[node.name] = func - for node in base_node.get_children(vy_ast.VariableDecl, {"is_public": True}): - name = node.target.id - if name in functions: - raise NamespaceCollision( - f"Interface contains multiple functions named '{name}'", base_node - ) - functions[name] = ContractFunctionT.getter_from_VariableDecl(node) - for node in base_node.get_children(vy_ast.EventDef): - name = node.name - if name in functions or name in events: - raise NamespaceCollision( - f"Interface contains multiple objects named '{name}'", base_node - ) - events[name] = EventT.from_EventDef(node) - - return functions, events - - -def _get_class_functions(base_node: vy_ast.InterfaceDef) -> Dict[str, ContractFunctionT]: - functions = {} - for node in base_node.body: - if not isinstance(node, vy_ast.FunctionDef): - raise StructureException("Interfaces can only contain function definitions", node) - if node.name in functions: - raise NamespaceCollision( - f"Interface contains multiple functions named '{node.name}'", node - ) - if len(node.decorator_list) > 0: - raise StructureException( - "Function definition in interface cannot be decorated", node.decorator_list[0] - ) - functions[node.name] = ContractFunctionT.from_FunctionDef(node, is_interface=True) - - return functions - - class StructT(_UserType): _as_array = True @@ -516,7 +318,7 @@ def member_types(self): return self.members @classmethod - def from_ast_def(cls, base_node: vy_ast.StructDef) -> "StructT": + def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": """ Generate a `StructT` object from a Vyper ast node. @@ -531,7 +333,7 @@ def from_ast_def(cls, base_node: vy_ast.StructDef) -> "StructT": """ struct_name = base_node.name - members: Dict[str, VyperType] = {} + members: dict[str, VyperType] = {} for node in base_node.body: if not isinstance(node, vy_ast.AnnAssign): raise StructureException( @@ -605,4 +407,4 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": f"Struct declaration does not define all fields: {', '.join(list(members))}", node ) - return StructT(self._id, self.member_types) + return self diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 1187080ca9..8d68a9fa01 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -6,12 +6,13 @@ InstantiationException, InvalidType, StructureException, + UndeclaredDefinition, UnknownType, ) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import get_namespace -from vyper.semantics.types.base import VyperType +from vyper.semantics.types.base import TYPE_T, VyperType # TODO maybe this should be merged with .types/base.py @@ -75,7 +76,7 @@ def type_from_annotation( Arguments --------- - node : VyperNode + node: VyperNode Vyper ast node from the `annotation` member of a `VariableDecl` or `AnnAssign` node. Returns @@ -95,12 +96,6 @@ def type_from_annotation( def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: namespace = get_namespace() - def _failwith(type_name): - suggestions_str = get_levenshtein_error_suggestions(type_name, namespace, 0.3) - raise UnknownType( - f"No builtin or user-defined type named '{type_name}'. {suggestions_str}", node - ) from None - if isinstance(node, vy_ast.Tuple): tuple_t = namespace["$TupleT"] return tuple_t.from_annotation(node) @@ -116,11 +111,43 @@ def _failwith(type_name): return type_ctor.from_annotation(node) + # prepare a common error message + err_msg = f"'{node.node_source_code}' is not a type!" + + if isinstance(node, vy_ast.Attribute): + # ex. SomeModule.SomeStruct + + # sanity check - we only allow modules/interfaces to be + # imported as `Name`s currently. + if not isinstance(node.value, vy_ast.Name): + raise InvalidType(err_msg, node) + + try: + module_or_interface = namespace[node.value.id] # type: ignore + except UndeclaredDefinition: + raise InvalidType(err_msg, node) from None + + interface = module_or_interface + if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo + interface = module_or_interface.module_t.interface + + if not interface._attribute_in_annotation: + raise InvalidType(err_msg, node) + + type_t = interface.get_type_member(node.attr, node) + assert isinstance(type_t, TYPE_T) # sanity check + return type_t.typedef + if not isinstance(node, vy_ast.Name): # maybe handle this somewhere upstream in ast validation - raise InvalidType(f"'{node.node_source_code}' is not a type", node) - if node.id not in namespace: - _failwith(node.node_source_code) + raise InvalidType(err_msg, node) + + if node.id not in namespace: # type: ignore + suggestions_str = get_levenshtein_error_suggestions(node.node_source_code, namespace, 0.3) + raise UnknownType( + f"No builtin or user-defined type named '{node.node_source_code}'. {suggestions_str}", + node, + ) from None typ_ = namespace[node.id] if hasattr(typ_, "from_annotation"): @@ -138,7 +165,7 @@ def get_index_value(node: vy_ast.Index) -> int: Arguments --------- - node : vy_ast.Index + node: vy_ast.Index Vyper ast node from the `slice` member of a Subscript node. Must be an `Index` object (Vyper does not support `Slice` or `ExtSlice`). @@ -146,6 +173,7 @@ def get_index_value(node: vy_ast.Index) -> int: ------- int Literal integer value. + In the future, will return `None` if the subscript is an Ellipsis """ # this is imported to improve error messages # TODO: revisit this! diff --git a/vyper/utils.py b/vyper/utils.py index 0a2e1f831f..6816db9bae 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -51,6 +51,10 @@ def difference(self, other): def union(self, other): return self | other + def update(self, other): + for item in other: + self.add(item) + def __or__(self, other): return self.__class__(super().__or__(other)) @@ -162,11 +166,6 @@ def method_id(method_str: str) -> bytes: return keccak256(bytes(method_str, "utf-8"))[:4] -# map a string to only-alphanumeric chars -def mkalphanum(s): - return "".join([c if c.isalnum() else "_" for c in s]) - - def round_towards_zero(d: decimal.Decimal) -> int: # TODO double check if this can just be int(d) # (but either way keep this util function bc it's easier at a glance