Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove duplicate terminus checking code #3541

Merged
merged 13 commits into from
Jan 14, 2024
8 changes: 0 additions & 8 deletions tests/functional/codegen/features/test_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,6 @@ def test():
assert self.ret1() == 1
""",
"""
@internal
def valid_address(sender: address) -> bool:
selfdestruct(sender)
@external
def test():
assert self.valid_address(msg.sender)
""",
"""
@external
def test():
assert raw_call(msg.sender, b'', max_outsize=1, gas=10, value=1000*1000) == b''
Expand Down
1 change: 0 additions & 1 deletion tests/functional/codegen/features/test_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ def foo(i: bool) -> int128:
else:
assert 2 != 0
return 7
return 11
"""

c = get_contract_with_gas_estimation(conditional_return_code)
Expand Down
43 changes: 33 additions & 10 deletions tests/functional/syntax/test_unbalanced_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
@external
def foo() -> int128:
pass
pass # missing return
""",
FunctionDeclarationException,
),
Expand All @@ -18,6 +18,7 @@ def foo() -> int128:
def foo() -> int128:
if False:
return 123
# missing return
""",
FunctionDeclarationException,
),
Expand All @@ -27,19 +28,19 @@ def foo() -> int128:
def test() -> int128:
if 1 == 1 :
return 1
if True:
if True: # unreachable
return 0
else:
assert msg.sender != msg.sender
""",
FunctionDeclarationException,
StructureException,
),
(
"""
@internal
def valid_address(sender: address) -> bool:
selfdestruct(sender)
return True
return True # unreachable
""",
StructureException,
),
Expand All @@ -48,7 +49,7 @@ def valid_address(sender: address) -> bool:
@internal
def valid_address(sender: address) -> bool:
selfdestruct(sender)
a: address = sender
a: address = sender # unreachable
""",
StructureException,
),
Expand All @@ -58,7 +59,7 @@ def valid_address(sender: address) -> bool:
def valid_address(sender: address) -> bool:
if sender == empty(address):
selfdestruct(sender)
_sender: address = sender
_sender: address = sender # unreachable
else:
return False
""",
Expand All @@ -69,7 +70,7 @@ def valid_address(sender: address) -> bool:
@internal
def foo() -> bool:
raw_revert(b"vyper")
return True
return True # unreachable
""",
StructureException,
),
Expand All @@ -78,7 +79,7 @@ def foo() -> bool:
@internal
def foo() -> bool:
raw_revert(b"vyper")
x: uint256 = 3
x: uint256 = 3 # unreachable
""",
StructureException,
),
Expand All @@ -88,12 +89,35 @@ def foo() -> bool:
def foo(x: uint256) -> bool:
if x == 2:
raw_revert(b"vyper")
a: uint256 = 3
a: uint256 = 3 # unreachable
else:
return False
""",
StructureException,
),
(
"""
@internal
def foo():
return
return # unreachable
""",
StructureException,
),
(
"""
@internal
def foo() -> uint256:
if block.number % 2 == 0:
return 5
elif block.number % 3 == 0:
return 6
else:
return 10
return 0 # unreachable
""",
StructureException,
),
]


Expand Down Expand Up @@ -154,7 +178,6 @@ def test() -> int128:
else:
x = keccak256(x)
return 1
return 1
""",
"""
@external
Expand Down
38 changes: 34 additions & 4 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@
Field names that, if present, must be set to None or a `SyntaxException`
is raised. This attribute is used to exclude syntax that is valid in Python
but not in Vyper.
_is_terminus : bool, optional
If `True`, indicates that execution halts upon reaching this node.
_translated_fields : Dict, optional
Field names that are reassigned if encountered. Used to normalize fields
across different Python versions.
Expand Down Expand Up @@ -389,6 +387,13 @@
"""
return False

@property
def is_terminus(self):
"""
Check if execution halts upon reaching this node.
"""
return False

@property
def has_folded_value(self):
"""
Expand Down Expand Up @@ -711,12 +716,19 @@

class Return(Stmt):
__slots__ = ("value",)
_is_terminus = True

@property
def is_terminus(self):
return True


class Expr(Stmt):
__slots__ = ("value",)

@property
def is_terminus(self):
return self.value.is_terminus


class Log(Stmt):
__slots__ = ("value",)
Expand Down Expand Up @@ -1187,6 +1199,21 @@
class Call(ExprNode):
__slots__ = ("func", "args", "keywords")

@property
def is_terminus(self):
# cursed import cycle!
from vyper.builtins.functions import get_builtin_functions
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed

if not isinstance(self.func, Name):
return False

funcname = self.func.id
builtin_t = get_builtin_functions().get(funcname)
if builtin_t is None:
return False

return builtin_t._is_terminus


class keyword(VyperNode):
__slots__ = ("arg", "value")
Expand Down Expand Up @@ -1322,7 +1349,10 @@
class Raise(Stmt):
__slots__ = ("exc",)
_only_empty_fields = ("cause",)
_is_terminus = True

@property
def is_terminus(self):
return True


class Assert(Stmt):
Expand Down
1 change: 1 addition & 0 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class BuiltinFunctionT(VyperType):
_kwargs: dict[str, KwargSettings] = {}
_modifiability: Modifiability = Modifiability.MODIFIABLE
_return_type: Optional[VyperType] = None
_is_terminus = False

# helper function to deal with TYPE_DEFINITIONs
def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None:
Expand Down
40 changes: 1 addition & 39 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import contextlib
from typing import Generator

from vyper import ast as vy_ast
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.compiler.settings import OptimizationLevel
from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch
from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch
from vyper.semantics.types import (
AddressT,
BoolT,
Expand Down Expand Up @@ -1035,43 +1034,6 @@ def eval_seq(ir_node):
return None


def is_return_from_function(node):
if isinstance(node, vy_ast.Expr) and node.get("value.func.id") in (
"raw_revert",
"selfdestruct",
):
return True
if isinstance(node, (vy_ast.Return, vy_ast.Raise)):
return True
return False


# TODO this is almost certainly duplicated with check_terminus_node
# in vyper/semantics/analysis/local.py
def check_single_exit(fn_node):
_check_return_body(fn_node, fn_node.body)
for node in fn_node.get_descendants(vy_ast.If):
_check_return_body(node, node.body)
if node.orelse:
_check_return_body(node, node.orelse)


def _check_return_body(node, node_list):
return_count = len([n for n in node_list if is_return_from_function(n)])
if return_count > 1:
raise StructureException(
"Too too many exit statements (return, raise or selfdestruct).", node
)
# Check for invalid code after returns.
last_node_pos = len(node_list) - 1
for idx, n in enumerate(node_list):
if is_return_from_function(n) and idx < last_node_pos:
# is not last statement in body.
raise StructureException(
"Exit statement with succeeding code (that will not execute).", node_list[idx + 1]
)


def mzero(dst, nbytes):
# calldatacopy from past-the-end gives zero bytes.
# cf. YP H.2 (ops section) with CALLDATACOPY spec.
Expand Down
5 changes: 0 additions & 5 deletions vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import vyper.ast as vy_ast
from vyper.codegen.context import Constancy, Context
from vyper.codegen.core import check_single_exit
from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function
from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function
from vyper.codegen.ir_node import IRnode
Expand Down Expand Up @@ -115,10 +114,6 @@ def generate_ir_for_function(
# generate _FuncIRInfo
func_t._ir_info = _FuncIRInfo(func_t)

# Validate return statements.
# XXX: This should really be in semantics pass.
check_single_exit(code)

callees = func_t.called_functions

# we start our function frame from the largest callee frame
Expand Down
3 changes: 1 addition & 2 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
get_dyn_array_count,
get_element_ptr,
getpos,
is_return_from_function,
make_byte_array_copier,
make_setter,
pop_dyn_array,
Expand Down Expand Up @@ -404,7 +403,7 @@ def parse_stmt(stmt, context):
def _is_terminated(code):
last_stmt = code[-1]

if is_return_from_function(last_stmt):
if last_stmt.is_terminus:
return True

if isinstance(last_stmt, vy_ast.If):
Expand Down
52 changes: 28 additions & 24 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,28 @@ def validate_functions(vy_module: vy_ast.Module) -> None:
err_list.raise_if_not_empty()


def _is_terminus_node(node: vy_ast.VyperNode) -> bool:
if getattr(node, "_is_terminus", None):
return True
if isinstance(node, vy_ast.Expr) and isinstance(node.value, vy_ast.Call):
func = get_exact_type_from_node(node.value.func)
if getattr(func, "_is_terminus", None):
return True
return False


def check_for_terminus(node_list: list) -> bool:
if next((i for i in node_list if _is_terminus_node(i)), None):
return True
for node in [i for i in node_list if isinstance(i, vy_ast.If)][::-1]:
if not node.orelse or not check_for_terminus(node.orelse):
continue
if not check_for_terminus(node.body):
continue
return True
return False
# finds the terminus node for a list of nodes.
# raises an exception if any nodes are unreachable
def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]:
ret = None

for node in node_list:
if ret is not None:
raise StructureException("Unreachable code!", node)
if node.is_terminus:
ret = node

if isinstance(node, vy_ast.If):
body_terminates = find_terminating_node(node.body)

else_terminates = None
if node.orelse is not None:
else_terminates = find_terminating_node(node.orelse)

if body_terminates is not None and else_terminates is not None:
ret = else_terminates

return ret


def _check_iterator_modification(
Expand Down Expand Up @@ -201,11 +203,13 @@ def analyze(self):
self.visit(node)

if self.func.return_type:
if not check_for_terminus(self.fn_node.body):
if not find_terminating_node(self.fn_node.body):
raise FunctionDeclarationException(
f"Missing or unmatched return statements in function '{self.fn_node.name}'",
self.fn_node,
f"Missing return statement in function '{self.fn_node.name}'", self.fn_node
)
else:
# call find_terminator for its unreachable code detection side effect
find_terminating_node(self.fn_node.body)

# visit default args
assert self.func.n_keyword_args == len(self.fn_node.args.defaults)
Expand Down Expand Up @@ -468,7 +472,7 @@ def visit_Return(self, node):
raise FunctionDeclarationException("Return statement is missing a value", node)
return
elif self.func.return_type is None:
raise FunctionDeclarationException("Function does not return any values", node)
raise FunctionDeclarationException("Function should not return any values", node)

if isinstance(values, vy_ast.Tuple):
values = values.elements
Expand Down
Loading