diff --git a/tests/functional/codegen/calling_convention/test_default_parameters.py b/tests/functional/codegen/calling_convention/test_default_parameters.py index 462748a9c7..240ccb3bb1 100644 --- a/tests/functional/codegen/calling_convention/test_default_parameters.py +++ b/tests/functional/codegen/calling_convention/test_default_parameters.py @@ -111,6 +111,38 @@ def fooBar(a: Bytes[100], b: uint256[2], c: Bytes[6] = b"hello", d: int128[3] = assert c.fooBar(b"booo", [55, 66]) == [b"booo", 66, c_default, d_default] +def test_default_param_interface(get_contract): + code = """ +interface Foo: + def bar(): payable + +FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF) + +@external +def bar(a: uint256, b: Foo = Foo(0xF5D4020dCA6a62bB1efFcC9212AAF3c9819E30D7)) -> Foo: + return b + +@external +def baz(a: uint256, b: Foo = Foo(empty(address))) -> Foo: + return b + +@external +def faz(a: uint256, b: Foo = FOO) -> Foo: + return b + """ + c = get_contract(code) + + addr1 = "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF" + addr2 = "0xF5D4020dCA6a62bB1efFcC9212AAF3c9819E30D7" + + assert c.bar(1) == addr2 + assert c.bar(1, addr1) == addr1 + assert c.baz(1) is None + assert c.baz(1, "0x0000000000000000000000000000000000000000") is None + assert c.faz(1) == addr1 + assert c.faz(1, addr1) == addr1 + + def test_default_param_internal_function(get_contract): code = """ @internal diff --git a/tests/functional/codegen/storage_variables/test_getters.py b/tests/functional/codegen/storage_variables/test_getters.py index 5eac074ef6..a2d9c6d0bb 100644 --- a/tests/functional/codegen/storage_variables/test_getters.py +++ b/tests/functional/codegen/storage_variables/test_getters.py @@ -19,6 +19,9 @@ def foo() -> int128: def test_getter_code(get_contract_with_gas_estimation_for_constants): getter_code = """ +interface V: + def foo(): nonpayable + struct W: a: uint256 b: int128[7] @@ -36,6 +39,7 @@ def test_getter_code(get_contract_with_gas_estimation_for_constants): d: public(immutable(uint256)) e: public(immutable(uint256[2])) f: public(constant(uint256[2])) = [3, 7] +g: public(constant(V)) = V(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF) @external def __init__(): @@ -70,6 +74,7 @@ def __init__(): assert c.d() == 1729 assert c.e(0) == 2 assert [c.f(i) for i in range(2)] == [3, 7] + assert c.g() == "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF" def test_getter_mutability(get_contract): diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index 7089dee3bb..04e778a00e 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -304,6 +304,19 @@ def deposit(deposit_input: Bytes[2048]): CONST_BAR: constant(Bar) = Bar({c: C, d: D}) """, + """ +interface Foo: + def foo(): nonpayable + +FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF) + """, + """ +interface Foo: + def foo(): nonpayable + +FOO: constant(Foo) = Foo(BAR) +BAR: constant(address) = 0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF + """, ] diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 1a488f39e0..d2aefb2fd4 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -129,6 +129,9 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: # ensures the type can be inferred exactly. get_exact_type_from_node(arg) + def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: + return self._modifiability >= modifiability + def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: self._validate_arg_types(node) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 577660b883..6a97e60ce2 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -190,9 +190,9 @@ def parse_Name(self): varinfo = self.context.globals[self.expr.id] if varinfo.is_constant: - # non-struct constants should have already gotten propagated - # during constant folding - assert isinstance(varinfo.typ, StructT) + # constants other than structs and interfaces should have already gotten + # propagated during constant folding + assert isinstance(varinfo.typ, (InterfaceT, StructT)) return Expr.parse_value_expr(varinfo.decl_node.value, self.context) assert varinfo.is_immutable, "not an immutable!" diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 359b51b71e..3e818fa246 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -645,15 +645,11 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> return all(check_modifiability(item, modifiability) for item in node.elements) if isinstance(node, vy_ast.Call): - args = node.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return all(check_modifiability(v, modifiability) for v in args[0].values) - call_type = get_exact_type_from_node(node.func) - # builtins - call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE) - return call_type_modifiability >= modifiability + # structs and interfaces + if hasattr(call_type, "check_modifiability_for_call"): + return call_type.check_modifiability_for_call(node, modifiability) value_type = get_expr_info(node) return value_type.modifiability >= modifiability diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 14949f693f..b15eca8ab2 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -334,6 +334,11 @@ def __init__(self, typedef): def __repr__(self): return f"type({self.typedef})" + def check_modifiability_for_call(self, node, modifiability): + if hasattr(self.typedef, "_ctor_modifiability_for_call"): + return self.typedef._ctor_modifiability_for_call(node, modifiability) + raise StructureException("Value is not callable", node) + # dispatch into ctor if it's called def fetch_call_return(self, node): if hasattr(self.typedef, "_ctor_call_return"): diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index f2c3d74525..ee1da22a87 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -10,8 +10,12 @@ StructureException, UnfoldableNode, ) -from vyper.semantics.analysis.base import VarInfo -from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids +from vyper.semantics.analysis.base import Modifiability, VarInfo +from vyper.semantics.analysis.utils import ( + check_modifiability, + validate_expected_type, + validate_unique_method_ids, +) from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType from vyper.semantics.types.function import ContractFunctionT @@ -81,6 +85,9 @@ def _ctor_arg_types(self, node): def _ctor_kwarg_types(self, node): return {} + def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: + return check_modifiability(node.args[0], modifiability) + # TODO x.validate_implements(other) def validate_implements(self, node: vy_ast.ImplementsDecl) -> None: namespace = get_namespace() diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 8ef9aa8d4a..92a455e3d8 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -14,8 +14,9 @@ UnknownAttribute, VariableDeclarationException, ) +from vyper.semantics.analysis.base import Modifiability from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions -from vyper.semantics.analysis.utils import validate_expected_type +from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType from vyper.semantics.types.subscriptable import HashMapT @@ -419,3 +420,6 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": ) return self + + def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: + return all(check_modifiability(v, modifiability) for v in node.args[0].values)