Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow constant interfaces #3718

Merged
merged 11 commits into from
Jan 15, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,41 @@ def fooBar(a: Bytes[100], b: uint256[2], c: Bytes[6] = b"hello", d: int128[3] =
assert c.fooBar(b"booo", [55, 66]) == [b"booo", 66, c_default, d_default]


def test_default_param_interface(get_contract):
code = """
interface Foo:
def bar(): payable

FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)

@external
def bar(a: uint256, b: Foo = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)) -> Foo:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be good for the tests if bar() and faz() returned something different

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean like any constant interface?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, like 0xaaaa vs 0xbbbbb

return b

@external
def baz(a: uint256, b: Foo = Foo(empty(address))) -> Foo:
return b

@external
def faz(a: uint256, b: Foo = FOO) -> Foo:
return b
"""
c = get_contract(code)

assert c.bar(1) == "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"
assert (
c.bar(1, "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF")
== "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"
)
assert c.baz(1) is None
assert c.baz(1, "0x0000000000000000000000000000000000000000") is None
assert c.faz(1) == "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"
assert (
c.faz(1, "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF")
== "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"
)


def test_default_param_internal_function(get_contract):
code = """
@internal
Expand Down
5 changes: 5 additions & 0 deletions tests/functional/codegen/storage_variables/test_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def foo() -> int128:

def test_getter_code(get_contract_with_gas_estimation_for_constants):
getter_code = """
interface V:
def foo(): nonpayable

struct W:
a: uint256
b: int128[7]
Expand All @@ -36,6 +39,7 @@ def test_getter_code(get_contract_with_gas_estimation_for_constants):
d: public(immutable(uint256))
e: public(immutable(uint256[2]))
f: public(constant(uint256[2])) = [3, 7]
g: public(constant(V)) = V(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)

@external
def __init__():
Expand Down Expand Up @@ -70,6 +74,7 @@ def __init__():
assert c.d() == 1729
assert c.e(0) == 2
assert [c.f(i) for i in range(2)] == [3, 7]
assert c.g() == "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"


def test_getter_mutability(get_contract):
Expand Down
13 changes: 13 additions & 0 deletions tests/functional/syntax/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,19 @@ def deposit(deposit_input: Bytes[2048]):

CONST_BAR: constant(Bar) = Bar({c: C, d: D})
""",
"""
interface Foo:
def foo(): nonpayable

FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)
""",
"""
interface Foo:
def foo(): nonpayable

FOO: constant(Foo) = Foo(BAR)
BAR: constant(address) = 0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF
""",
]


Expand Down
6 changes: 3 additions & 3 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ def parse_Name(self):
varinfo = self.context.globals[self.expr.id]

if varinfo.is_constant:
# non-struct constants should have already gotten propagated
# during constant folding
assert isinstance(varinfo.typ, StructT)
# constants other than structs and interfaces should have already gotten
# propagated during constant folding
assert isinstance(varinfo.typ, (InterfaceT, StructT))
return Expr.parse_value_expr(varinfo.decl_node.value, self.context)

assert varinfo.is_immutable, "not an immutable!"
Expand Down
12 changes: 11 additions & 1 deletion vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
from vyper.semantics.types.base import TYPE_T, VyperType, is_type_t
from vyper.semantics.types.bytestrings import BytesT, StringT
from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT
from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT
Expand Down Expand Up @@ -646,10 +646,20 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) ->

if isinstance(node, vy_ast.Call):
args = node.args

# structs
if len(args) == 1 and isinstance(args[0], vy_ast.Dict):
return all(check_modifiability(v, modifiability) for v in args[0].values)

call_type = get_exact_type_from_node(node.func)

# interfaces
from vyper.semantics.types.module import InterfaceT
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved

if is_type_t(call_type, InterfaceT):
return check_modifiability(args[0], modifiability)

# builtins
call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE)
return call_type_modifiability >= modifiability

Expand Down
Loading