Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove duplicate terminus checking code #3541

Merged
merged 13 commits into from
Jan 14, 2024
28 changes: 24 additions & 4 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,6 @@
Field names that, if present, must be set to None or a `SyntaxException`
is raised. This attribute is used to exclude syntax that is valid in Python
but not in Vyper.
_is_terminus : bool, optional
If `True`, indicates that execution halts upon reaching this node.
_translated_fields : Dict, optional
Field names that are reassigned if encountered. Used to normalize fields
across different Python versions.
Expand Down Expand Up @@ -380,6 +378,13 @@
"""
return False

@property
def is_terminus(self):
"""
Check if execution halts upon reaching this node.
"""
return False

@property
def has_folded_value(self):
"""
Expand Down Expand Up @@ -717,7 +722,10 @@

class Return(Stmt):
__slots__ = ("value",)
_is_terminus = True

@property
def is_terminus(self):
return True


class Expr(Stmt):
Expand Down Expand Up @@ -1302,6 +1310,15 @@
class Call(ExprNode):
__slots__ = ("func", "args", "keywords")

@property
def is_terminus(self):
# cursed import cycle!
from vyper.builtins.functions import DISPATCH_TABLE
Fixed Show fixed Hide fixed

func_name = self.func.id
builtin_t = DISPATCH_TABLE[func_name]
return getattr(builtin_t, "_is_terminus", False)

# try checking if this is a builtin, which is foldable
def _try_fold(self):
if not isinstance(self.func, Name):
Expand Down Expand Up @@ -1483,7 +1500,10 @@
class Raise(Stmt):
__slots__ = ("exc",)
_only_empty_fields = ("cause",)
_is_terminus = True

@property
def is_terminus(self):
return True


class Assert(Stmt):
Expand Down
40 changes: 1 addition & 39 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import contextlib
from typing import Generator

from vyper import ast as vy_ast
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.compiler.settings import OptimizationLevel
from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch
from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch
from vyper.semantics.types import (
AddressT,
BoolT,
Expand Down Expand Up @@ -1035,43 +1034,6 @@ def eval_seq(ir_node):
return None


def is_return_from_function(node):
if isinstance(node, vy_ast.Expr) and node.get("value.func.id") in (
"raw_revert",
"selfdestruct",
):
return True
if isinstance(node, (vy_ast.Return, vy_ast.Raise)):
return True
return False


# TODO this is almost certainly duplicated with check_terminus_node
# in vyper/semantics/analysis/local.py
def check_single_exit(fn_node):
_check_return_body(fn_node, fn_node.body)
for node in fn_node.get_descendants(vy_ast.If):
_check_return_body(node, node.body)
if node.orelse:
_check_return_body(node, node.orelse)


def _check_return_body(node, node_list):
return_count = len([n for n in node_list if is_return_from_function(n)])
if return_count > 1:
raise StructureException(
"Too too many exit statements (return, raise or selfdestruct).", node
)
# Check for invalid code after returns.
last_node_pos = len(node_list) - 1
for idx, n in enumerate(node_list):
if is_return_from_function(n) and idx < last_node_pos:
# is not last statement in body.
raise StructureException(
"Exit statement with succeeding code (that will not execute).", node_list[idx + 1]
)


def mzero(dst, nbytes):
# calldatacopy from past-the-end gives zero bytes.
# cf. YP H.2 (ops section) with CALLDATACOPY spec.
Expand Down
5 changes: 0 additions & 5 deletions vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import vyper.ast as vy_ast
from vyper.codegen.context import Constancy, Context
from vyper.codegen.core import check_single_exit
from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function
from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function
from vyper.codegen.ir_node import IRnode
Expand Down Expand Up @@ -115,10 +114,6 @@ def generate_ir_for_function(
# generate _FuncIRInfo
func_t._ir_info = _FuncIRInfo(func_t)

# Validate return statements.
# XXX: This should really be in semantics pass.
check_single_exit(code)

callees = func_t.called_functions

# we start our function frame from the largest callee frame
Expand Down
3 changes: 1 addition & 2 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
get_dyn_array_count,
get_element_ptr,
getpos,
is_return_from_function,
make_byte_array_copier,
make_setter,
pop_dyn_array,
Expand Down Expand Up @@ -406,7 +405,7 @@ def parse_stmt(stmt, context):
def _is_terminated(code):
last_stmt = code[-1]

if is_return_from_function(last_stmt):
if last_stmt.is_terminus:
return True

if isinstance(last_stmt, vy_ast.If):
Expand Down
31 changes: 20 additions & 11 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,28 @@ def validate_functions(vy_module: vy_ast.Module) -> None:
err_list.raise_if_not_empty()


def _is_terminus_node(node: vy_ast.VyperNode) -> bool:
if getattr(node, "_is_terminus", None):
return True
if isinstance(node, vy_ast.Expr) and isinstance(node.value, vy_ast.Call):
func = get_exact_type_from_node(node.value.func)
if getattr(func, "_is_terminus", None):
return True
return False


def check_for_terminus(node_list: list) -> bool:
if next((i for i in node_list if _is_terminus_node(i)), None):
terminus_nodes = []

# Check for invalid code after returns
last_node_pos = len(node_list) - 1
for idx, n in enumerate(node_list):
if n.is_terminus:
terminus_nodes.append(n)
if idx < last_node_pos:
# is not last statement in body.
raise StructureException(
"Exit statement with succeeding code (that will not execute).",
node_list[idx + 1],
)

if len(terminus_nodes) > 1:
raise StructureException(
"Too many exit statements (return, raise or selfdestruct).", terminus_nodes[-1]
)
elif len(terminus_nodes) == 1:
return True

for node in [i for i in node_list if isinstance(i, vy_ast.If)][::-1]:
if not node.orelse or not check_for_terminus(node.orelse):
continue
Expand Down
Loading