diff --git a/setup.py b/setup.py index 40efb436c5..431c50b74b 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ "pytest-instafail>=0.4,<1.0", "pytest-xdist>=2.5,<3.0", "pytest-split>=0.7.0,<1.0", - "pytest-rerunfailures>=10.2,<11", "eth-tester[py-evm]>=0.9.0b1,<0.10", "py-evm>=0.7.0a1,<0.8", "web3==6.0.0", diff --git a/tests/conftest.py b/tests/conftest.py index 22f8544beb..925a025a4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from vyper import compiler from vyper.ast.grammar import parse_vyper_source from vyper.codegen.ir_node import IRnode -from vyper.compiler.input_bundle import FilesystemInputBundle +from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings, _set_debug_mode from vyper.ir import compile_ir, optimizer @@ -103,6 +103,12 @@ def fn(sources_dict): return fn +# for tests which just need an input bundle, doesn't matter what it is +@pytest.fixture +def dummy_input_bundle(): + return InputBundle([]) + + # TODO: remove me, this is just string.encode("utf-8").ljust() # only used in test_logging.py. @pytest.fixture @@ -255,9 +261,11 @@ def ir_compiler(ir, *args, **kwargs): ir = IRnode.from_list(ir) if optimize != OptimizationLevel.NONE: ir = optimizer.optimize(ir) + bytecode, _ = compile_ir.assembly_to_evm( compile_ir.compile_to_assembly(ir, optimize=optimize) ) + abi = kwargs.get("abi") or [] c = w3.eth.contract(abi=abi, bytecode=bytecode) deploy_transaction = c.constructor() diff --git a/tests/functional/codegen/test_call_graph_stability.py b/tests/functional/codegen/test_call_graph_stability.py index 4c85c330f3..2d8ad59791 100644 --- a/tests/functional/codegen/test_call_graph_stability.py +++ b/tests/functional/codegen/test_call_graph_stability.py @@ -55,7 +55,7 @@ def foo(): # check the .called_functions data structure on foo() directly foo = t.vyper_module_folded.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0] - foo_t = foo._metadata["type"] + foo_t = foo._metadata["func_type"] assert [f.name for f in foo_t.called_functions] == func_names # now for sanity, ensure the order that the function definitions appear diff --git a/tests/functional/builtins/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py similarity index 84% rename from tests/functional/builtins/codegen/test_interfaces.py rename to tests/functional/codegen/test_interfaces.py index 8cb0124f29..3544f4a965 100644 --- a/tests/functional/builtins/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -6,9 +6,9 @@ from vyper.compiler import compile_code from vyper.exceptions import ( ArgumentException, + DuplicateImport, InterfaceViolation, NamespaceCollision, - StructureException, ) @@ -31,7 +31,7 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256): out = compile_code(code, output_formats=["interface"]) out = out["interface"] - code_pass = "\n".join(code.split("\n")[:-2] + [" pass"]) # replace with a pass statement. + code_pass = "\n".join(code.split("\n")[:-2] + [" ..."]) # replace with a pass statement. assert code_pass.strip() == out.strip() @@ -60,7 +60,7 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256): view def test(_owner: address): nonpayable """ - out = compile_code(code, contract_name="One.vy", output_formats=["external_interface"])[ + out = compile_code(code, contract_path="One.vy", output_formats=["external_interface"])[ "external_interface" ] @@ -85,14 +85,14 @@ def test_external_interface_parsing(make_input_bundle, assert_compile_failed): interface_code = """ @external def foo() -> uint256: - pass + ... @external def bar() -> uint256: - pass + ... """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) code = """ import a as FooBarInterface @@ -121,9 +121,8 @@ def foo() -> uint256: """ - assert_compile_failed( - lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation - ) + with pytest.raises(InterfaceViolation): + compile_code(not_implemented_code, input_bundle=input_bundle) def test_missing_event(make_input_bundle, assert_compile_failed): @@ -132,7 +131,7 @@ def test_missing_event(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -156,7 +155,7 @@ def test_malformed_event(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -183,7 +182,7 @@ def test_malformed_events_indexed(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -211,7 +210,7 @@ def test_malformed_events_indexed2(make_input_bundle, assert_compile_failed): a: indexed(uint256) """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -234,13 +233,13 @@ def bar() -> uint256: VALID_IMPORT_CODE = [ # import statement, import path without suffix - ("import a as Foo", "a.vy"), - ("import b.a as Foo", "b/a.vy"), - ("import Foo as Foo", "Foo.vy"), - ("from a import Foo", "a/Foo.vy"), - ("from b.a import Foo", "b/a/Foo.vy"), - ("from .a import Foo", "./a/Foo.vy"), - ("from ..a import Foo", "../a/Foo.vy"), + ("import a as Foo", "a.vyi"), + ("import b.a as Foo", "b/a.vyi"), + ("import Foo as Foo", "Foo.vyi"), + ("from a import Foo", "a/Foo.vyi"), + ("from b.a import Foo", "b/a/Foo.vyi"), + ("from .a import Foo", "./a/Foo.vyi"), + ("from ..a import Foo", "../a/Foo.vyi"), ] @@ -252,11 +251,12 @@ def test_extract_file_interface_imports(code, filename, make_input_bundle): BAD_IMPORT_CODE = [ - ("import a", StructureException), # must alias absolute imports - ("import a as A\nimport a as A", NamespaceCollision), + ("import a as A\nimport a as A", DuplicateImport), + ("import a as A\nimport a as a", DuplicateImport), + ("from . import a\nimport a as a", DuplicateImport), + ("import a as a\nfrom . import a", DuplicateImport), ("from b import a\nfrom . import a", NamespaceCollision), - ("from . import a\nimport a as a", NamespaceCollision), - ("import a as a\nfrom . import a", NamespaceCollision), + ("import a\nimport c as a", NamespaceCollision), ] @@ -264,34 +264,50 @@ def test_extract_file_interface_imports(code, filename, make_input_bundle): def test_extract_file_interface_imports_raises( code, exception_type, assert_compile_failed, make_input_bundle ): - input_bundle = make_input_bundle({"a.vy": "", "b/a.vy": ""}) # dummy - assert_compile_failed(lambda: compile_code(code, input_bundle=input_bundle), exception_type) + input_bundle = make_input_bundle({"a.vyi": "", "b/a.vyi": "", "c.vyi": ""}) + with pytest.raises(exception_type): + compile_code(code, input_bundle=input_bundle) def test_external_call_to_interface(w3, get_contract, make_input_bundle): + token_interface = """ +@view +@external +def balanceOf(addr: address) -> uint256: + ... + +@external +def transfer(to: address, amount: uint256): + ... + """ + token_code = """ +import itoken as IToken + +implements: IToken + balanceOf: public(HashMap[address, uint256]) @external -def transfer(to: address, _value: uint256): - self.balanceOf[to] += _value +def transfer(to: address, amount: uint256): + self.balanceOf[to] += amount """ - input_bundle = make_input_bundle({"one.vy": token_code}) + input_bundle = make_input_bundle({"token.vy": token_code, "itoken.vyi": token_interface}) code = """ -import one as TokenCode +import itoken as IToken interface EPI: def test() -> uint256: view -token_address: TokenCode +token_address: IToken @external def __init__(_token_address: address): - self.token_address = TokenCode(_token_address) + self.token_address = IToken(_token_address) @external @@ -299,14 +315,15 @@ def test(): self.token_address.transfer(msg.sender, 1000) """ - erc20 = get_contract(token_code) - test_c = get_contract(code, *[erc20.address], input_bundle=input_bundle) + token = get_contract(token_code, input_bundle=input_bundle) + + test_c = get_contract(code, *[token.address], input_bundle=input_bundle) sender = w3.eth.accounts[0] - assert erc20.balanceOf(sender) == 0 + assert token.balanceOf(sender) == 0 test_c.test(transact={}) - assert erc20.balanceOf(sender) == 1000 + assert token.balanceOf(sender) == 1000 @pytest.mark.parametrize( @@ -320,26 +337,36 @@ def test(): ], ) def test_external_call_to_interface_kwarg(get_contract, kwarg, typ, expected, make_input_bundle): - code_a = f""" + interface_code = f""" +@external +@view +def foo(_max: {typ} = {kwarg}) -> {typ}: + ... + """ + code1 = f""" +import one as IContract + +implements: IContract + @external @view def foo(_max: {typ} = {kwarg}) -> {typ}: return _max """ - input_bundle = make_input_bundle({"one.vy": code_a}) + input_bundle = make_input_bundle({"one.vyi": interface_code}) - code_b = f""" -import one as ContractA + code2 = f""" +import one as IContract @external @view def bar(a_address: address) -> {typ}: - return ContractA(a_address).foo() + return IContract(a_address).foo() """ - contract_a = get_contract(code_a) - contract_b = get_contract(code_b, *[contract_a.address], input_bundle=input_bundle) + contract_a = get_contract(code1, input_bundle=input_bundle) + contract_b = get_contract(code2, *[contract_a.address], input_bundle=input_bundle) assert contract_b.bar(contract_a.address) == expected @@ -349,8 +376,8 @@ def test_external_call_to_builtin_interface(w3, get_contract): balanceOf: public(HashMap[address, uint256]) @external -def transfer(to: address, _value: uint256) -> bool: - self.balanceOf[to] += _value +def transfer(to: address, amount: uint256) -> bool: + self.balanceOf[to] += amount return True """ @@ -510,14 +537,14 @@ def returns_Bytes3() -> Bytes[3]: """ should_not_compile = """ -import BadJSONInterface as BadJSONInterface +import BadJSONInterface @external def foo(x: BadJSONInterface) -> Bytes[2]: return slice(x.returns_Bytes3(), 0, 2) """ code = """ -import BadJSONInterface as BadJSONInterface +import BadJSONInterface foo: BadJSONInterface @@ -578,10 +605,10 @@ def balanceOf(owner: address) -> uint256: @external @view def balanceOf(owner: address) -> uint256: - pass + ... """ - input_bundle = make_input_bundle({"balanceof.vy": interface_code}) + input_bundle = make_input_bundle({"balanceof.vyi": interface_code}) c = get_contract(code, input_bundle=input_bundle) @@ -592,7 +619,7 @@ def test_simple_implements(make_input_bundle): interface_code = """ @external def foo() -> uint256: - pass + ... """ code = """ @@ -605,7 +632,7 @@ def foo() -> uint256: return 1 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) assert compile_code(code, input_bundle=input_bundle) is not None diff --git a/tests/functional/codegen/test_selector_table_stability.py b/tests/functional/codegen/test_selector_table_stability.py index 3302ff5009..27f82416d6 100644 --- a/tests/functional/codegen/test_selector_table_stability.py +++ b/tests/functional/codegen/test_selector_table_stability.py @@ -14,7 +14,7 @@ def test_dense_jumptable_stability(): # test that the selector table data is stable across different runs # (tox should provide different PYTHONHASHSEEDs). - expected_asm = """{ DATA _sym_BUCKET_HEADERS b'\\x0bB' _sym_bucket_0 b'\\n' b'+\\x8d' _sym_bucket_1 b'\\x0c' b'\\x00\\x85' _sym_bucket_2 b'\\x08' } { DATA _sym_bucket_1 b'\\xd8\\xee\\xa1\\xe8' _sym_external_foo6___3639517672 b'\\x05' b'\\xd2\\x9e\\xe0\\xf9' _sym_external_foo0___3533627641 b'\\x05' b'\\x05\\xf1\\xe0_' _sym_external_foo2___99737695 b'\\x05' b'\\x91\\t\\xb4{' _sym_external_foo23___2433332347 b'\\x05' b'np3\\x7f' _sym_external_foo11___1852846975 b'\\x05' b'&\\xf5\\x96\\xf9' _sym_external_foo13___653629177 b'\\x05' b'\\x04ga\\xeb' _sym_external_foo14___73884139 b'\\x05' b'\\x89\\x06\\xad\\xc6' _sym_external_foo17___2298916294 b'\\x05' b'\\xe4%\\xac\\xd1' _sym_external_foo4___3827674321 b'\\x05' b'yj\\x01\\xac' _sym_external_foo7___2036990380 b'\\x05' b'\\xf1\\xe6K\\xe5' _sym_external_foo29___4058401765 b'\\x05' b'\\xd2\\x89X\\xb8' _sym_external_foo3___3532216504 b'\\x05' } { DATA _sym_bucket_2 b'\\x06p\\xffj' _sym_external_foo25___108068714 b'\\x05' b'\\x964\\x99I' _sym_external_foo24___2520029513 b'\\x05' b's\\x81\\xe7\\xc1' _sym_external_foo10___1937893313 b'\\x05' b'\\x85\\xad\\xc11' _sym_external_foo28___2242756913 b'\\x05' b'\\xfa"\\xb1\\xed' _sym_external_foo5___4196577773 b'\\x05' b'A\\xe7[\\x05' _sym_external_foo22___1105681157 b'\\x05' b'\\xd3\\x89U\\xe8' _sym_external_foo1___3548993000 b'\\x05' b'hL\\xf8\\xf3' _sym_external_foo20___1749874931 b'\\x05' } { DATA _sym_bucket_0 b'\\xee\\xd9\\x1d\\xe3' _sym_external_foo9___4007206371 b'\\x05' b'a\\xbc\\x1ch' _sym_external_foo16___1639717992 b'\\x05' b'\\xd3*\\xa7\\x0c' _sym_external_foo21___3542787852 b'\\x05' b'\\x18iG\\xd9' _sym_external_foo19___409552857 b'\\x05' b'\\n\\xf1\\xf9\\x7f' _sym_external_foo18___183630207 b'\\x05' b')\\xda\\xd7`' _sym_external_foo27___702207840 b'\\x05' b'2\\xf6\\xaa\\xda' _sym_external_foo12___855026394 b'\\x05' b'\\xbe\\xb5\\x05\\xf5' _sym_external_foo15___3199534581 b'\\x05' b'\\xfc\\xa7_\\xe6' _sym_external_foo8___4238827494 b'\\x05' b'\\x1b\\x12C8' _sym_external_foo26___454181688 b'\\x05' } }""" # noqa: E501 + expected_asm = """{ DATA _sym_BUCKET_HEADERS b\'\\x0bB\' _sym_bucket_0 b\'\\n\' b\'+\\x8d\' _sym_bucket_1 b\'\\x0c\' b\'\\x00\\x85\' _sym_bucket_2 b\'\\x08\' } { DATA _sym_bucket_1 b\'\\xd8\\xee\\xa1\\xe8\' _sym_external 6 foo6()3639517672 b\'\\x05\' b\'\\xd2\\x9e\\xe0\\xf9\' _sym_external 0 foo0()3533627641 b\'\\x05\' b\'\\x05\\xf1\\xe0_\' _sym_external 2 foo2()99737695 b\'\\x05\' b\'\\x91\\t\\xb4{\' _sym_external 23 foo23()2433332347 b\'\\x05\' b\'np3\\x7f\' _sym_external 11 foo11()1852846975 b\'\\x05\' b\'&\\xf5\\x96\\xf9\' _sym_external 13 foo13()653629177 b\'\\x05\' b\'\\x04ga\\xeb\' _sym_external 14 foo14()73884139 b\'\\x05\' b\'\\x89\\x06\\xad\\xc6\' _sym_external 17 foo17()2298916294 b\'\\x05\' b\'\\xe4%\\xac\\xd1\' _sym_external 4 foo4()3827674321 b\'\\x05\' b\'yj\\x01\\xac\' _sym_external 7 foo7()2036990380 b\'\\x05\' b\'\\xf1\\xe6K\\xe5\' _sym_external 29 foo29()4058401765 b\'\\x05\' b\'\\xd2\\x89X\\xb8\' _sym_external 3 foo3()3532216504 b\'\\x05\' } { DATA _sym_bucket_2 b\'\\x06p\\xffj\' _sym_external 25 foo25()108068714 b\'\\x05\' b\'\\x964\\x99I\' _sym_external 24 foo24()2520029513 b\'\\x05\' b\'s\\x81\\xe7\\xc1\' _sym_external 10 foo10()1937893313 b\'\\x05\' b\'\\x85\\xad\\xc11\' _sym_external 28 foo28()2242756913 b\'\\x05\' b\'\\xfa"\\xb1\\xed\' _sym_external 5 foo5()4196577773 b\'\\x05\' b\'A\\xe7[\\x05\' _sym_external 22 foo22()1105681157 b\'\\x05\' b\'\\xd3\\x89U\\xe8\' _sym_external 1 foo1()3548993000 b\'\\x05\' b\'hL\\xf8\\xf3\' _sym_external 20 foo20()1749874931 b\'\\x05\' } { DATA _sym_bucket_0 b\'\\xee\\xd9\\x1d\\xe3\' _sym_external 9 foo9()4007206371 b\'\\x05\' b\'a\\xbc\\x1ch\' _sym_external 16 foo16()1639717992 b\'\\x05\' b\'\\xd3*\\xa7\\x0c\' _sym_external 21 foo21()3542787852 b\'\\x05\' b\'\\x18iG\\xd9\' _sym_external 19 foo19()409552857 b\'\\x05\' b\'\\n\\xf1\\xf9\\x7f\' _sym_external 18 foo18()183630207 b\'\\x05\' b\')\\xda\\xd7`\' _sym_external 27 foo27()702207840 b\'\\x05\' b\'2\\xf6\\xaa\\xda\' _sym_external 12 foo12()855026394 b\'\\x05\' b\'\\xbe\\xb5\\x05\\xf5\' _sym_external 15 foo15()3199534581 b\'\\x05\' b\'\\xfc\\xa7_\\xe6\' _sym_external 8 foo8()4238827494 b\'\\x05\' b\'\\x1b\\x12C8\' _sym_external 26 foo26()454181688 b\'\\x05\' } }""" # noqa: E501 assert expected_asm in output["asm"] diff --git a/tests/functional/codegen/test_stateless_modules.py b/tests/functional/codegen/test_stateless_modules.py new file mode 100644 index 0000000000..8e634e5868 --- /dev/null +++ b/tests/functional/codegen/test_stateless_modules.py @@ -0,0 +1,335 @@ +import hypothesis.strategies as st +import pytest +from hypothesis import given, settings + +from vyper import compiler +from vyper.exceptions import ( + CallViolation, + DuplicateImport, + ImportCycle, + StructureException, + TypeMismatch, +) + +# test modules which have no variables - "libraries" + + +def test_simple_library(get_contract, make_input_bundle, w3): + library_source = """ +@internal +def foo() -> uint256: + return block.number + 1 + """ + main = """ +import library + +@external +def bar() -> uint256: + return library.foo() - 1 + """ + input_bundle = make_input_bundle({"library.vy": library_source}) + + c = get_contract(main, input_bundle=input_bundle) + + assert c.bar() == w3.eth.block_number + + +# is this the best place for this? +def test_import_cycle(make_input_bundle): + code_a = "import b\n" + code_b = "import a\n" + + input_bundle = make_input_bundle({"a.vy": code_a, "b.vy": code_b}) + + with pytest.raises(ImportCycle): + compiler.compile_code(code_a, input_bundle=input_bundle) + + +# test we can have a function in the library with the same name as +# in the main contract +def test_library_function_same_name(get_contract, make_input_bundle): + library = """ +@internal +def foo() -> uint256: + return 10 + """ + + main = """ +import library + +@internal +def foo() -> uint256: + return 100 + +@external +def self_foo() -> uint256: + return self.foo() + +@external +def library_foo() -> uint256: + return library.foo() + """ + + input_bundle = make_input_bundle({"library.vy": library}) + + c = get_contract(main, input_bundle=input_bundle) + + assert c.self_foo() == 100 + assert c.library_foo() == 10 + + +def test_transitive_import(get_contract, make_input_bundle): + a = """ +@internal +def foo() -> uint256: + return 1 + """ + b = """ +import a + +@internal +def bar() -> uint256: + return a.foo() + 1 + """ + c = """ +import b + +@external +def baz() -> uint256: + return b.bar() + 1 + """ + # more complicated call graph, with `a` imported twice. + d = """ +import b +import a + +@external +def qux() -> uint256: + s: uint256 = a.foo() + return s + b.bar() + 1 + """ + input_bundle = make_input_bundle({"a.vy": a, "b.vy": b, "c.vy": c, "d.vy": d}) + + contract = get_contract(c, input_bundle=input_bundle) + assert contract.baz() == 3 + contract = get_contract(d, input_bundle=input_bundle) + assert contract.qux() == 4 + + +def test_cannot_call_library_external_functions(make_input_bundle): + library_source = """ +@external +def foo(): + pass + """ + contract_source = """ +import library + +@external +def bar(): + library.foo() + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(CallViolation): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_library_external_functions_not_in_abi(get_contract, make_input_bundle): + library_source = """ +@external +def foo(): + pass + """ + contract_source = """ +import library + +@external +def bar(): + pass + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + c = get_contract(contract_source, input_bundle=input_bundle) + assert not hasattr(c, "foo") + + +def test_library_structs(get_contract, make_input_bundle): + library_source = """ +struct SomeStruct: + x: uint256 + +@internal +def foo() -> SomeStruct: + return SomeStruct({x: 1}) + """ + contract_source = """ +import library + +@external +def bar(s: library.SomeStruct): + pass + +@external +def baz() -> library.SomeStruct: + return library.SomeStruct({x: 2}) + +@external +def qux() -> library.SomeStruct: + return library.foo() + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + c = get_contract(contract_source, input_bundle=input_bundle) + + assert c.bar((1,)) == [] + + assert c.baz() == (2,) + assert c.qux() == (1,) + + +# test calls to library functions in statement position +def test_library_statement_calls(get_contract, make_input_bundle, assert_tx_failed): + library_source = """ +from vyper.interfaces import ERC20 +@internal +def check_adds_to_ten(x: uint256, y: uint256): + assert x + y == 10 + """ + contract_source = """ +import library + +counter: public(uint256) + +@external +def foo(x: uint256): + library.check_adds_to_ten(3, x) + self.counter = x + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + + c = get_contract(contract_source, input_bundle=input_bundle) + + c.foo(7, transact={}) + + assert c.counter() == 7 + + assert_tx_failed(lambda: c.foo(8)) + + +def test_library_is_typechecked(make_input_bundle): + library_source = """ +@internal +def foo(): + asdlkfjasdflkajsdf + """ + contract_source = """ +import library + """ + + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(StructureException): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_library_is_typechecked2(make_input_bundle): + # check that we typecheck against imported function signatures + library_source = """ +@internal +def foo() -> uint256: + return 1 + """ + contract_source = """ +import library + +@external +def foo() -> bytes32: + return library.foo() + """ + + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(TypeMismatch): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_reject_duplicate_imports(make_input_bundle): + library_source = """ + """ + + contract_source = """ +import library +import library as library2 + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(DuplicateImport): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_nested_module_access(get_contract, make_input_bundle): + lib1 = """ +import lib2 + +@internal +def lib2_foo() -> uint256: + return lib2.foo() + """ + lib2 = """ +@internal +def foo() -> uint256: + return 1337 + """ + + main = """ +import lib1 +import lib2 + +@external +def lib1_foo() -> uint256: + return lib1.lib2_foo() + +@external +def lib2_foo() -> uint256: + return lib1.lib2.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.lib1_foo() == c.lib2_foo() == 1337 + + +_int_127 = st.integers(min_value=0, max_value=127) +_bytes_128 = st.binary(min_size=0, max_size=128) + + +def test_slice_builtin(get_contract, make_input_bundle): + lib = """ +@internal +def slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return slice(x, start, length) + """ + + main = """ +import lib +@external +def lib_slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return lib.slice_input(x, start, length) + +@external +def slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return slice(x, start, length) + """ + input_bundle = make_input_bundle({"lib.vy": lib}) + c = get_contract(main, input_bundle=input_bundle) + + # use an inner test so that we can cache the result of get_contract() + @given(start=_int_127, length=_int_127, bytesdata=_bytes_128) + @settings(max_examples=100) + def _test(bytesdata, start, length): + # surjectively map start into allowable range + if start > len(bytesdata): + start = start % (len(bytesdata) or 1) + # surjectively map length into allowable range + if length > (len(bytesdata) - start): + length = length % ((len(bytesdata) - start) or 1) + main_result = c.slice_input(bytesdata, start, length) + library_result = c.lib_slice_input(bytesdata, start, length) + assert main_result == library_result == bytesdata[start : start + length] + + _test() diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index aa0286cfa5..7dd8c35929 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -92,7 +92,7 @@ def from_grammar() -> st.SearchStrategy[str]: # Avoid examples with *only* single or double quote docstrings -# because they trigger a trivial compiler bug +# because they trigger a trivial parser bug SINGLE_QUOTE_DOCSTRING = re.compile(r"^'''.*'''$") DOUBLE_QUOTE_DOCSTRING = re.compile(r'^""".*"""$') diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 9100389dbd..a672ed7b88 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -376,17 +376,12 @@ def test_interfaces_success(good_code): def test_imports_and_implements_within_interface(make_input_bundle): interface_code = """ -from vyper.interfaces import ERC20 -import foo.bar as Baz - -implements: Baz - @external def foobar(): - pass + ... """ - input_bundle = make_input_bundle({"foo.vy": interface_code}) + input_bundle = make_input_bundle({"foo.vyi": interface_code}) code = """ import foo as Foo diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index 47483c493c..d413340083 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -37,9 +37,9 @@ def foo(): @pytest.mark.parametrize("code", code_invalid_checksum) -def test_invalid_checksum(code): +def test_invalid_checksum(code, dummy_input_bundle): vyper_module = vy_ast.parse_to_ast(code) with pytest.raises(InvalidLiteral): vy_ast.validation.validate_literal_nodes(vyper_module) - semantics.validate_semantics(vyper_module, {}) + semantics.validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 68a07178bb..16ce6fe631 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -1,7 +1,6 @@ import ast as python_ast -from vyper.ast.annotation import annotate_python_ast -from vyper.ast.pre_parser import pre_parse +from vyper.ast.parse import annotate_python_ast, pre_parse class AssertionVisitor(python_ast.NodeVisitor): diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 1f60c9ac8b..dc49f72561 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -1,7 +1,8 @@ import json from vyper import compiler -from vyper.ast.utils import ast_to_dict, dict_to_ast, parse_to_ast +from vyper.ast.parse import parse_to_ast +from vyper.ast.utils import ast_to_dict, dict_to_ast def get_node_ids(ast_struct, ids=None): @@ -40,7 +41,7 @@ def test_basic_ast(): code = """ a: int128 """ - dict_out = compiler.compile_code(code, output_formats=["ast_dict"]) + dict_out = compiler.compile_code(code, output_formats=["ast_dict"], source_id=0) assert dict_out["ast_dict"]["ast"]["body"][0] == { "annotation": { "ast_type": "Name", @@ -89,7 +90,7 @@ def foo() -> uint256: view def foo() -> uint256: return 1 """ - dict_out = compiler.compile_code(code, output_formats=["ast_dict"]) + dict_out = compiler.compile_code(code, output_formats=["ast_dict"], source_id=0) assert dict_out["ast_dict"]["ast"]["body"][1] == { "col_offset": 0, "annotation": { diff --git a/tests/unit/ast/test_parser.py b/tests/unit/ast/test_parser.py index c47bf40bfa..e0bfcbc2ef 100644 --- a/tests/unit/ast/test_parser.py +++ b/tests/unit/ast/test_parser.py @@ -1,4 +1,4 @@ -from vyper.ast.utils import parse_to_ast +from vyper.ast.parse import parse_to_ast def test_ast_equal(): diff --git a/tests/unit/cli/outputs/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py similarity index 100% rename from tests/unit/cli/outputs/test_storage_layout.py rename to tests/unit/cli/storage_layout/test_storage_layout.py diff --git a/tests/unit/cli/outputs/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py similarity index 98% rename from tests/unit/cli/outputs/test_storage_layout_overrides.py rename to tests/unit/cli/storage_layout/test_storage_layout_overrides.py index 94e0faeb37..f4c11b7ae6 100644 --- a/tests/unit/cli/outputs/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -103,7 +103,7 @@ def test_overflow(): storage_layout_override = {"x": {"slot": 2**256 - 1, "type": "uint256[2]"}} with pytest.raises( - StorageLayoutException, match=f"Invalid storage slot for var x, out of bounds: {2**256}\n" + StorageLayoutException, match=f"Invalid storage slot for var x, out of bounds: {2**256}" ): compile_code( code, output_formats=["layout"], storage_layout_override=storage_layout_override diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index 2a16efa777..f6e3a51a4b 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -30,93 +30,100 @@ def test_invalid_root_path(): compile_files([], [], root_folder="path/that/does/not/exist") -FOO_CODE = """ -{} - -struct FooStruct: - foo_: uint256 +CONTRACT_CODE = """ +{import_stmt} @external -def foo() -> FooStruct: - return FooStruct({{foo_: 13}}) +def foo() -> {alias}.FooStruct: + return {alias}.FooStruct({{foo_: 13}}) @external -def bar(a: address) -> FooStruct: - return {}(a).bar() +def bar(a: address) -> {alias}.FooStruct: + return {alias}(a).bar() """ -BAR_CODE = """ +INTERFACE_CODE = """ struct FooStruct: foo_: uint256 + +@external +def foo() -> FooStruct: + ... + @external def bar() -> FooStruct: - return FooStruct({foo_: 13}) + ... """ SAME_FOLDER_IMPORT_STMT = [ - ("import Bar as Bar", "Bar"), - ("import contracts.Bar as Bar", "Bar"), - ("from . import Bar", "Bar"), - ("from contracts import Bar", "Bar"), - ("from ..contracts import Bar", "Bar"), - ("from . import Bar as FooBar", "FooBar"), - ("from contracts import Bar as FooBar", "FooBar"), - ("from ..contracts import Bar as FooBar", "FooBar"), + ("import IFoo as IFoo", "IFoo"), + ("import contracts.IFoo as IFoo", "IFoo"), + ("from . import IFoo", "IFoo"), + ("from contracts import IFoo", "IFoo"), + ("from ..contracts import IFoo", "IFoo"), + ("from . import IFoo as FooBar", "FooBar"), + ("from contracts import IFoo as FooBar", "FooBar"), + ("from ..contracts import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt,alias", SAME_FOLDER_IMPORT_STMT) def test_import_same_folder(import_stmt, alias, tmp_path, make_file): foo = "contracts/foo.vy" - make_file("contracts/foo.vy", FOO_CODE.format(import_stmt, alias)) - make_file("contracts/Bar.vy", BAR_CODE) + make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + make_file("contracts/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) SUBFOLDER_IMPORT_STMT = [ - ("import other.Bar as Bar", "Bar"), - ("import contracts.other.Bar as Bar", "Bar"), - ("from other import Bar", "Bar"), - ("from contracts.other import Bar", "Bar"), - ("from .other import Bar", "Bar"), - ("from ..contracts.other import Bar", "Bar"), - ("from other import Bar as FooBar", "FooBar"), - ("from contracts.other import Bar as FooBar", "FooBar"), - ("from .other import Bar as FooBar", "FooBar"), - ("from ..contracts.other import Bar as FooBar", "FooBar"), + ("import other.IFoo as IFoo", "IFoo"), + ("import contracts.other.IFoo as IFoo", "IFoo"), + ("from other import IFoo", "IFoo"), + ("from contracts.other import IFoo", "IFoo"), + ("from .other import IFoo", "IFoo"), + ("from ..contracts.other import IFoo", "IFoo"), + ("from other import IFoo as FooBar", "FooBar"), + ("from contracts.other import IFoo as FooBar", "FooBar"), + ("from .other import IFoo as FooBar", "FooBar"), + ("from ..contracts.other import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt, alias", SUBFOLDER_IMPORT_STMT) def test_import_subfolder(import_stmt, alias, tmp_path, make_file): - foo = make_file("contracts/foo.vy", (FOO_CODE.format(import_stmt, alias))) - make_file("contracts/other/Bar.vy", BAR_CODE) + foo = make_file( + "contracts/foo.vy", (CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + ) + make_file("contracts/other/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) OTHER_FOLDER_IMPORT_STMT = [ - ("import interfaces.Bar as Bar", "Bar"), - ("from interfaces import Bar", "Bar"), - ("from ..interfaces import Bar", "Bar"), - ("from interfaces import Bar as FooBar", "FooBar"), - ("from ..interfaces import Bar as FooBar", "FooBar"), + ("import interfaces.IFoo as IFoo", "IFoo"), + ("from interfaces import IFoo", "IFoo"), + ("from ..interfaces import IFoo", "IFoo"), + ("from interfaces import IFoo as FooBar", "FooBar"), + ("from ..interfaces import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt, alias", OTHER_FOLDER_IMPORT_STMT) def test_import_other_folder(import_stmt, alias, tmp_path, make_file): - foo = make_file("contracts/foo.vy", FOO_CODE.format(import_stmt, alias)) - make_file("interfaces/Bar.vy", BAR_CODE) + foo = make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + make_file("interfaces/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) def test_import_parent_folder(tmp_path, make_file): - foo = make_file("contracts/baz/foo.vy", FOO_CODE.format("from ... import Bar", "Bar")) - make_file("Bar.vy", BAR_CODE) + foo = make_file( + "contracts/baz/foo.vy", + CONTRACT_CODE.format(import_stmt="from ... import IFoo", alias="IFoo"), + ) + make_file("IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) @@ -125,62 +132,60 @@ def test_import_parent_folder(tmp_path, make_file): META_IMPORT_STMT = [ - "import Meta as Meta", - "import contracts.Meta as Meta", - "from . import Meta", - "from contracts import Meta", + "import ISelf as ISelf", + "import contracts.ISelf as ISelf", + "from . import ISelf", + "from contracts import ISelf", ] @pytest.mark.parametrize("import_stmt", META_IMPORT_STMT) def test_import_self_interface(import_stmt, tmp_path, make_file): - # a contract can access its derived interface by importing itself - code = f""" -{import_stmt} - + interface_code = """ struct FooStruct: foo_: uint256 @external def know_thyself(a: address) -> FooStruct: - return Meta(a).be_known() + ... @external def be_known() -> FooStruct: - return FooStruct({{foo_: 42}}) + ... """ - meta = make_file("contracts/Meta.vy", code) - - assert compile_files([meta], ["combined_json"], root_folder=tmp_path) + code = f""" +{import_stmt} +@external +def know_thyself(a: address) -> ISelf.FooStruct: + return ISelf(a).be_known() -DERIVED_IMPORT_STMT_BAZ = ["import Foo as Foo", "from . import Foo"] +@external +def be_known() -> ISelf.FooStruct: + return ISelf.FooStruct({{foo_: 42}}) + """ + make_file("contracts/ISelf.vyi", interface_code) + meta = make_file("contracts/Self.vy", code) -DERIVED_IMPORT_STMT_FOO = ["import Bar as Bar", "from . import Bar"] + assert compile_files([meta], ["combined_json"], root_folder=tmp_path) -@pytest.mark.parametrize("import_stmt_baz", DERIVED_IMPORT_STMT_BAZ) -@pytest.mark.parametrize("import_stmt_foo", DERIVED_IMPORT_STMT_FOO) -def test_derived_interface_imports(import_stmt_baz, import_stmt_foo, tmp_path, make_file): - # contracts-as-interfaces should be able to contain import statements +# implement IFoo in another contract for fun +@pytest.mark.parametrize("import_stmt_foo,alias", SAME_FOLDER_IMPORT_STMT) +def test_another_interface_implementation(import_stmt_foo, alias, tmp_path, make_file): baz_code = f""" -{import_stmt_baz} - -struct FooStruct: - foo_: uint256 +{import_stmt_foo} @external -def foo(a: address) -> FooStruct: - return Foo(a).foo() +def foo(a: address) -> {alias}.FooStruct: + return {alias}(a).foo() @external -def bar(_foo: address, _bar: address) -> FooStruct: - return Foo(_foo).bar(_bar) +def bar(_foo: address) -> {alias}.FooStruct: + return {alias}(_foo).bar() """ - - make_file("Foo.vy", FOO_CODE.format(import_stmt_foo, "Bar")) - make_file("Bar.vy", BAR_CODE) - baz = make_file("Baz.vy", baz_code) + make_file("contracts/IFoo.vyi", INTERFACE_CODE) + baz = make_file("contracts/Baz.vy", baz_code) assert compile_files([baz], ["combined_json"], root_folder=tmp_path) @@ -207,15 +212,36 @@ def test_local_namespace(make_file, tmp_path): make_file(filename, code) paths.append(filename) - for file_name in ("foo.vy", "bar.vy"): - make_file(file_name, BAR_CODE) + for file_name in ("foo.vyi", "bar.vyi"): + make_file(file_name, INTERFACE_CODE) assert compile_files(paths, ["combined_json"], root_folder=tmp_path) def test_compile_outside_root_path(tmp_path, make_file): # absolute paths relative to "." - foo = make_file("foo.vy", FOO_CODE.format("import bar as Bar", "Bar")) - bar = make_file("bar.vy", BAR_CODE) + make_file("ifoo.vyi", INTERFACE_CODE) + foo = make_file("foo.vy", CONTRACT_CODE.format(import_stmt="import ifoo as IFoo", alias="IFoo")) + + assert compile_files([foo], ["combined_json"], root_folder=".") + + +def test_import_library(tmp_path, make_file): + library_source = """ +@internal +def foo() -> uint256: + return block.number + 1 + """ + + contract_source = """ +import lib + +@external +def foo() -> uint256: + return lib.foo() + """ + + make_file("lib.vy", library_source) + contract_file = make_file("contract.vy", contract_source) - assert compile_files([foo, bar], ["combined_json"], root_folder=".") + assert compile_files([contract_file], ["combined_json"], root_folder=tmp_path) is not None diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index 732762d72b..a50946ba21 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -1,30 +1,55 @@ import json +from pathlib import PurePath import pytest import vyper -from vyper.cli.vyper_json import compile_from_input_dict, compile_json, exc_handler_to_dict -from vyper.compiler import OUTPUT_FORMATS, compile_code +from vyper.cli.vyper_json import ( + compile_from_input_dict, + compile_json, + exc_handler_to_dict, + get_inputs, +) +from vyper.compiler import OUTPUT_FORMATS, compile_code, compile_from_file_input +from vyper.compiler.input_bundle import JSONInputBundle from vyper.exceptions import InvalidType, JSONError, SyntaxException FOO_CODE = """ -import contracts.bar as Bar +import contracts.ibar as IBar + +import contracts.library as library @external def foo(a: address) -> bool: - return Bar(a).bar(1) + return IBar(a).bar(1) @external def baz() -> uint256: - return self.balance + return self.balance + library.foo() """ BAR_CODE = """ +import contracts.ibar as IBar + +implements: IBar + @external def bar(a: uint256) -> bool: return True """ +BAR_VYI = """ +@external +def bar(a: uint256) -> bool: + ... +""" + +LIBRARY_CODE = """ +@internal +def foo() -> uint256: + return block.number + 1 +""" + BAD_SYNTAX_CODE = """ def bar()>: """ @@ -52,6 +77,7 @@ def input_json(): "language": "Vyper", "sources": { "contracts/foo.vy": {"content": FOO_CODE}, + "contracts/library.vy": {"content": LIBRARY_CODE}, "contracts/bar.vy": {"content": BAR_CODE}, }, "interfaces": {"contracts/ibar.json": {"abi": BAR_ABI}}, @@ -59,6 +85,14 @@ def input_json(): } +@pytest.fixture(scope="function") +def input_bundle(input_json): + # CMC 2023-12-11 maybe input_json -> JSONInputBundle should be a helper + # function in `vyper_json.py`. + sources = get_inputs(input_json) + return JSONInputBundle(sources, search_paths=[PurePath(".")]) + + # test string and dict inputs both work def test_string_input(input_json): assert compile_json(input_json) == compile_json(json.dumps(input_json)) @@ -77,29 +111,39 @@ def test_keyerror_becomes_jsonerror(input_json): compile_json(input_json) -def test_compile_json(input_json, make_input_bundle): - input_bundle = make_input_bundle({"contracts/bar.vy": BAR_CODE}) +def test_compile_json(input_json, input_bundle): + foo_input = input_bundle.load_file("contracts/foo.vy") + foo = compile_from_file_input( + foo_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + ) - foo = compile_code( - FOO_CODE, - source_id=0, - contract_name="contracts/foo.vy", - output_formats=OUTPUT_FORMATS, - input_bundle=input_bundle, + library_input = input_bundle.load_file("contracts/library.vy") + library = compile_from_file_input( + library_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle ) - bar = compile_code( - BAR_CODE, source_id=1, contract_name="contracts/bar.vy", output_formats=OUTPUT_FORMATS + + bar_input = input_bundle.load_file("contracts/bar.vy") + bar = compile_from_file_input( + bar_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle ) - compile_code_results = {"contracts/bar.vy": bar, "contracts/foo.vy": foo} + compile_code_results = { + "contracts/bar.vy": bar, + "contracts/library.vy": library, + "contracts/foo.vy": foo, + } output_json = compile_json(input_json) - assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + assert list(output_json["contracts"].keys()) == [ + "contracts/foo.vy", + "contracts/library.vy", + "contracts/bar.vy", + ] assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"] assert output_json["compiler"] == f"vyper-{vyper.__version__}" - for source_id, contract_name in enumerate(["foo", "bar"]): + for source_id, contract_name in [(0, "foo"), (2, "library"), (3, "bar")]: path = f"contracts/{contract_name}.vy" data = compile_code_results[path] assert output_json["sources"][path] == {"id": source_id, "ast": data["ast_dict"]["ast"]} @@ -123,13 +167,28 @@ def test_compile_json(input_json, make_input_bundle): } -def test_different_outputs(make_input_bundle, input_json): +def test_compilation_targets(input_json): + output_json = compile_json(input_json) + assert list(output_json["contracts"].keys()) == [ + "contracts/foo.vy", + "contracts/library.vy", + "contracts/bar.vy", + ] + + # omit library.vy + input_json["settings"]["outputSelection"] = {"contracts/foo.vy": "*", "contracts/bar.vy": "*"} + output_json = compile_json(input_json) + + assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + + +def test_different_outputs(input_bundle, input_json): input_json["settings"]["outputSelection"] = { "contracts/bar.vy": "*", "contracts/foo.vy": ["evm.methodIdentifiers"], } output_json = compile_json(input_json) - assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + assert list(output_json["contracts"].keys()) == ["contracts/bar.vy", "contracts/foo.vy"] assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"] assert output_json["compiler"] == f"vyper-{vyper.__version__}" @@ -143,10 +202,9 @@ def test_different_outputs(make_input_bundle, input_json): assert sorted(foo.keys()) == ["evm"] # check method_identifiers - input_bundle = make_input_bundle({"contracts/bar.vy": BAR_CODE}) method_identifiers = compile_code( FOO_CODE, - contract_name="contracts/foo.vy", + contract_path="contracts/foo.vy", output_formats=["method_identifiers"], input_bundle=input_bundle, )["method_identifiers"] @@ -204,11 +262,12 @@ def get(filename, contractname): return result["contracts"][filename][contractname]["evm"]["deployedBytecode"]["sourceMap"] assert get("contracts/foo.vy", "foo").startswith("-1:-1:0") - assert get("contracts/bar.vy", "bar").startswith("-1:-1:1") + assert get("contracts/library.vy", "library").startswith("-1:-1:2") + assert get("contracts/bar.vy", "bar").startswith("-1:-1:3") def test_relative_import_paths(input_json): - input_json["sources"]["contracts/potato/baz/baz.vy"] = {"content": """from ... import foo"""} - input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": """from . import baz"""} - input_json["sources"]["contracts/potato/footato.vy"] = {"content": """from baz import baz"""} + input_json["sources"]["contracts/potato/baz/baz.vy"] = {"content": "from ... import foo"} + input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": "from . import baz"} + input_json["sources"]["contracts/potato/footato.vy"] = {"content": "from baz import baz"} compile_from_input_dict(input_json) diff --git a/tests/unit/cli/vyper_json/test_get_inputs.py b/tests/unit/cli/vyper_json/test_get_inputs.py index 6e323a91bd..c91cc750f2 100644 --- a/tests/unit/cli/vyper_json/test_get_inputs.py +++ b/tests/unit/cli/vyper_json/test_get_inputs.py @@ -2,7 +2,7 @@ import pytest -from vyper.cli.vyper_json import get_compilation_targets, get_inputs +from vyper.cli.vyper_json import get_inputs from vyper.exceptions import JSONError from vyper.utils import keccak256 @@ -122,9 +122,6 @@ def test_interfaces_output(): "interface.folder/bar2.vy": {"content": BAR_CODE}, }, } - targets = get_compilation_targets(input_json) - assert targets == [PurePath("foo.vy")] - result = get_inputs(input_json) assert result == { PurePath("foo.vy"): {"content": FOO_CODE}, diff --git a/tests/unit/cli/vyper_json/test_output_selection.py b/tests/unit/cli/vyper_json/test_output_selection.py index 78ad7404f2..5383190a66 100644 --- a/tests/unit/cli/vyper_json/test_output_selection.py +++ b/tests/unit/cli/vyper_json/test_output_selection.py @@ -8,53 +8,61 @@ def test_no_outputs(): with pytest.raises(KeyError): - get_output_formats({}, {}) + get_output_formats({}) def test_invalid_output(): - input_json = {"settings": {"outputSelection": {"foo.vy": ["abi", "foobar"]}}} - targets = [PurePath("foo.vy")] + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": ["abi", "foobar"]}}, + } with pytest.raises(JSONError): - get_output_formats(input_json, targets) + get_output_formats(input_json) def test_unknown_contract(): - input_json = {"settings": {"outputSelection": {"bar.vy": ["abi"]}}} - targets = [PurePath("foo.vy")] + input_json = {"sources": {}, "settings": {"outputSelection": {"bar.vy": ["abi"]}}} with pytest.raises(JSONError): - get_output_formats(input_json, targets) + get_output_formats(input_json) @pytest.mark.parametrize("output", TRANSLATE_MAP.items()) def test_translate_map(output): - input_json = {"settings": {"outputSelection": {"foo.vy": [output[0]]}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): [output[1]]} + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": [output[0]]}}, + } + assert get_output_formats(input_json) == {PurePath("foo.vy"): [output[1]]} def test_star(): - input_json = {"settings": {"outputSelection": {"*": ["*"]}}} - targets = [PurePath("foo.vy"), PurePath("bar.vy")] + input_json = { + "sources": {"foo.vy": "", "bar.vy": ""}, + "settings": {"outputSelection": {"*": ["*"]}}, + } expected = sorted(set(TRANSLATE_MAP.values())) - result = get_output_formats(input_json, targets) + result = get_output_formats(input_json) assert result == {PurePath("foo.vy"): expected, PurePath("bar.vy"): expected} def test_evm(): - input_json = {"settings": {"outputSelection": {"foo.vy": ["abi", "evm"]}}} - targets = [PurePath("foo.vy")] + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": ["abi", "evm"]}}, + } expected = ["abi"] + sorted(v for k, v in TRANSLATE_MAP.items() if k.startswith("evm")) - result = get_output_formats(input_json, targets) + result = get_output_formats(input_json) assert result == {PurePath("foo.vy"): expected} def test_solc_style(): - input_json = {"settings": {"outputSelection": {"foo.vy": {"": ["abi"], "foo.vy": ["ir"]}}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): ["abi", "ir_dict"]} + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": {"": ["abi"], "foo.vy": ["ir"]}}}, + } + assert get_output_formats(input_json) == {PurePath("foo.vy"): ["abi", "ir_dict"]} def test_metadata(): - input_json = {"settings": {"outputSelection": {"*": ["metadata"]}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): ["metadata"]} + input_json = {"sources": {"foo.vy": ""}, "settings": {"outputSelection": {"*": ["metadata"]}}} + assert get_output_formats(input_json) == {PurePath("foo.vy"): ["metadata"]} diff --git a/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py b/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py index 3b0f700c7e..6b509dd3ef 100644 --- a/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py +++ b/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py @@ -9,11 +9,11 @@ from vyper.exceptions import JSONError FOO_CODE = """ -import contracts.bar as Bar +import contracts.ibar as IBar @external def foo(a: address) -> bool: - return Bar(a).bar(1) + return IBar(a).bar(1) """ BAR_CODE = """ diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index 47b70a8c70..44b823757c 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -1,5 +1,6 @@ import pytest +from vyper.compiler import compile_code from vyper.compiler.phases import CompilerData from vyper.compiler.settings import OptimizationLevel, Settings @@ -71,33 +72,61 @@ def __init__(): ] +# check dead code eliminator works on unreachable functions @pytest.mark.parametrize("code", codes) def test_dead_code_eliminator(code): c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.NONE)) - initcode_asm = [i for i in c.assembly if not isinstance(i, list)] - runtime_asm = c.assembly_runtime - ctor_only_label = "_sym_internal_ctor_only___" - runtime_only_label = "_sym_internal_runtime_only___" + # get the labels + initcode_asm = [i for i in c.assembly if isinstance(i, str)] + runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] + + ctor_only = "ctor_only()" + runtime_only = "runtime_only()" # qux reachable from unoptimized initcode, foo not reachable. - assert ctor_only_label + "_deploy" in initcode_asm - assert runtime_only_label + "_deploy" not in initcode_asm + assert any(ctor_only in instr for instr in initcode_asm) + assert all(runtime_only not in instr for instr in initcode_asm) # all labels should be in unoptimized runtime asm - for s in (ctor_only_label, runtime_only_label): - assert s + "_runtime" in runtime_asm + for s in (ctor_only, runtime_only): + assert any(s in instr for instr in runtime_asm) c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.GAS)) - initcode_asm = [i for i in c.assembly if not isinstance(i, list)] - runtime_asm = c.assembly_runtime + initcode_asm = [i for i in c.assembly if isinstance(i, str)] + runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] # ctor only label should not be in runtime code - for instr in runtime_asm: - if isinstance(instr, str): - assert not instr.startswith(ctor_only_label), instr + assert all(ctor_only not in instr for instr in runtime_asm) # runtime only label should not be in initcode asm - for instr in initcode_asm: - if isinstance(instr, str): - assert not instr.startswith(runtime_only_label), instr + assert all(runtime_only not in instr for instr in initcode_asm) + + +def test_library_code_eliminator(make_input_bundle): + library = """ +@internal +def unused1(): + pass + +@internal +def unused2(): + self.unused1() + +@internal +def some_function(): + pass + """ + code = """ +import library + +@external +def foo(): + library.some_function() + """ + input_bundle = make_input_bundle({"library.vy": library}) + res = compile_code(code, input_bundle=input_bundle, output_formats=["asm"]) + asm = res["asm"] + assert "some_function()" in asm + assert "unused1()" not in asm + assert "unused2()" not in asm diff --git a/tests/unit/compiler/test_input_bundle.py b/tests/unit/compiler/test_input_bundle.py index c49c81219b..e26555b169 100644 --- a/tests/unit/compiler/test_input_bundle.py +++ b/tests/unit/compiler/test_input_bundle.py @@ -1,4 +1,6 @@ +import contextlib import json +import os from pathlib import Path, PurePath import pytest @@ -12,19 +14,19 @@ def input_bundle(tmp_path): return FilesystemInputBundle([tmp_path]) -def test_load_file(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") +def test_load_file(make_file, input_bundle): + filepath = make_file("foo.vy", "contents") file = input_bundle.load_file(Path("foo.vy")) assert isinstance(file, FileInput) - assert file == FileInput(0, tmp_path / Path("foo.vy"), "contents") + assert file == FileInput(0, Path("foo.vy"), filepath, "contents") def test_search_path_context_manager(make_file, tmp_path): ib = FilesystemInputBundle([]) - make_file("foo.vy", "contents") + filepath = make_file("foo.vy", "contents") with pytest.raises(FileNotFoundError): # no search path given @@ -34,7 +36,7 @@ def test_search_path_context_manager(make_file, tmp_path): file = ib.load_file(Path("foo.vy")) assert isinstance(file, FileInput) - assert file == FileInput(0, tmp_path / Path("foo.vy"), "contents") + assert file == FileInput(0, Path("foo.vy"), filepath, "contents") def test_search_path_precedence(make_file, tmp_path, tmp_path_factory, input_bundle): @@ -43,59 +45,85 @@ def test_search_path_precedence(make_file, tmp_path, tmp_path_factory, input_bun tmpdir = tmp_path_factory.mktemp("some_directory") tmpdir2 = tmp_path_factory.mktemp("some_other_directory") + filepaths = [] for i, directory in enumerate([tmp_path, tmpdir, tmpdir2]): - with (directory / "foo.vy").open("w") as f: + path = directory / "foo.vy" + with path.open("w") as f: f.write(f"contents {i}") + filepaths.append(path) ib = FilesystemInputBundle([tmp_path, tmpdir, tmpdir2]) file = ib.load_file("foo.vy") assert isinstance(file, FileInput) - assert file == FileInput(0, tmpdir2 / "foo.vy", "contents 2") + assert file == FileInput(0, "foo.vy", filepaths[2], "contents 2") with ib.search_path(tmpdir): file = ib.load_file("foo.vy") assert isinstance(file, FileInput) - assert file == FileInput(1, tmpdir / "foo.vy", "contents 1") + assert file == FileInput(1, "foo.vy", filepaths[1], "contents 1") # special rules for handling json files def test_load_abi(make_file, input_bundle, tmp_path): contents = json.dumps("some string") - make_file("foo.json", contents) + path = make_file("foo.json", contents) file = input_bundle.load_file("foo.json") assert isinstance(file, ABIInput) - assert file == ABIInput(0, tmp_path / "foo.json", "some string") + assert file == ABIInput(0, "foo.json", path, "some string") # suffix doesn't matter - make_file("foo.txt", contents) - + path = make_file("foo.txt", contents) file = input_bundle.load_file("foo.txt") assert isinstance(file, ABIInput) - assert file == ABIInput(1, tmp_path / "foo.txt", "some string") + assert file == ABIInput(1, "foo.txt", path, "some string") + + +@contextlib.contextmanager +def working_directory(directory): + tmp = os.getcwd() + try: + os.chdir(directory) + yield + finally: + os.chdir(tmp) # check that unique paths give unique source ids def test_source_id_file_input(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") - make_file("bar.vy", "contents 2") + foopath = make_file("foo.vy", "contents") + barpath = make_file("bar.vy", "contents 2") file = input_bundle.load_file("foo.vy") assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "contents") + assert file == FileInput(0, "foo.vy", foopath, "contents") file2 = input_bundle.load_file("bar.vy") # source id increments assert file2.source_id == 1 - assert file2 == FileInput(1, tmp_path / "bar.vy", "contents 2") + assert file2 == FileInput(1, "bar.vy", barpath, "contents 2") file3 = input_bundle.load_file("foo.vy") assert file3.source_id == 0 - assert file3 == FileInput(0, tmp_path / "foo.vy", "contents") + assert file3 == FileInput(0, "foo.vy", foopath, "contents") + + # test source id is stable across different search paths + with working_directory(tmp_path): + with input_bundle.search_path(Path(".")): + file4 = input_bundle.load_file("foo.vy") + assert file4.source_id == 0 + assert file4 == FileInput(0, "foo.vy", foopath, "contents") + + # test source id is stable even when requested filename is different + with working_directory(tmp_path.parent): + with input_bundle.search_path(Path(".")): + file5 = input_bundle.load_file(Path(tmp_path.stem) / "foo.vy") + assert file5.source_id == 0 + assert file5 == FileInput(0, Path(tmp_path.stem) / "foo.vy", foopath, "contents") # check that unique paths give unique source ids @@ -103,37 +131,51 @@ def test_source_id_json_input(make_file, input_bundle, tmp_path): contents = json.dumps("some string") contents2 = json.dumps(["some list"]) - make_file("foo.json", contents) + foopath = make_file("foo.json", contents) - make_file("bar.json", contents2) + barpath = make_file("bar.json", contents2) file = input_bundle.load_file("foo.json") assert isinstance(file, ABIInput) - assert file == ABIInput(0, tmp_path / "foo.json", "some string") + assert file == ABIInput(0, "foo.json", foopath, "some string") file2 = input_bundle.load_file("bar.json") assert isinstance(file2, ABIInput) - assert file2 == ABIInput(1, tmp_path / "bar.json", ["some list"]) + assert file2 == ABIInput(1, "bar.json", barpath, ["some list"]) file3 = input_bundle.load_file("foo.json") - assert isinstance(file3, ABIInput) - assert file3 == ABIInput(0, tmp_path / "foo.json", "some string") + assert file3.source_id == 0 + assert file3 == ABIInput(0, "foo.json", foopath, "some string") + + # test source id is stable across different search paths + with working_directory(tmp_path): + with input_bundle.search_path(Path(".")): + file4 = input_bundle.load_file("foo.json") + assert file4.source_id == 0 + assert file4 == ABIInput(0, "foo.json", foopath, "some string") + + # test source id is stable even when requested filename is different + with working_directory(tmp_path.parent): + with input_bundle.search_path(Path(".")): + file5 = input_bundle.load_file(Path(tmp_path.stem) / "foo.json") + assert file5.source_id == 0 + assert file5 == ABIInput(0, Path(tmp_path.stem) / "foo.json", foopath, "some string") # test some pathological case where the file changes underneath def test_mutating_file_source_id(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") + foopath = make_file("foo.vy", "contents") file = input_bundle.load_file("foo.vy") assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "contents") + assert file == FileInput(0, "foo.vy", foopath, "contents") - make_file("foo.vy", "new contents") + foopath = make_file("foo.vy", "new contents") file = input_bundle.load_file("foo.vy") # source id hasn't changed, even though contents have assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "new contents") + assert file == FileInput(0, "foo.vy", foopath, "new contents") # test the os.normpath behavior of symlink @@ -147,10 +189,12 @@ def test_load_file_symlink(make_file, input_bundle, tmp_path, tmp_path_factory): dir2.mkdir() symlink.symlink_to(dir2, target_is_directory=True) - with (tmp_path / "foo.vy").open("w") as f: - f.write("contents of the upper directory") + outer_path = tmp_path / "foo.vy" + with outer_path.open("w") as f: + f.write("contents of the outer directory") - with (dir1 / "foo.vy").open("w") as f: + inner_path = dir1 / "foo.vy" + with inner_path.open("w") as f: f.write("contents of the inner directory") # symlink rules would be: @@ -159,9 +203,10 @@ def test_load_file_symlink(make_file, input_bundle, tmp_path, tmp_path_factory): # base/first/foo.vy # normpath would be base/symlink/../foo.vy => # base/foo.vy - file = input_bundle.load_file(symlink / ".." / "foo.vy") + to_load = symlink / ".." / "foo.vy" + file = input_bundle.load_file(to_load) - assert file == FileInput(0, tmp_path / "foo.vy", "contents of the upper directory") + assert file == FileInput(0, to_load, outer_path.resolve(), "contents of the outer directory") def test_json_input_bundle_basic(): @@ -169,40 +214,42 @@ def test_json_input_bundle_basic(): input_bundle = JSONInputBundle(files, [PurePath(".")]) file = input_bundle.load_file(PurePath("foo.vy")) - assert file == FileInput(0, PurePath("foo.vy"), "some text") + assert file == FileInput(0, PurePath("foo.vy"), PurePath("foo.vy"), "some text") def test_json_input_bundle_normpath(): - files = {PurePath("foo/../bar.vy"): {"content": "some text"}} + contents = "some text" + files = {PurePath("foo/../bar.vy"): {"content": contents}} input_bundle = JSONInputBundle(files, [PurePath(".")]) - expected = FileInput(0, PurePath("bar.vy"), "some text") + barpath = PurePath("bar.vy") + + expected = FileInput(0, barpath, barpath, contents) file = input_bundle.load_file(PurePath("bar.vy")) assert file == expected file = input_bundle.load_file(PurePath("baz/../bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("baz/../bar.vy"), barpath, contents) file = input_bundle.load_file(PurePath("./bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("./bar.vy"), barpath, contents) with input_bundle.search_path(PurePath("foo")): file = input_bundle.load_file(PurePath("../bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("../bar.vy"), barpath, contents) def test_json_input_abi(): some_abi = ["some abi"] some_abi_str = json.dumps(some_abi) - files = { - PurePath("foo.json"): {"abi": some_abi}, - PurePath("bar.txt"): {"content": some_abi_str}, - } + foopath = PurePath("foo.json") + barpath = PurePath("bar.txt") + files = {foopath: {"abi": some_abi}, barpath: {"content": some_abi_str}} input_bundle = JSONInputBundle(files, [PurePath(".")]) - file = input_bundle.load_file(PurePath("foo.json")) - assert file == ABIInput(0, PurePath("foo.json"), some_abi) + file = input_bundle.load_file(foopath) + assert file == ABIInput(0, foopath, foopath, some_abi) - file = input_bundle.load_file(PurePath("bar.txt")) - assert file == ABIInput(1, PurePath("bar.txt"), some_abi) + file = input_bundle.load_file(barpath) + assert file == ABIInput(1, barpath, barpath, some_abi) diff --git a/tests/unit/semantics/analysis/test_array_index.py b/tests/unit/semantics/analysis/test_array_index.py index 27c0634cf8..5ea373fc19 100644 --- a/tests/unit/semantics/analysis/test_array_index.py +++ b/tests/unit/semantics/analysis/test_array_index.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("value", ["address", "Bytes[10]", "decimal", "bool"]) -def test_type_mismatch(namespace, value): +def test_type_mismatch(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -23,11 +23,11 @@ def foo(b: {value}): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "0x00", "b'\x01'", "False"]) -def test_invalid_literal(namespace, value): +def test_invalid_literal(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -38,11 +38,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidType): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", [-1, 3, -(2**127), 2**127 - 1, 2**256 - 1]) -def test_out_of_bounds(namespace, value): +def test_out_of_bounds(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -53,11 +53,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ArrayIndexException): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["b", "self.b"]) -def test_undeclared_definition(namespace, value): +def test_undeclared_definition(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -68,11 +68,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(UndeclaredDefinition): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["a", "foo", "int128"]) -def test_invalid_reference(namespace, value): +def test_invalid_reference(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -83,4 +83,4 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidReference): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/semantics/analysis/test_cyclic_function_calls.py b/tests/unit/semantics/analysis/test_cyclic_function_calls.py index 2a09bd5ed5..c31146b16f 100644 --- a/tests/unit/semantics/analysis/test_cyclic_function_calls.py +++ b/tests/unit/semantics/analysis/test_cyclic_function_calls.py @@ -3,22 +3,20 @@ from vyper.ast import parse_to_ast from vyper.exceptions import CallViolation, StructureException from vyper.semantics.analysis import validate_semantics -from vyper.semantics.analysis.module import ModuleAnalyzer -def test_self_function_call(namespace): +def test_self_function_call(dummy_input_bundle): code = """ @internal def foo(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_cyclic_function_call(namespace): +def test_cyclic_function_call(dummy_input_bundle): code = """ @internal def foo(): @@ -29,12 +27,11 @@ def bar(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_multi_cyclic_function_call(namespace): +def test_multi_cyclic_function_call(dummy_input_bundle): code = """ @internal def foo(): @@ -53,12 +50,11 @@ def potato(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_global_ann_assign_callable_no_crash(): +def test_global_ann_assign_callable_no_crash(dummy_input_bundle): code = """ balanceOf: public(HashMap[address, uint256]) @@ -68,5 +64,5 @@ def foo(to : address): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException) as excinfo: - validate_semantics(vyper_module, {}) - assert excinfo.value.message == "Value is not callable" + validate_semantics(vyper_module, dummy_input_bundle) + assert excinfo.value.message == "HashMap[address, uint256] is not callable" diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index 0d61a8f8f8..e2c0f555af 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -10,7 +10,7 @@ from vyper.semantics.analysis import validate_semantics -def test_modify_iterator_function_outside_loop(namespace): +def test_modify_iterator_function_outside_loop(dummy_input_bundle): code = """ a: uint256[3] @@ -26,10 +26,10 @@ def bar(): pass """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_pass_memory_var_to_other_function(namespace): +def test_pass_memory_var_to_other_function(dummy_input_bundle): code = """ @internal @@ -46,10 +46,10 @@ def bar(): self.foo(a) """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator(namespace): +def test_modify_iterator(dummy_input_bundle): code = """ a: uint256[3] @@ -61,10 +61,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_bad_keywords(namespace): +def test_bad_keywords(dummy_input_bundle): code = """ @internal @@ -75,10 +75,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(ArgumentException): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_bad_bound(namespace): +def test_bad_bound(dummy_input_bundle): code = """ @internal @@ -89,10 +89,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(StateAccessViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator_function_call(namespace): +def test_modify_iterator_function_call(dummy_input_bundle): code = """ a: uint256[3] @@ -108,10 +108,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator_recursive_function_call(namespace): +def test_modify_iterator_recursive_function_call(dummy_input_bundle): code = """ a: uint256[3] @@ -131,7 +131,7 @@ def baz(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) iterator_inference_codes = [ @@ -169,7 +169,7 @@ def foo(): @pytest.mark.parametrize("code", iterator_inference_codes) -def test_iterator_type_inference_checker(namespace, code): +def test_iterator_type_inference_checker(code, dummy_input_bundle): vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index d390fe9a39..002ee38cd2 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -110,6 +110,6 @@ def test_allocator_overflow(get_contract): """ with pytest.raises( StorageLayoutException, - match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}\n", + match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}", ): get_contract(code) diff --git a/tox.ini b/tox.ini index c949354dfe..f9d4c3b60b 100644 --- a/tox.ini +++ b/tox.ini @@ -53,4 +53,4 @@ commands = basepython = python3 extras = lint commands = - mypy --install-types --non-interactive --follow-imports=silent --ignore-missing-imports --disallow-incomplete-defs -p vyper + mypy --install-types --non-interactive --follow-imports=silent --ignore-missing-imports --implicit-optional -p vyper diff --git a/vyper/__init__.py b/vyper/__init__.py index 482d5c3a60..5bb6469757 100644 --- a/vyper/__init__.py +++ b/vyper/__init__.py @@ -1,6 +1,6 @@ from pathlib import Path as _Path -from vyper.compiler import compile_code # noqa: F401 +from vyper.compiler import compile_code, compile_from_file_input try: from importlib.metadata import PackageNotFoundError # type: ignore diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index e5b81f1e7f..4b46801153 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -6,7 +6,8 @@ from . import nodes, validation from .natspec import parse_natspec from .nodes import compare_nodes -from .utils import ast_to_dict, parse_to_ast, parse_to_ast_with_settings +from .utils import ast_to_dict +from .parse import parse_to_ast, parse_to_ast_with_settings # adds vyper.ast.nodes classes into the local namespace for name, obj in ( diff --git a/vyper/ast/__init__.pyi b/vyper/ast/__init__.pyi index d349e804d6..eac8ffdef5 100644 --- a/vyper/ast/__init__.pyi +++ b/vyper/ast/__init__.pyi @@ -4,5 +4,5 @@ from typing import Any, Optional, Union from . import expansion, folding, nodes, validation from .natspec import parse_natspec as parse_natspec from .nodes import * +from .parse import parse_to_ast as parse_to_ast from .utils import ast_to_dict as ast_to_dict -from .utils import parse_to_ast as parse_to_ast diff --git a/vyper/ast/expansion.py b/vyper/ast/expansion.py index 5471b971a4..1536f39165 100644 --- a/vyper/ast/expansion.py +++ b/vyper/ast/expansion.py @@ -5,22 +5,9 @@ from vyper.semantics.types.function import ContractFunctionT -def expand_annotated_ast(vyper_module: vy_ast.Module) -> None: - """ - Perform expansion / simplification operations on an annotated Vyper AST. - - This pass uses annotated type information to modify the AST, simplifying - logic and expanding subtrees to reduce the compexity during codegen. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node that has been type-checked and annotated. - """ - generate_public_variable_getters(vyper_module) - remove_unused_statements(vyper_module) - - +# TODO: remove this function. it causes correctness/performance problems +# because of copying and mutating the AST - getter generation should be handled +# during code generation. def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: """ Create getter functions for public variables. @@ -32,7 +19,7 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: """ for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_public": True}): - func_type = node._metadata["func_type"] + func_type = node._metadata["getter_type"] input_types, return_type = node._metadata["type"].getter_signature input_nodes = [] @@ -86,31 +73,11 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: returns=return_node, ) - with vyper_module.namespace(): - func_type = ContractFunctionT.from_FunctionDef(expanded) - - expanded._metadata["type"] = func_type - return_node.set_parent(expanded) + # update pointers vyper_module.add_to_body(expanded) + return_node.set_parent(expanded) + with vyper_module.namespace(): + func_type = ContractFunctionT.from_FunctionDef(expanded) -def remove_unused_statements(vyper_module: vy_ast.Module) -> None: - """ - Remove statement nodes that are unused after type checking. - - Once type checking is complete, we can remove now-meaningless statements to - simplify the AST prior to IR generation. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - """ - - # constant declarations - values were substituted within the AST during folding - for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_constant": True}): - vyper_module.remove_from_body(node) - - # `implements: interface` statements - validated during type checking - for node in vyper_module.get_children(vy_ast.ImplementsDecl): - vyper_module.remove_from_body(node) + expanded._metadata["func_type"] = func_type diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index ca9979b2a3..15367ce94a 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -89,7 +89,8 @@ tuple_def: "(" ( NAME | array_def | dyn_array_def | tuple_def ) ( "," ( NAME | a // NOTE: Map takes a basic type and maps to another type (can be non-basic, including maps) _MAP: "HashMap" map_def: _MAP "[" ( NAME | array_def ) "," type "]" -type: ( NAME | array_def | tuple_def | map_def | dyn_array_def ) +imported_type: NAME "." NAME +type: ( NAME | imported_type | array_def | tuple_def | map_def | dyn_array_def ) // Structs can be composed of 1+ basic types or other custom_types _STRUCT_DECL: "struct" @@ -291,7 +292,7 @@ special_builtins: empty | abi_decode // Adapted from: https://docs.python.org/3/reference/grammar.html // Adapted by: Erez Shinan NAME: /[a-zA-Z_]\w*/ -COMMENT: /#[^\n]*/ +COMMENT: /#[^\n\r]*/ _NEWLINE: ( /\r?\n[\t ]*/ | COMMENT )+ @@ -312,8 +313,10 @@ _number: DEC_NUMBER BOOL.2: "True" | "False" +ELLIPSIS: "..." + // TODO: Remove Docstring from here, and add to first part of body -?literal: ( _number | STRING | DOCSTRING | BOOL ) +?literal: ( _number | STRING | DOCSTRING | BOOL | ELLIPSIS) %ignore /[\t \f]+/ // WS %ignore /\\[\t \f]*\r?\n/ // LINE_CONT diff --git a/vyper/ast/natspec.py b/vyper/ast/natspec.py index c25fc423f8..41905b178a 100644 --- a/vyper/ast/natspec.py +++ b/vyper/ast/natspec.py @@ -43,7 +43,7 @@ def parse_natspec(vyper_module_folded: vy_ast.Module) -> Tuple[dict, dict]: for node in [i for i in vyper_module_folded.body if i.get("doc_string.value")]: docstring = node.doc_string.value - func_type = node._metadata["type"] + func_type = node._metadata["func_type"] if func_type.visibility != FunctionVisibility.EXTERNAL: continue diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 69bd1fed53..3bccc5f141 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -589,7 +589,8 @@ def __contains__(self, obj): class Module(TopLevel): - __slots__ = () + # metadata + __slots__ = ("path", "resolved_path", "source_id") def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: """ @@ -897,12 +898,16 @@ def validate(self): raise InvalidLiteral("Cannot have an empty tuple", self) -class Dict(ExprNode): - __slots__ = ("keys", "values") +class NameConstant(Constant): + __slots__ = () -class NameConstant(Constant): - __slots__ = ("value",) +class Ellipsis(Constant): + __slots__ = () + + +class Dict(ExprNode): + __slots__ = ("keys", "values") class Name(ExprNode): @@ -1407,7 +1412,7 @@ class Pass(Stmt): __slots__ = () -class _Import(Stmt): +class _ImportStmt(Stmt): __slots__ = ("name", "alias") def __init__(self, *args, **kwargs): @@ -1419,11 +1424,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -class Import(_Import): +class Import(_ImportStmt): __slots__ = () -class ImportFrom(_Import): +class ImportFrom(_ImportStmt): __slots__ = ("level", "module") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 47c9af8526..05784aed0f 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -2,9 +2,9 @@ import ast as python_ast from typing import Any, Optional, Sequence, Type, Union from .natspec import parse_natspec as parse_natspec +from .parse import parse_to_ast as parse_to_ast +from .parse import parse_to_ast_with_settings as parse_to_ast_with_settings from .utils import ast_to_dict as ast_to_dict -from .utils import parse_to_ast as parse_to_ast -from .utils import parse_to_ast_with_settings as parse_to_ast_with_settings NODE_BASE_ATTRIBUTES: Any NODE_SRC_ATTRIBUTES: Any @@ -59,6 +59,8 @@ class TopLevel(VyperNode): def __contains__(self, obj: Any) -> bool: ... class Module(TopLevel): + path: str = ... + resolved_path: str = ... def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: ... def add_to_body(self, node: VyperNode) -> None: ... def remove_from_body(self, node: VyperNode) -> None: ... @@ -121,6 +123,9 @@ class Bytes(Constant): @property def s(self): ... +class NameConstant(Constant): ... +class Ellipsis(Constant): ... + class List(VyperNode): elements: list = ... @@ -131,8 +136,6 @@ class Dict(VyperNode): keys: list = ... values: list = ... -class NameConstant(Constant): ... - class Name(VyperNode): id: str = ... _type: str = ... @@ -188,7 +191,7 @@ class NotIn(VyperNode): ... class Call(ExprNode): args: list = ... keywords: list = ... - func: Name = ... + func: VyperNode = ... class keyword(VyperNode): ... diff --git a/vyper/ast/annotation.py b/vyper/ast/parse.py similarity index 68% rename from vyper/ast/annotation.py rename to vyper/ast/parse.py index 9c7b1e063f..a2f2542179 100644 --- a/vyper/ast/annotation.py +++ b/vyper/ast/parse.py @@ -1,14 +1,114 @@ import ast as python_ast import tokenize from decimal import Decimal -from typing import Optional, cast +from typing import Any, Dict, List, Optional, Union, cast import asttokens -from vyper.exceptions import CompilerPanic, SyntaxException +from vyper.ast import nodes as vy_ast +from vyper.ast.pre_parser import pre_parse +from vyper.compiler.settings import Settings +from vyper.exceptions import CompilerPanic, ParserException, SyntaxException from vyper.typing import ModificationOffsets +def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module: + _settings, ast = parse_to_ast_with_settings(*args, **kwargs) + return ast + + +def parse_to_ast_with_settings( + source_code: str, + source_id: int = 0, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, + add_fn_node: Optional[str] = None, +) -> tuple[Settings, vy_ast.Module]: + """ + Parses a Vyper source string and generates basic Vyper AST nodes. + + Parameters + ---------- + source_code : str + The Vyper source code to parse. + source_id : int, optional + Source id to use in the `src` member of each node. + contract_name: str, optional + Name of contract. + add_fn_node: str, optional + If not None, adds a dummy Python AST FunctionDef wrapper node. + source_id: int, optional + The source ID generated for this source code. + Corresponds to FileInput.source_id + module_path: str, optional + The path of the source code + Corresponds to FileInput.path + resolved_path: str, optional + The resolved path of the source code + Corresponds to FileInput.resolved_path + + Returns + ------- + list + Untyped, unoptimized Vyper AST nodes. + """ + if "\x00" in source_code: + raise ParserException("No null bytes (\\x00) allowed in the source code.") + settings, class_types, reformatted_code = pre_parse(source_code) + try: + py_ast = python_ast.parse(reformatted_code) + except SyntaxError as e: + # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors + raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e + + # Add dummy function node to ensure local variables are treated as `AnnAssign` + # instead of state variables (`VariableDecl`) + if add_fn_node: + fn_node = python_ast.FunctionDef(add_fn_node, py_ast.body, [], []) + fn_node.body = py_ast.body + fn_node.args = python_ast.arguments(defaults=[]) + py_ast.body = [fn_node] + + annotate_python_ast( + py_ast, + source_code, + class_types, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) + + # Convert to Vyper AST. + module = vy_ast.get_node(py_ast) + assert isinstance(module, vy_ast.Module) # mypy hint + return settings, module + + +def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: + """ + Converts a Vyper AST node, or list of nodes, into a dictionary suitable for + output to the user. + """ + if isinstance(ast_struct, vy_ast.VyperNode): + return ast_struct.to_dict() + + if isinstance(ast_struct, list): + return [i.to_dict() for i in ast_struct] + + raise CompilerPanic(f'Unknown Vyper AST node provided: "{type(ast_struct)}".') + + +def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: + """ + Converts an AST dict, or list of dicts, into Vyper AST node objects. + """ + if isinstance(ast_struct, dict): + return vy_ast.get_node(ast_struct) + if isinstance(ast_struct, list): + return [vy_ast.get_node(i) for i in ast_struct] + raise CompilerPanic(f'Unknown ast_struct provided: "{type(ast_struct)}".') + + class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets @@ -19,11 +119,13 @@ def __init__( modification_offsets: Optional[ModificationOffsets], tokens: asttokens.ASTTokens, source_id: int, - contract_name: Optional[str], + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, ): self._tokens = tokens self._source_id = source_id - self._contract_name = contract_name + self._module_path = module_path + self._resolved_path = resolved_path self._source_code: str = source_code self.counter: int = 0 self._modification_offsets = {} @@ -83,7 +185,9 @@ def _visit_docstring(self, node): return node def visit_Module(self, node): - node.name = self._contract_name + node.path = self._module_path + node.resolved_path = self._resolved_path + node.source_id = self._source_id return self._visit_docstring(node) def visit_FunctionDef(self, node): @@ -163,6 +267,8 @@ def visit_Constant(self, node): node.ast_type = "Str" elif isinstance(node.value, bytes): node.ast_type = "Bytes" + elif isinstance(node.value, Ellipsis.__class__): + node.ast_type = "Ellipsis" else: raise SyntaxException( "Invalid syntax (unsupported Python Constant AST node).", @@ -250,7 +356,8 @@ def annotate_python_ast( source_code: str, modification_offsets: Optional[ModificationOffsets] = None, source_id: int = 0, - contract_name: Optional[str] = None, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, ) -> python_ast.AST: """ Annotate and optimize a Python AST in preparation conversion to a Vyper AST. @@ -270,7 +377,14 @@ def annotate_python_ast( """ tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) - visitor = AnnotatingVisitor(source_code, modification_offsets, tokens, source_id, contract_name) + visitor = AnnotatingVisitor( + source_code, + modification_offsets, + tokens, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) visitor.visit(parsed_ast) return parsed_ast diff --git a/vyper/ast/utils.py b/vyper/ast/utils.py index 4e669385ab..4c2e5394c9 100644 --- a/vyper/ast/utils.py +++ b/vyper/ast/utils.py @@ -1,64 +1,7 @@ -import ast as python_ast -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Union from vyper.ast import nodes as vy_ast -from vyper.ast.annotation import annotate_python_ast -from vyper.ast.pre_parser import pre_parse -from vyper.compiler.settings import Settings -from vyper.exceptions import CompilerPanic, ParserException, SyntaxException - - -def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module: - return parse_to_ast_with_settings(*args, **kwargs)[1] - - -def parse_to_ast_with_settings( - source_code: str, - source_id: int = 0, - contract_name: Optional[str] = None, - add_fn_node: Optional[str] = None, -) -> tuple[Settings, vy_ast.Module]: - """ - Parses a Vyper source string and generates basic Vyper AST nodes. - - Parameters - ---------- - source_code : str - The Vyper source code to parse. - source_id : int, optional - Source id to use in the `src` member of each node. - contract_name: str, optional - Name of contract. - add_fn_node: str, optional - If not None, adds a dummy Python AST FunctionDef wrapper node. - - Returns - ------- - list - Untyped, unoptimized Vyper AST nodes. - """ - if "\x00" in source_code: - raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, class_types, reformatted_code = pre_parse(source_code) - try: - py_ast = python_ast.parse(reformatted_code) - except SyntaxError as e: - # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors - raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e - - # Add dummy function node to ensure local variables are treated as `AnnAssign` - # instead of state variables (`VariableDecl`) - if add_fn_node: - fn_node = python_ast.FunctionDef(add_fn_node, py_ast.body, [], []) - fn_node.body = py_ast.body - fn_node.args = python_ast.arguments(defaults=[]) - py_ast.body = [fn_node] - annotate_python_ast(py_ast, source_code, class_types, source_id, contract_name) - - # Convert to Vyper AST. - module = vy_ast.get_node(py_ast) - assert isinstance(module, vy_ast.Module) # mypy hint - return settings, module +from vyper.exceptions import CompilerPanic def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: diff --git a/vyper/builtins/_utils.py b/vyper/builtins/_utils.py index afc0987b6d..72b05f15e3 100644 --- a/vyper/builtins/_utils.py +++ b/vyper/builtins/_utils.py @@ -1,10 +1,10 @@ from vyper.ast import parse_to_ast from vyper.codegen.context import Context -from vyper.codegen.global_context import GlobalContext from vyper.codegen.stmt import parse_body from vyper.semantics.analysis.local import FunctionNodeVisitor from vyper.semantics.namespace import Namespace, override_global_namespace from vyper.semantics.types.function import ContractFunctionT, FunctionVisibility, StateMutability +from vyper.semantics.types.module import ModuleT def _strip_source_pos(ir_node): @@ -22,15 +22,16 @@ def generate_inline_function(code, variables, variables_2, memory_allocator): # Initialise a placeholder `FunctionDef` AST node and corresponding # `ContractFunctionT` type to rely on the annotation visitors in semantics # module. - ast_code.body[0]._metadata["type"] = ContractFunctionT( + ast_code.body[0]._metadata["func_type"] = ContractFunctionT( "sqrt_builtin", [], [], None, FunctionVisibility.INTERNAL, StateMutability.NONPAYABLE ) # The FunctionNodeVisitor's constructor performs semantic checks # annotate the AST as side effects. - FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer = FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer.analyze() new_context = Context( - vars_=variables, global_ctx=GlobalContext(), memory_allocator=memory_allocator + vars_=variables, module_ctx=ModuleT(ast_code), memory_allocator=memory_allocator ) generated_ir = parse_body(ast_code.body[0].body, new_context) # strip source position info from the generated_ir since diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 22931508a6..d50a31767d 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2499,9 +2499,9 @@ def infer_arg_types(self, node): validate_call_args(node, 2, ["unwrap_tuple"]) data_type = get_exact_type_from_node(node.args[0]) - output_typedef = TYPE_T(type_from_annotation(node.args[1])) + output_type = type_from_annotation(node.args[1]) - return [data_type, output_typedef] + return [data_type, TYPE_T(output_type)] @process_inputs def build_IR(self, expr, args, kwargs, context): diff --git a/vyper/builtins/interfaces/ERC165.vy b/vyper/builtins/interfaces/ERC165.vyi similarity index 88% rename from vyper/builtins/interfaces/ERC165.vy rename to vyper/builtins/interfaces/ERC165.vyi index a4ca451abd..441130f77c 100644 --- a/vyper/builtins/interfaces/ERC165.vy +++ b/vyper/builtins/interfaces/ERC165.vyi @@ -1,4 +1,4 @@ @view @external def supportsInterface(interface_id: bytes4) -> bool: - pass + ... diff --git a/vyper/builtins/interfaces/ERC20.vy b/vyper/builtins/interfaces/ERC20.vyi similarity index 68% rename from vyper/builtins/interfaces/ERC20.vy rename to vyper/builtins/interfaces/ERC20.vyi index 065ca97a9b..ee533ab326 100644 --- a/vyper/builtins/interfaces/ERC20.vy +++ b/vyper/builtins/interfaces/ERC20.vyi @@ -1,38 +1,38 @@ # Events event Transfer: - _from: indexed(address) - _to: indexed(address) - _value: uint256 + sender: indexed(address) + recipient: indexed(address) + value: uint256 event Approval: - _owner: indexed(address) - _spender: indexed(address) - _value: uint256 + owner: indexed(address) + spender: indexed(address) + value: uint256 # Functions @view @external def totalSupply() -> uint256: - pass + ... @view @external def balanceOf(_owner: address) -> uint256: - pass + ... @view @external def allowance(_owner: address, _spender: address) -> uint256: - pass + ... @external def transfer(_to: address, _value: uint256) -> bool: - pass + ... @external def transferFrom(_from: address, _to: address, _value: uint256) -> bool: - pass + ... @external def approve(_spender: address, _value: uint256) -> bool: - pass + ... diff --git a/vyper/builtins/interfaces/ERC20Detailed.vy b/vyper/builtins/interfaces/ERC20Detailed.vyi similarity index 93% rename from vyper/builtins/interfaces/ERC20Detailed.vy rename to vyper/builtins/interfaces/ERC20Detailed.vyi index 7c4f546d45..0be1c6f153 100644 --- a/vyper/builtins/interfaces/ERC20Detailed.vy +++ b/vyper/builtins/interfaces/ERC20Detailed.vyi @@ -5,14 +5,14 @@ @view @external def name() -> String[1]: - pass + ... @view @external def symbol() -> String[1]: - pass + ... @view @external def decimals() -> uint8: - pass + ... diff --git a/vyper/builtins/interfaces/ERC4626.vy b/vyper/builtins/interfaces/ERC4626.vyi similarity index 90% rename from vyper/builtins/interfaces/ERC4626.vy rename to vyper/builtins/interfaces/ERC4626.vyi index 05865406cf..6d9e4c6ef7 100644 --- a/vyper/builtins/interfaces/ERC4626.vy +++ b/vyper/builtins/interfaces/ERC4626.vyi @@ -16,75 +16,75 @@ event Withdraw: @view @external def asset() -> address: - pass + ... @view @external def totalAssets() -> uint256: - pass + ... @view @external def convertToShares(assetAmount: uint256) -> uint256: - pass + ... @view @external def convertToAssets(shareAmount: uint256) -> uint256: - pass + ... @view @external def maxDeposit(owner: address) -> uint256: - pass + ... @view @external def previewDeposit(assets: uint256) -> uint256: - pass + ... @external def deposit(assets: uint256, receiver: address=msg.sender) -> uint256: - pass + ... @view @external def maxMint(owner: address) -> uint256: - pass + ... @view @external def previewMint(shares: uint256) -> uint256: - pass + ... @external def mint(shares: uint256, receiver: address=msg.sender) -> uint256: - pass + ... @view @external def maxWithdraw(owner: address) -> uint256: - pass + ... @view @external def previewWithdraw(assets: uint256) -> uint256: - pass + ... @external def withdraw(assets: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: - pass + ... @view @external def maxRedeem(owner: address) -> uint256: - pass + ... @view @external def previewRedeem(shares: uint256) -> uint256: - pass + ... @external def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: - pass + ... diff --git a/vyper/builtins/interfaces/ERC721.vy b/vyper/builtins/interfaces/ERC721.vyi similarity index 61% rename from vyper/builtins/interfaces/ERC721.vy rename to vyper/builtins/interfaces/ERC721.vyi index 464c0e255b..b8dcfd3c5f 100644 --- a/vyper/builtins/interfaces/ERC721.vy +++ b/vyper/builtins/interfaces/ERC721.vyi @@ -1,67 +1,62 @@ # Events event Transfer: - _from: indexed(address) - _to: indexed(address) - _tokenId: indexed(uint256) + sender: indexed(address) + recipient: indexed(address) + token_id: indexed(uint256) event Approval: - _owner: indexed(address) - _approved: indexed(address) - _tokenId: indexed(uint256) + owner: indexed(address) + approved: indexed(address) + token_id: indexed(uint256) event ApprovalForAll: - _owner: indexed(address) - _operator: indexed(address) - _approved: bool + owner: indexed(address) + operator: indexed(address) + approved: bool # Functions @view @external def supportsInterface(interface_id: bytes4) -> bool: - pass + ... @view @external def balanceOf(_owner: address) -> uint256: - pass + ... @view @external def ownerOf(_tokenId: uint256) -> address: - pass + ... @view @external def getApproved(_tokenId: uint256) -> address: - pass + ... @view @external def isApprovedForAll(_owner: address, _operator: address) -> bool: - pass + ... @external @payable def transferFrom(_from: address, _to: address, _tokenId: uint256): - pass + ... @external @payable -def safeTransferFrom(_from: address, _to: address, _tokenId: uint256): - pass - -@external -@payable -def safeTransferFrom(_from: address, _to: address, _tokenId: uint256, _data: Bytes[1024]): - pass +def safeTransferFrom(_from: address, _to: address, _tokenId: uint256, _data: Bytes[1024] = b""): + ... @external @payable def approve(_approved: address, _tokenId: uint256): - pass + ... @external def setApprovalForAll(_operator: address, _approved: bool): - pass + ... diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index ca1792384e..4f88812fa0 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -271,10 +271,8 @@ def compile_files( with open(storage_file_path) as sfh: storage_layout_override = json.load(sfh) - output = vyper.compile_code( - file.source_code, - contract_name=str(file.path), - source_id=file.source_id, + output = vyper.compile_from_file_input( + file, input_bundle=input_bundle, output_formats=final_formats, exc_handler=exc_handler, diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index 2720f20d23..63da2e0643 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -12,7 +12,7 @@ from vyper.compiler.settings import OptimizationLevel, Settings from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import JSONError -from vyper.utils import keccak256 +from vyper.utils import OrderedSet, keccak256 TRANSLATE_MAP = { "abi": "abi", @@ -151,13 +151,6 @@ def get_evm_version(input_dict: dict) -> Optional[str]: return evm_version -def get_compilation_targets(input_dict: dict) -> list[PurePath]: - # TODO: once we have modules, add optional "compilation_targets" key - # which specifies which sources we actually want to compile. - - return [PurePath(p) for p in input_dict["sources"].keys()] - - def get_inputs(input_dict: dict) -> dict[PurePath, Any]: ret = {} seen = {} @@ -218,14 +211,14 @@ def get_inputs(input_dict: dict) -> dict[PurePath, Any]: # get unique output formats for each contract, given the input_dict # NOTE: would maybe be nice to raise on duplicated output formats -def get_output_formats(input_dict: dict, targets: list[PurePath]) -> dict[PurePath, list[str]]: +def get_output_formats(input_dict: dict) -> dict[PurePath, list[str]]: output_formats: dict[PurePath, list[str]] = {} for path, outputs in input_dict["settings"]["outputSelection"].items(): if isinstance(outputs, dict): # if outputs are given in solc json format, collapse them into a single list - outputs = set(x for i in outputs.values() for x in i) + outputs = OrderedSet(x for i in outputs.values() for x in i) else: - outputs = set(outputs) + outputs = OrderedSet(outputs) for key in [i for i in ("evm", "evm.bytecode", "evm.deployedBytecode") if i in outputs]: outputs.remove(key) @@ -239,13 +232,13 @@ def get_output_formats(input_dict: dict, targets: list[PurePath]) -> dict[PurePa except KeyError as e: raise JSONError(f"Invalid outputSelection - {e}") - outputs = sorted(set(outputs)) + outputs = sorted(list(outputs)) if path == "*": - output_paths = targets + output_paths = [PurePath(path) for path in input_dict["sources"].keys()] else: output_paths = [PurePath(path)] - if output_paths[0] not in targets: + if str(output_paths[0]) not in input_dict["sources"]: raise JSONError(f"outputSelection references unknown contract '{output_paths[0]}'") for output_path in output_paths: @@ -281,9 +274,9 @@ def compile_from_input_dict( no_bytecode_metadata = not input_dict["settings"].get("bytecodeMetadata", True) - compilation_targets = get_compilation_targets(input_dict) sources = get_inputs(input_dict) - output_formats = get_output_formats(input_dict, compilation_targets) + output_formats = get_output_formats(input_dict) + compilation_targets = list(output_formats.keys()) input_bundle = JSONInputBundle(sources, search_paths=[Path(root_folder)]) @@ -295,12 +288,10 @@ def compile_from_input_dict( # use load_file to get a unique source_id file = input_bundle.load_file(contract_path) assert isinstance(file, FileInput) # mypy hint - data = vyper.compile_code( - file.source_code, - contract_name=str(file.path), + data = vyper.compile_from_file_input( + file, input_bundle=input_bundle, output_formats=output_formats[contract_path], - source_id=file.source_id, settings=settings, no_bytecode_metadata=no_bytecode_metadata, ) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 5b79f293bd..dea30faabc 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -48,7 +48,7 @@ def __repr__(self): class Context: def __init__( self, - global_ctx, + module_ctx, memory_allocator, vars_=None, forvars=None, @@ -60,7 +60,7 @@ def __init__( self.vars = vars_ or {} # Global variables, in the form (name, storage location, type) - self.globals = global_ctx.variables + self.globals = module_ctx.variables # Variables defined in for loops, e.g. for i in range(6): ... self.forvars = forvars or {} @@ -75,8 +75,8 @@ def __init__( # Whether we are currently parsing a range expression self.in_range_expr = False - # store global context - self.global_ctx = global_ctx + # store module context + self.module_ctx = module_ctx # full function type self.func_t = func_t diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index dc0e98786f..5870e64e98 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -47,8 +47,10 @@ StringT, StructT, TupleT, + is_type_t, ) from vyper.semantics.types.bytestrings import _BytestringT +from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T from vyper.utils import ( DECIMAL_DIVISOR, @@ -79,7 +81,7 @@ def __init__(self, node, context): self.ir_node = fn() if self.ir_node is None: - raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.", node) + raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.\n", node) self.ir_node.annotation = self.expr.get("node_source_code") self.ir_node.source_pos = getpos(self.expr) @@ -662,39 +664,38 @@ def parse_Call(self): if function_name in DISPATCH_TABLE: return DISPATCH_TABLE[function_name].build_IR(self.expr, self.context) - # Struct constructors do not need `self` prefix. - elif isinstance(self.expr._metadata["type"], StructT): - args = self.expr.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return Expr.struct_literals(args[0], self.context, self.expr._metadata["type"]) + func_type = self.expr.func._metadata["type"] - # Interface assignment. Bar(
). - elif isinstance(self.expr._metadata["type"], InterfaceT): - (arg0,) = self.expr.args - arg_ir = Expr(arg0, self.context).ir_node + # Struct constructor + if is_type_t(func_type, StructT): + args = self.expr.args + if len(args) == 1 and isinstance(args[0], vy_ast.Dict): + return Expr.struct_literals(args[0], self.context, self.expr._metadata["type"]) - assert arg_ir.typ == AddressT() - arg_ir.typ = self.expr._metadata["type"] + # Interface constructor. Bar(). + if is_type_t(func_type, InterfaceT): + (arg0,) = self.expr.args + arg_ir = Expr(arg0, self.context).ir_node - return arg_ir + assert arg_ir.typ == AddressT() + arg_ir.typ = self.expr._metadata["type"] - elif isinstance(self.expr.func, vy_ast.Attribute) and self.expr.func.attr == "pop": + return arg_ir + + if isinstance(func_type, MemberFunctionT) and self.expr.func.attr == "pop": # TODO consider moving this to builtins darray = Expr(self.expr.func.value, self.context).ir_node assert len(self.expr.args) == 0 assert isinstance(darray.typ, DArrayT) return pop_dyn_array(darray, return_popped_item=True) - elif ( - # TODO use expr.func.type.is_internal once - # type annotations are consistently available - isinstance(self.expr.func, vy_ast.Attribute) - and isinstance(self.expr.func.value, vy_ast.Name) - and self.expr.func.value.id == "self" - ): - return self_call.ir_for_self_call(self.expr, self.context) - else: - return external_call.ir_for_external_call(self.expr, self.context) + if isinstance(func_type, ContractFunctionT): + if func_type.is_internal: + return self_call.ir_for_self_call(self.expr, self.context) + else: + return external_call.ir_for_external_call(self.expr, self.context) + + raise CompilerPanic("unreachable", self.expr) def parse_List(self): typ = self.expr._metadata["type"] diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index c48f1256c3..454ba9c8cd 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -7,13 +7,13 @@ from vyper.codegen.core import check_single_exit from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function -from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode from vyper.codegen.memory_allocator import MemoryAllocator from vyper.exceptions import CompilerPanic from vyper.semantics.types import VyperType from vyper.semantics.types.function import ContractFunctionT -from vyper.utils import MemoryPositions, calc_mem_gas, mkalphanum +from vyper.semantics.types.module import ModuleT +from vyper.utils import MemoryPositions, calc_mem_gas @dataclass @@ -44,7 +44,14 @@ def exit_sequence_label(self) -> str: @cached_property def ir_identifier(self) -> str: argz = ",".join([str(argtyp) for argtyp in self.func_t.argument_types]) - return mkalphanum(f"{self.visibility} {self.func_t.name} ({argz})") + + name = self.func_t.name + function_id = self.func_t._function_id + assert function_id is not None + + # include module id in the ir identifier to disambiguate functions + # with the same name but which come from different modules + return f"{self.visibility} {function_id} {name}({argz})" def set_frame_info(self, frame_info: FrameInfo) -> None: if self.frame_info is not None: @@ -94,7 +101,7 @@ class InternalFuncIR(FuncIR): # TODO: should split this into external and internal ir generation? def generate_ir_for_function( - code: vy_ast.FunctionDef, global_ctx: GlobalContext, is_ctor_context: bool = False + code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False ) -> FuncIR: """ Parse a function and produce IR code for the function, includes: @@ -103,7 +110,7 @@ def generate_ir_for_function( - Clamping and copying of arguments - Function body """ - func_t = code._metadata["type"] + func_t = code._metadata["func_type"] # generate _FuncIRInfo func_t._ir_info = _FuncIRInfo(func_t) @@ -126,7 +133,7 @@ def generate_ir_for_function( context = Context( vars_=None, - global_ctx=global_ctx, + module_ctx=module_ctx, memory_allocator=memory_allocator, constancy=Constancy.Mutable if func_t.is_mutable else Constancy.Constant, func_t=func_t, diff --git a/vyper/codegen/global_context.py b/vyper/codegen/global_context.py deleted file mode 100644 index 1f6783f6f8..0000000000 --- a/vyper/codegen/global_context.py +++ /dev/null @@ -1,32 +0,0 @@ -from functools import cached_property -from typing import Optional - -from vyper import ast as vy_ast - - -# Datatype to store all global context information. -# TODO: rename me to ModuleT -class GlobalContext: - def __init__(self, module: Optional[vy_ast.Module] = None): - self._module = module - - @cached_property - def functions(self): - return self._module.get_children(vy_ast.FunctionDef) - - @cached_property - def variables(self): - # variables that this module defines, ex. - # `x: uint256` is a private storage variable named x - if self._module is None: # TODO: make self._module never be None - return None - variable_decls = self._module.get_children(vy_ast.VariableDecl) - return {s.target.id: s.target._metadata["varinfo"] for s in variable_decls} - - @property - def immutables(self): - return [t for t in self.variables.values() if t.is_immutable] - - @cached_property - def immutable_section_bytes(self): - return sum([imm.typ.memory_bytes_required for imm in self.immutables]) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index bfdafa8ba9..ef861e3953 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -5,49 +5,67 @@ from vyper.codegen import core, jumptable_utils from vyper.codegen.core import shr from vyper.codegen.function_definitions import generate_ir_for_function -from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import _is_debug_mode from vyper.exceptions import CompilerPanic -from vyper.utils import method_id_int +from vyper.semantics.types.module import ModuleT +from vyper.utils import OrderedSet, method_id_int -def _topsort_helper(functions, lookup): - # single pass to get a global topological sort of functions (so that each - # function comes after each of its callees). may have duplicates, which get - # filtered out in _topsort() +def _topsort(functions): + # single pass to get a global topological sort of functions (so that each + # function comes after each of its callees). + ret = OrderedSet() + for func_ast in functions: + fn_t = func_ast._metadata["func_type"] + + for reachable_t in fn_t.reachable_internal_functions: + assert reachable_t.ast_def is not None + ret.add(reachable_t.ast_def) + + ret.add(func_ast) + + # create globally unique IDs for each function + for idx, func in enumerate(ret): + func._metadata["func_type"]._function_id = idx + + return list(ret) + - ret = [] +# calculate globally reachable functions to see which +# ones should make it into the final bytecode. +# TODO: in the future, this should get obsolesced by IR dead code eliminator. +def _globally_reachable_functions(functions): + ret = OrderedSet() for f in functions: - # called_functions is a list of ContractFunctions, need to map - # back to FunctionDefs. - callees = [lookup[t.name] for t in f._metadata["type"].called_functions] - ret.extend(_topsort_helper(callees, lookup)) - ret.append(f) + fn_t = f._metadata["func_type"] - return ret + if not fn_t.is_external: + continue + for reachable_t in fn_t.reachable_internal_functions: + assert reachable_t.ast_def is not None + ret.add(reachable_t) -def _topsort(functions): - lookup = {f.name: f for f in functions} - # strip duplicates - return list(dict.fromkeys(_topsort_helper(functions, lookup))) + ret.add(fn_t) + + return ret def _is_constructor(func_ast): - return func_ast._metadata["type"].is_constructor + return func_ast._metadata["func_type"].is_constructor def _is_fallback(func_ast): - return func_ast._metadata["type"].is_fallback + return func_ast._metadata["func_type"].is_fallback def _is_internal(func_ast): - return func_ast._metadata["type"].is_internal + return func_ast._metadata["func_type"].is_internal def _is_payable(func_ast): - return func_ast._metadata["type"].is_payable + return func_ast._metadata["func_type"].is_payable def _annotated_method_id(abi_sig): @@ -63,7 +81,7 @@ def label_for_entry_point(abi_sig, entry_point): # adapt whatever generate_ir_for_function gives us into an IR node def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): - func_t = func_ast._metadata["type"] + func_t = func_ast._metadata["func_type"] assert func_t.is_fallback or func_t.is_constructor ret = ["seq"] @@ -86,12 +104,12 @@ def _ir_for_internal_function(func_ast, *args, **kwargs): return generate_ir_for_function(func_ast, *args, **kwargs).func_ir -def _generate_external_entry_points(external_functions, global_ctx): +def _generate_external_entry_points(external_functions, module_ctx): entry_points = {} # map from ABI sigs to ir code sig_of = {} # reverse map from method ids to abi sig for code in external_functions: - func_ir = generate_ir_for_function(code, global_ctx) + func_ir = generate_ir_for_function(code, module_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): method_id = method_id_int(abi_sig) assert abi_sig not in entry_points @@ -113,13 +131,13 @@ def _generate_external_entry_points(external_functions, global_ctx): # into a bucket (of about 8-10 items), and then uses perfect hash # to select the final function. # costs about 212 gas for typical function and 8 bytes of code (+ ~87 bytes of global overhead) -def _selector_section_dense(external_functions, global_ctx): +def _selector_section_dense(external_functions, module_ctx): function_irs = [] if len(external_functions) == 0: return IRnode.from_list(["seq"]) - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) # generate the label so the jumptable works for abi_sig, entry_point in entry_points.items(): @@ -264,13 +282,13 @@ def _selector_section_dense(external_functions, global_ctx): # a bucket, and then descends into linear search from there. # costs about 126 gas for typical (nonpayable, >0 args, avg bucket size 1.5) # function and 24 bytes of code (+ ~23 bytes of global overhead) -def _selector_section_sparse(external_functions, global_ctx): +def _selector_section_sparse(external_functions, module_ctx): ret = ["seq"] if len(external_functions) == 0: return ret - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) n_buckets, buckets = jumptable_utils.generate_sparse_jumptable_buckets(entry_points.keys()) @@ -367,14 +385,14 @@ def _selector_section_sparse(external_functions, global_ctx): # O(n) linear search for the method id # mainly keep this in for backends which cannot handle the indirect jump # in selector_section_dense and selector_section_sparse -def _selector_section_linear(external_functions, global_ctx): +def _selector_section_linear(external_functions, module_ctx): ret = ["seq"] if len(external_functions) == 0: return ret ret.append(["if", ["lt", "calldatasize", 4], ["goto", "fallback"]]) - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) dispatcher = ["seq"] @@ -402,10 +420,11 @@ def _selector_section_linear(external_functions, global_ctx): return ret -# take a GlobalContext, and generate the runtime and deploy IR -def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: +# take a ModuleT, and generate the runtime and deploy IR +def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: # order functions so that each function comes after all of its callees - function_defs = _topsort(global_ctx.functions) + function_defs = _topsort(module_ctx.function_defs) + reachable = _globally_reachable_functions(module_ctx.function_defs) runtime_functions = [f for f in function_defs if not _is_constructor(f)] init_function = next((f for f in function_defs if _is_constructor(f)), None) @@ -421,20 +440,26 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: # compile internal functions first so we have the function info for func_ast in internal_functions: - func_ir = _ir_for_internal_function(func_ast, global_ctx, False) - internal_functions_ir.append(IRnode.from_list(func_ir)) + # compile it so that _ir_info is populated (whether or not it makes + # it into the final IR artifact) + func_ir = _ir_for_internal_function(func_ast, module_ctx, False) + + # only include it in the IR if it is reachable from an external + # function. + if func_ast._metadata["func_type"] in reachable: + internal_functions_ir.append(IRnode.from_list(func_ir)) if core._opt_none(): - selector_section = _selector_section_linear(external_functions, global_ctx) + selector_section = _selector_section_linear(external_functions, module_ctx) # dense vs sparse global overhead is amortized after about 4 methods. # (--debug will force dense selector table anyway if _opt_codesize is selected.) elif core._opt_codesize() and (len(external_functions) > 4 or _is_debug_mode()): - selector_section = _selector_section_dense(external_functions, global_ctx) + selector_section = _selector_section_dense(external_functions, module_ctx) else: - selector_section = _selector_section_sparse(external_functions, global_ctx) + selector_section = _selector_section_sparse(external_functions, module_ctx) if default_function: - fallback_ir = _ir_for_fallback_or_ctor(default_function, global_ctx) + fallback_ir = _ir_for_fallback_or_ctor(default_function, module_ctx) else: fallback_ir = IRnode.from_list( ["revert", 0, 0], annotation="Default function", error_msg="fallback function" @@ -447,29 +472,30 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: runtime.extend(internal_functions_ir) deploy_code: List[Any] = ["seq"] - immutables_len = global_ctx.immutable_section_bytes + immutables_len = module_ctx.immutable_section_bytes if init_function: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` + init_func_t = init_function._metadata["func_type"] ctor_internal_func_irs = [] internal_functions = [f for f in runtime_functions if _is_internal(f)] for f in internal_functions: - init_func_t = init_function._metadata["type"] - if f.name not in init_func_t.recursive_calls: + func_t = f._metadata["func_type"] + if func_t not in init_func_t.reachable_internal_functions: # unreachable code, delete it continue - func_ir = _ir_for_internal_function(f, global_ctx, is_ctor_context=True) + func_ir = _ir_for_internal_function(f, module_ctx, is_ctor_context=True) ctor_internal_func_irs.append(func_ir) # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, global_ctx, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx, is_ctor_context=True) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables # note: (deploy mem_ofst, code, extra_padding) - init_mem_used = init_function._metadata["type"]._ir_info.frame_info.mem_used + init_mem_used = init_function._metadata["func_type"]._ir_info.frame_info.mem_used # force msize to be initialized past the end of immutables section # so that builtins which use `msize` for "dynamic" memory diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index f03f2eb9c8..f53e4a81b4 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -4,15 +4,6 @@ from vyper.exceptions import StateAccessViolation from vyper.semantics.types.subscriptable import TupleT -_label_counter = 0 - - -# TODO a more general way of doing this -def _generate_label(name: str) -> str: - global _label_counter - _label_counter += 1 - return f"label{_label_counter}" - def _align_kwargs(func_t, args_ir): """ @@ -63,7 +54,7 @@ def ir_for_self_call(stmt_expr, context): # note: internal_function_label asserts `func_t.is_internal`. _label = func_t._ir_info.internal_function_label(context.is_ctor_context) - return_label = _generate_label(f"{_label}_call") + return_label = _freshname(f"{_label}_call") # allocate space for the return buffer # TODO allocate in stmt and/or expr.py diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 254cad32e6..cc7a603b7c 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -26,6 +26,7 @@ from vyper.evm.address_space import MEMORY, STORAGE from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure from vyper.semantics.types import DArrayT, MemberFunctionT +from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.shortcuts import INT256_T, UINT256_T @@ -117,44 +118,32 @@ def parse_Log(self): return events.ir_node_for_log(self.stmt, event, topic_ir, data_ir, self.context) def parse_Call(self): - # TODO use expr.func.type.is_internal once type annotations - # are consistently available. - is_self_function = ( - (isinstance(self.stmt.func, vy_ast.Attribute)) - and isinstance(self.stmt.func.value, vy_ast.Name) - and self.stmt.func.value.id == "self" - ) - if isinstance(self.stmt.func, vy_ast.Name): funcname = self.stmt.func.id return STMT_DISPATCH_TABLE[funcname].build_IR(self.stmt, self.context) - elif isinstance(self.stmt.func, vy_ast.Attribute) and self.stmt.func.attr in ( - "append", - "pop", - ): - func_type = self.stmt.func._metadata["type"] - if isinstance(func_type, MemberFunctionT): - darray = Expr(self.stmt.func.value, self.context).ir_node - args = [Expr(x, self.context).ir_node for x in self.stmt.args] - if self.stmt.func.attr == "append": - # sanity checks - assert len(args) == 1 - arg = args[0] - assert isinstance(darray.typ, DArrayT) - check_assign( - dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) - ) - - return append_dyn_array(darray, arg) - else: - assert len(args) == 0 - return pop_dyn_array(darray, return_popped_item=False) - - if is_self_function: - return self_call.ir_for_self_call(self.stmt, self.context) - else: - return external_call.ir_for_external_call(self.stmt, self.context) + func_type = self.stmt.func._metadata["type"] + + if isinstance(func_type, MemberFunctionT) and self.stmt.func.attr in ("append", "pop"): + darray = Expr(self.stmt.func.value, self.context).ir_node + args = [Expr(x, self.context).ir_node for x in self.stmt.args] + if self.stmt.func.attr == "append": + (arg,) = args + assert isinstance(darray.typ, DArrayT) + check_assign( + dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) + ) + + return append_dyn_array(darray, arg) + else: + assert len(args) == 0 + return pop_dyn_array(darray, return_popped_item=False) + + if isinstance(func_type, ContractFunctionT): + if func_type.is_internal: + return self_call.ir_for_self_call(self.stmt, self.context) + else: + return external_call.ir_for_external_call(self.stmt, self.context) def _assert_reason(self, test_expr, msg): # from parse_Raise: None passed as the assert condition diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 61d7a7c229..026c8369c5 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -5,7 +5,7 @@ import vyper.ast as vy_ast # break an import cycle import vyper.codegen.core as codegen import vyper.compiler.output as output -from vyper.compiler.input_bundle import InputBundle, PathLike +from vyper.compiler.input_bundle import FileInput, InputBundle, PathLike from vyper.compiler.phases import CompilerData from vyper.compiler.settings import Settings from vyper.evm.opcodes import DEFAULT_EVM_VERSION, anchor_evm_version @@ -44,10 +44,8 @@ UNKNOWN_CONTRACT_NAME = "