Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove duplicate terminus checking code #3541

Merged
merged 13 commits into from
Jan 14, 2024
40 changes: 1 addition & 39 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import contextlib
from typing import Generator

from vyper import ast as vy_ast
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.compiler.settings import OptimizationLevel
from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch
from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch
from vyper.semantics.types import (
AddressT,
BoolT,
Expand Down Expand Up @@ -1033,43 +1032,6 @@ def eval_seq(ir_node):
return None


def is_return_from_function(node):
if isinstance(node, vy_ast.Expr) and node.get("value.func.id") in (
"raw_revert",
"selfdestruct",
):
return True
if isinstance(node, (vy_ast.Return, vy_ast.Raise)):
return True
return False


# TODO this is almost certainly duplicated with check_terminus_node
# in vyper/semantics/analysis/local.py
def check_single_exit(fn_node):
_check_return_body(fn_node, fn_node.body)
for node in fn_node.get_descendants(vy_ast.If):
_check_return_body(node, node.body)
if node.orelse:
_check_return_body(node, node.orelse)


def _check_return_body(node, node_list):
return_count = len([n for n in node_list if is_return_from_function(n)])
if return_count > 1:
raise StructureException(
"Too too many exit statements (return, raise or selfdestruct).", node
)
# Check for invalid code after returns.
last_node_pos = len(node_list) - 1
for idx, n in enumerate(node_list):
if is_return_from_function(n) and idx < last_node_pos:
# is not last statement in body.
raise StructureException(
"Exit statement with succeeding code (that will not execute).", node_list[idx + 1]
)


def mzero(dst, nbytes):
# calldatacopy from past-the-end gives zero bytes.
# cf. YP H.2 (ops section) with CALLDATACOPY spec.
Expand Down
5 changes: 0 additions & 5 deletions vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import vyper.ast as vy_ast
from vyper.codegen.context import Constancy, Context
from vyper.codegen.core import check_single_exit
from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function
from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function
from vyper.codegen.global_context import GlobalContext
Expand Down Expand Up @@ -108,10 +107,6 @@ def generate_ir_for_function(
# generate _FuncIRInfo
func_t._ir_info = _FuncIRInfo(func_t)

# Validate return statements.
# XXX: This should really be in semantics pass.
check_single_exit(code)

callees = func_t.called_functions

# we start our function frame from the largest callee frame
Expand Down
4 changes: 2 additions & 2 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
get_dyn_array_count,
get_element_ptr,
getpos,
is_return_from_function,
make_byte_array_copier,
make_setter,
pop_dyn_array,
Expand All @@ -25,6 +24,7 @@
from vyper.codegen.return_ import make_return_stmt
from vyper.evm.address_space import MEMORY, STORAGE
from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure
from vyper.semantics.analysis.utils import is_terminus_node
Fixed Show fixed Hide fixed
from vyper.semantics.types import DArrayT, MemberFunctionT
from vyper.semantics.types.shortcuts import INT256_T, UINT256_T

Expand Down Expand Up @@ -425,7 +425,7 @@
def _is_terminated(code):
last_stmt = code[-1]

if is_return_from_function(last_stmt):
if is_terminus_node(last_stmt):
return True

if isinstance(last_stmt, vy_ast.If):
Expand Down
32 changes: 21 additions & 11 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_exact_type_from_node,
get_expr_info,
get_possible_types_from_node,
is_terminus_node,
validate_expected_type,
)
from vyper.semantics.data_locations import DataLocation
Expand Down Expand Up @@ -65,19 +66,28 @@ def validate_functions(vy_module: vy_ast.Module) -> None:
err_list.raise_if_not_empty()


def _is_terminus_node(node: vy_ast.VyperNode) -> bool:
if getattr(node, "_is_terminus", None):
return True
if isinstance(node, vy_ast.Expr) and isinstance(node.value, vy_ast.Call):
func = get_exact_type_from_node(node.value.func)
if getattr(func, "_is_terminus", None):
return True
return False


def check_for_terminus(node_list: list) -> bool:
if next((i for i in node_list if _is_terminus_node(i)), None):
terminus_nodes = []

# Check for invalid code after returns
last_node_pos = len(node_list) - 1
for idx, n in enumerate(node_list):
if is_terminus_node(n):
terminus_nodes.append(n)
if idx < last_node_pos:
# is not last statement in body.
raise StructureException(
"Exit statement with succeeding code (that will not execute).",
node_list[idx + 1],
)

if len(terminus_nodes) > 1:
raise StructureException(
"Too many exit statements (return, raise or selfdestruct).", terminus_nodes[-1]
)
elif len(terminus_nodes) == 1:
return True

for node in [i for i in node_list if isinstance(i, vy_ast.If)][::-1]:
if not node.orelse or not check_for_terminus(node.orelse):
continue
Expand Down
10 changes: 10 additions & 0 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,16 @@ def get_common_types(*nodes: vy_ast.VyperNode, filter_fn: Callable = None) -> Li
return common_types


def is_terminus_node(node: vy_ast.VyperNode) -> bool:
if getattr(node, "_is_terminus", None):
return True
if isinstance(node, vy_ast.Expr) and isinstance(node.value, vy_ast.Call):
func = get_exact_type_from_node(node.value.func)
if getattr(func, "_is_terminus", None):
return True
return False


# TODO push this into `ArrayT.validate_literal()`
def _validate_literal_array(node, expected):
# validate that every item within an array has the same type
Expand Down
Loading