Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow constant interfaces #3718

Merged
merged 11 commits into from
Jan 15, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,38 @@ def fooBar(a: Bytes[100], b: uint256[2], c: Bytes[6] = b"hello", d: int128[3] =
assert c.fooBar(b"booo", [55, 66]) == [b"booo", 66, c_default, d_default]


def test_default_param_interface(get_contract):
code = """
interface Foo:
def bar(): payable

FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)

@external
def bar(a: uint256, b: Foo = Foo(0xF5D4020dCA6a62bB1efFcC9212AAF3c9819E30D7)) -> Foo:
return b

@external
def baz(a: uint256, b: Foo = Foo(empty(address))) -> Foo:
return b

@external
def faz(a: uint256, b: Foo = FOO) -> Foo:
return b
"""
c = get_contract(code)

addr1 = "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"
addr2 = "0xF5D4020dCA6a62bB1efFcC9212AAF3c9819E30D7"

assert c.bar(1) == addr2
assert c.bar(1, addr1) == addr1
assert c.baz(1) is None
assert c.baz(1, "0x0000000000000000000000000000000000000000") is None
assert c.faz(1) == addr1
assert c.faz(1, addr1) == addr1


def test_default_param_internal_function(get_contract):
code = """
@internal
Expand Down
5 changes: 5 additions & 0 deletions tests/functional/codegen/storage_variables/test_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def foo() -> int128:

def test_getter_code(get_contract_with_gas_estimation_for_constants):
getter_code = """
interface V:
def foo(): nonpayable

struct W:
a: uint256
b: int128[7]
Expand All @@ -36,6 +39,7 @@ def test_getter_code(get_contract_with_gas_estimation_for_constants):
d: public(immutable(uint256))
e: public(immutable(uint256[2]))
f: public(constant(uint256[2])) = [3, 7]
g: public(constant(V)) = V(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)

@external
def __init__():
Expand Down Expand Up @@ -70,6 +74,7 @@ def __init__():
assert c.d() == 1729
assert c.e(0) == 2
assert [c.f(i) for i in range(2)] == [3, 7]
assert c.g() == "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"


def test_getter_mutability(get_contract):
Expand Down
13 changes: 13 additions & 0 deletions tests/functional/syntax/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,19 @@ def deposit(deposit_input: Bytes[2048]):

CONST_BAR: constant(Bar) = Bar({c: C, d: D})
""",
"""
interface Foo:
def foo(): nonpayable

FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)
""",
"""
interface Foo:
def foo(): nonpayable

FOO: constant(Foo) = Foo(BAR)
BAR: constant(address) = 0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF
""",
]


Expand Down
3 changes: 3 additions & 0 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None:
# ensures the type can be inferred exactly.
get_exact_type_from_node(arg)

def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool:
return self._modifiability >= modifiability

def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
self._validate_arg_types(node)

Expand Down
6 changes: 3 additions & 3 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ def parse_Name(self):
varinfo = self.context.globals[self.expr.id]

if varinfo.is_constant:
# non-struct constants should have already gotten propagated
# during constant folding
assert isinstance(varinfo.typ, StructT)
# constants other than structs and interfaces should have already gotten
# propagated during constant folding
assert isinstance(varinfo.typ, (InterfaceT, StructT))
return Expr.parse_value_expr(varinfo.decl_node.value, self.context)

assert varinfo.is_immutable, "not an immutable!"
Expand Down
13 changes: 6 additions & 7 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,15 +645,14 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) ->
return all(check_modifiability(item, modifiability) for item in node.elements)

if isinstance(node, vy_ast.Call):
args = node.args
if len(args) == 1 and isinstance(args[0], vy_ast.Dict):
return all(check_modifiability(v, modifiability) for v in args[0].values)

call_type = get_exact_type_from_node(node.func)

# builtins
call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE)
return call_type_modifiability >= modifiability
# structs and interfaces
if isinstance(call_type, TYPE_T):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's more complicated but i think this actually should use TYPE_T.check_modifiability_for_call (which can dispatch to something like StructT.ctor_modifiability_for_call). the way you have it written, a struct instance can pass the modifiability check.

call_type = call_type.typedef

if hasattr(call_type, "check_modifiability_for_call"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a general comment, i'd like to move these hasattr checks to be more like

if call_type._is_callable:
    call_type.check_modifiability_for_call(...)

that way we impose more structure on the APIs of callable types. but i think that can be pushed to "future work"

return call_type.check_modifiability_for_call(node, modifiability)

value_type = get_expr_info(node)
return value_type.modifiability >= modifiability
11 changes: 9 additions & 2 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
StructureException,
UnfoldableNode,
)
from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids
from vyper.semantics.analysis.base import Modifiability, VarInfo
from vyper.semantics.analysis.utils import (
check_modifiability,
validate_expected_type,
validate_unique_method_ids,
)
Comment on lines +14 to +18

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.utils
begins an import cycle.
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
from vyper.semantics.types.function import ContractFunctionT
Expand Down Expand Up @@ -80,6 +84,9 @@
def _ctor_kwarg_types(self, node):
return {}

def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool:
return check_modifiability(node.args[0], modifiability)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will panic if len(node.args) < 1 (because in some cases, check_modifiability runs before validate_call_args).


# TODO x.validate_implements(other)
def validate_implements(self, node: vy_ast.ImplementsDecl) -> None:
namespace = get_namespace()
Expand Down
6 changes: 5 additions & 1 deletion vyper/semantics/types/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
UnknownAttribute,
VariableDeclarationException,
)
from vyper.semantics.analysis.base import Modifiability
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.analysis.utils import validate_expected_type
from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.utils
begins an import cycle.
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.subscriptable import HashMapT
Expand Down Expand Up @@ -419,3 +420,6 @@
)

return self

def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool:
return all(check_modifiability(v, modifiability) for v in node.args[0].values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same -- will panic if len(node.args) < 1

Loading