Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow constant interfaces #3718

Merged
merged 11 commits into from
Jan 15, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,41 @@ def fooBar(a: Bytes[100], b: uint256[2], c: Bytes[6] = b"hello", d: int128[3] =
assert c.fooBar(b"booo", [55, 66]) == [b"booo", 66, c_default, d_default]


def test_default_param_interface(get_contract):
code = """
interface Foo:
def bar(): payable

FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)

@external
def bar(a: uint256, b: Foo = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)) -> Foo:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be good for the tests if bar() and faz() returned something different

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean like any constant interface?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, like 0xaaaa vs 0xbbbbb

return b

@external
def baz(a: uint256, b: Foo = Foo(empty(address))) -> Foo:
return b

@external
def faz(a: uint256, b: Foo = FOO) -> Foo:
return b
"""
c = get_contract(code)

assert c.bar(1) == "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"
assert (
c.bar(1, "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF")
== "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"
)
assert c.baz(1) is None
assert c.baz(1, "0x0000000000000000000000000000000000000000") is None
assert c.faz(1) == "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"
assert (
c.faz(1, "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF")
== "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"
)


def test_default_param_internal_function(get_contract):
code = """
@internal
Expand Down
5 changes: 5 additions & 0 deletions tests/functional/codegen/storage_variables/test_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def foo() -> int128:

def test_getter_code(get_contract_with_gas_estimation_for_constants):
getter_code = """
interface V:
def foo(): nonpayable

struct W:
a: uint256
b: int128[7]
Expand All @@ -36,6 +39,7 @@ def test_getter_code(get_contract_with_gas_estimation_for_constants):
d: public(immutable(uint256))
e: public(immutable(uint256[2]))
f: public(constant(uint256[2])) = [3, 7]
g: public(constant(V)) = V(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)

@external
def __init__():
Expand Down Expand Up @@ -70,6 +74,7 @@ def __init__():
assert c.d() == 1729
assert c.e(0) == 2
assert [c.f(i) for i in range(2)] == [3, 7]
assert c.g() == "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"


def test_getter_mutability(get_contract):
Expand Down
13 changes: 13 additions & 0 deletions tests/functional/syntax/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,19 @@ def deposit(deposit_input: Bytes[2048]):

CONST_BAR: constant(Bar) = Bar({c: C, d: D})
""",
"""
interface Foo:
def foo(): nonpayable

FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)
""",
"""
interface Foo:
def foo(): nonpayable

FOO: constant(Foo) = Foo(BAR)
BAR: constant(address) = 0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF
""",
]


Expand Down
6 changes: 3 additions & 3 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ def parse_Name(self):
varinfo = self.context.globals[self.expr.id]

if varinfo.is_constant:
# non-struct constants should have already gotten propagated
# during constant folding
assert isinstance(varinfo.typ, StructT)
# constants other than structs and interfaces should have already gotten
# propagated during constant folding
assert isinstance(varinfo.typ, (InterfaceT, StructT))
return Expr.parse_value_expr(varinfo.decl_node.value, self.context)

assert varinfo.is_immutable, "not an immutable!"
Expand Down
13 changes: 9 additions & 4 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,11 +645,16 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) ->
return all(check_modifiability(item, modifiability) for item in node.elements)

if isinstance(node, vy_ast.Call):
args = node.args
if len(args) == 1 and isinstance(args[0], vy_ast.Dict):
return all(check_modifiability(v, modifiability) for v in args[0].values)

call_type = get_exact_type_from_node(node.func)

# structs and interfaces
if isinstance(call_type, TYPE_T):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's more complicated but i think this actually should use TYPE_T.check_modifiability_for_call (which can dispatch to something like StructT.ctor_modifiability_for_call). the way you have it written, a struct instance can pass the modifiability check.

call_type = call_type.typedef

if hasattr(call_type, "check_modifiability_for_call"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a general comment, i'd like to move these hasattr checks to be more like

if call_type._is_callable:
    call_type.check_modifiability_for_call(...)

that way we impose more structure on the APIs of callable types. but i think that can be pushed to "future work"

return call_type.check_modifiability_for_call(node, modifiability)

# builtins
call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE)
return call_type_modifiability >= modifiability

Expand Down
11 changes: 9 additions & 2 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from vyper.abi_types import ABI_Address, ABIType
from vyper.ast.validation import validate_call_args
from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException
from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids
from vyper.semantics.analysis.base import Modifiability, VarInfo
from vyper.semantics.analysis.utils import (
check_modifiability,
validate_expected_type,
validate_unique_method_ids,
)
Comment on lines +14 to +18

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.utils
begins an import cycle.
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
from vyper.semantics.types.function import ContractFunctionT
Expand Down Expand Up @@ -66,6 +70,9 @@
def _ctor_kwarg_types(self, node):
return {}

def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool:
return check_modifiability(node.args[0], modifiability)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will panic if len(node.args) < 1 (because in some cases, check_modifiability runs before validate_call_args).


# TODO x.validate_implements(other)
def validate_implements(self, node: vy_ast.ImplementsDecl) -> None:
namespace = get_namespace()
Expand Down
6 changes: 5 additions & 1 deletion vyper/semantics/types/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
UnknownAttribute,
VariableDeclarationException,
)
from vyper.semantics.analysis.base import Modifiability
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.analysis.utils import validate_expected_type
from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.utils
begins an import cycle.
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.subscriptable import HashMapT
Expand Down Expand Up @@ -408,3 +409,6 @@
)

return self

def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool:
return all(check_modifiability(v, modifiability) for v in node.args[0].values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same -- will panic if len(node.args) < 1

Loading