Skip to content

Commit

Permalink
feat[lang]: add module.__at__() to cast to interface (#4090)
Browse files Browse the repository at this point in the history
add `module.__at__`, a new `MemberFunctionT`, which allows the user to
cast addresses to a module's interface.

additionally, fix a bug where interfaces defined inline could not
be exported. this is simultaneously fixed as a related bug because
previously, interfaces could come up in export analysis as `InterfaceT`
or `TYPE_T` depending on their provenance. this commit fixes the bug by
making them `TYPE_T` in both imported and inlined provenance.

this also allows `module.__interface__` to be used in export position
by adding it to `ModuleT`'s members. note this has an unwanted side
effect of allowing `module.__interface__` in call position; in other
words, `module.__interface__(<address>)` has the same behavior as
`module.__at__(<address>)` when use as an expression. this can be
addressed in a later refactor.

refactor:
- wrap interfaces in `TYPE_T`
- streamline an `isinstance(t, (VyperType, TYPE_T))` check. TYPE_T` now
  inherits from `VyperType`, so it doesn't need to be listed separately

---------

Co-authored-by: cyberthirst <[email protected]>
  • Loading branch information
charles-cooper and cyberthirst authored Nov 25, 2024
1 parent 8f433f8 commit f249c93
Show file tree
Hide file tree
Showing 16 changed files with 431 additions and 28 deletions.
15 changes: 15 additions & 0 deletions docs/using-modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ The ``_times_two()`` helper function in the above module can be immediately used
The other functions cannot be used yet, because they touch the ``ownable`` module's state. There are two ways to declare a module so that its state can be used.

Using a module as an interface
==============================

A module can be used as an interface with the ``__at__`` syntax.

.. code-block:: vyper
import ownable
an_ownable: ownable.__interface__
def call_ownable(addr: address):
self.an_ownable = ownable.__at__(addr)
self.an_ownable.transfer_ownership(...)
Initializing a module
=====================

Expand Down
23 changes: 23 additions & 0 deletions tests/functional/codegen/modules/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,26 @@ def __init__():
# call `c.__default__()`
env.message_call(c.address)
assert c.counter() == 6


def test_inline_interface_export(make_input_bundle, get_contract):
lib1 = """
interface IAsset:
def asset() -> address: view
implements: IAsset
@external
@view
def asset() -> address:
return self
"""
main = """
import lib1
exports: lib1.IAsset
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
c = get_contract(main, input_bundle=input_bundle)

assert c.asset() == c.address
36 changes: 35 additions & 1 deletion tests/functional/codegen/modules/test_interface_imports.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest


def test_import_interface_types(make_input_bundle, get_contract):
ifaces = """
interface IFoo:
Expand Down Expand Up @@ -50,16 +53,47 @@ def foo() -> bool:
# check that this typechecks both directions
a: lib1.IERC20 = IERC20(msg.sender)
b: lib2.IERC20 = IERC20(msg.sender)
c: IERC20 = lib1.IERC20(msg.sender) # allowed in call position
# return the equality so we can sanity check it
return a == b
return a == b and b == c
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2})
c = get_contract(main, input_bundle=input_bundle)

assert c.foo() is True


@pytest.mark.parametrize("interface_syntax", ["__at__", "__interface__"])
def test_intrinsic_interface(get_contract, make_input_bundle, interface_syntax):
lib = """
@external
@view
def foo() -> uint256:
# detect self call
if msg.sender == self:
return 4
else:
return 5
"""

main = f"""
import lib
exports: lib.__interface__
@external
@view
def bar() -> uint256:
return staticcall lib.{interface_syntax}(self).foo()
"""
input_bundle = make_input_bundle({"lib.vy": lib})
c = get_contract(main, input_bundle=input_bundle)

assert c.foo() == 5
assert c.bar() == 4


def test_import_interface_flags(make_input_bundle, get_contract):
ifaces = """
flag Foo:
Expand Down
89 changes: 89 additions & 0 deletions tests/functional/codegen/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,3 +774,92 @@ def foo(s: MyStruct) -> MyStruct:
assert "b: uint256" in out
assert "struct Voter:" in out
assert "voted: bool" in out


def test_intrinsic_interface_instantiation(make_input_bundle, get_contract):
lib1 = """
@external
@view
def foo():
pass
"""
main = """
import lib1
i: lib1.__interface__
@external
def bar() -> lib1.__interface__:
self.i = lib1.__at__(self)
return self.i
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
c = get_contract(main, input_bundle=input_bundle)

assert c.bar() == c.address


def test_intrinsic_interface_converts(make_input_bundle, get_contract):
lib1 = """
@external
@view
def foo():
pass
"""
main = """
import lib1
@external
def bar() -> lib1.__interface__:
return lib1.__at__(self)
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
c = get_contract(main, input_bundle=input_bundle)

assert c.bar() == c.address


def test_intrinsic_interface_kws(env, make_input_bundle, get_contract):
value = 10**5
lib1 = f"""
@external
@payable
def foo(a: address):
send(a, {value})
"""
main = f"""
import lib1
exports: lib1.__interface__
@external
def bar(a: address):
extcall lib1.__at__(self).foo(a, value={value})
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
c = get_contract(main, input_bundle=input_bundle)
env.set_balance(c.address, value)
original_balance = env.get_balance(env.deployer)
c.bar(env.deployer)
assert env.get_balance(env.deployer) == original_balance + value


def test_intrinsic_interface_defaults(env, make_input_bundle, get_contract):
lib1 = """
@external
@payable
def foo(i: uint256=1) -> uint256:
return i
"""
main = """
import lib1
exports: lib1.__interface__
@external
def bar() -> uint256:
return extcall lib1.__at__(self).foo()
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
c = get_contract(main, input_bundle=input_bundle)
assert c.bar() == 1
34 changes: 33 additions & 1 deletion tests/functional/syntax/modules/test_deploy_visibility.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from vyper.compiler import compile_code
from vyper.exceptions import CallViolation
from vyper.exceptions import CallViolation, UnknownAttribute


def test_call_deploy_from_external(make_input_bundle):
Expand All @@ -25,3 +25,35 @@ def foo():
compile_code(main, input_bundle=input_bundle)

assert e.value.message == "Cannot call an @deploy function from an @external function!"


@pytest.mark.parametrize("interface_syntax", ["__interface__", "__at__"])
def test_module_interface_init(make_input_bundle, tmp_path, interface_syntax):
lib1 = """
#lib1.vy
k: uint256
@external
def bar():
pass
@deploy
def __init__():
self.k = 10
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})

code = f"""
import lib1
@deploy
def __init__():
lib1.{interface_syntax}(self).__init__()
"""

with pytest.raises(UnknownAttribute) as e:
compile_code(code, input_bundle=input_bundle)

# as_posix() for windows tests
lib1_path = (tmp_path / "lib1.vy").as_posix()
assert e.value.message == f"interface {lib1_path} has no member '__init__'."
106 changes: 106 additions & 0 deletions tests/functional/syntax/modules/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,28 @@ def do_xyz():
assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!"


def test_no_export_unimplemented_inline_interface(make_input_bundle):
lib1 = """
interface ifoo:
def do_xyz(): nonpayable
# technically implements ifoo, but missing `implements: ifoo`
@external
def do_xyz():
pass
"""
main = """
import lib1
exports: lib1.ifoo
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(InterfaceViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!"


def test_export_selector_conflict(make_input_bundle):
ifoo = """
@external
Expand Down Expand Up @@ -444,3 +466,87 @@ def __init__():
with pytest.raises(InterfaceViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!"


def test_export_empty_interface(make_input_bundle, tmp_path):
lib1 = """
def an_internal_function():
pass
"""
main = """
import lib1
exports: lib1.__interface__
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(StructureException) as e:
compile_code(main, input_bundle=input_bundle)

# as_posix() for windows
lib1_path = (tmp_path / "lib1.vy").as_posix()
assert e.value._message == f"lib1 (located at `{lib1_path}`) has no external functions!"


def test_invalid_export(make_input_bundle):
lib1 = """
@external
def foo():
pass
"""
main = """
import lib1
a: address
exports: lib1.__interface__(self.a).foo
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})

with pytest.raises(StructureException) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "invalid export of a value"
assert e.value._hint == "exports should look like <module>.<function | interface>"

main = """
interface Foo:
def foo(): nonpayable
exports: Foo
"""
with pytest.raises(StructureException) as e:
compile_code(main)

assert e.value._message == "invalid export"
assert e.value._hint == "exports should look like <module>.<function | interface>"


@pytest.mark.parametrize("exports_item", ["__at__", "__at__(self)", "__at__(self).__interface__"])
def test_invalid_at_exports(get_contract, make_input_bundle, exports_item):
lib = """
@external
@view
def foo() -> uint256:
return 5
"""

main = f"""
import lib
exports: lib.{exports_item}
@external
@view
def bar() -> uint256:
return staticcall lib.__at__(self).foo()
"""
input_bundle = make_input_bundle({"lib.vy": lib})

with pytest.raises(Exception) as e:
compile_code(main, input_bundle=input_bundle)

if exports_item == "__at__":
assert "not a function or interface" in str(e.value)
if exports_item == "__at__(self)":
assert "invalid exports" in str(e.value)
if exports_item == "__at__(self).__interface__":
assert "has no member '__interface__'" in str(e.value)
Loading

0 comments on commit f249c93

Please sign in to comment.