From a5de4bddeb05f4eafceee07fa89e1a41f9c0f7b3 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 25 Sep 2023 15:30:04 +0800 Subject: [PATCH] refactor; consolidate derive to types utils --- vyper/ast/nodes.py | 18 ++++++++--- vyper/builtins/_signatures.py | 17 ++-------- vyper/builtins/functions.py | 27 ++-------------- vyper/semantics/analysis/local.py | 14 ++++---- vyper/semantics/analysis/module.py | 6 ++-- vyper/semantics/analysis/utils.py | 44 ++------------------------ vyper/semantics/types/subscriptable.py | 4 +-- vyper/semantics/types/utils.py | 31 +++++++++++++++++- 8 files changed, 65 insertions(+), 96 deletions(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 8a0b211d4c..0a534a11eb 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -904,6 +904,12 @@ class Tuple(ExprNode): __slots__ = ("elements",) _translated_fields = {"elts": "elements"} + def derive(self, constants: dict): + val = [e.derive(constants) for e in self.elements] + if None in val: + return None + return val + def validate(self): if not self.elements: raise InvalidLiteral("Cannot have an empty tuple", self) @@ -912,6 +918,12 @@ def validate(self): class Dict(ExprNode): __slots__ = ("keys", "values") + def derive(self, constants: dict): + values = [v.derive(constants) for v in self.args[0].values] + if any(v is None for v in values): + return None + return {k: v for (k, v) in zip(self.args[0].keys, values)} + class NameConstant(Constant): __slots__ = ("value",) @@ -1305,11 +1317,9 @@ class Call(ExprNode): __slots__ = ("func", "args", "keywords", "keyword") def derive(self, constants: dict): + # only return constant struct values if len(self.args) == 1 and isinstance(self.args[0], Dict): - values = [v.derive(constants) for v in self.args[0].values] - if any(v is None for v in values): - return None - return {k: v for (k, v) in zip(self.args[0].keys, values)} + return self.args[0].derive(constants) return None diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index a6a393c403..b513440d1c 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -1,15 +1,13 @@ import functools from typing import Dict -from vyper.ast import nodes as vy_ast from vyper.ast.validation import validate_call_args from vyper.codegen.expr import Expr from vyper.codegen.ir_node import IRnode -from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode, VyperException +from vyper.exceptions import CompilerPanic, TypeMismatch from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type -from vyper.semantics.namespace import get_namespace from vyper.semantics.types import TYPE_T, KwargSettings, VyperType -from vyper.semantics.types.utils import type_from_annotation +from vyper.semantics.types.utils import derive_folded_value, type_from_annotation def process_arg(arg, expected_arg_type, context): @@ -103,18 +101,9 @@ def _validate_arg_types(self, node): for arg, (_, expected) in zip(node.args, self._inputs): self._validate_single(arg, expected) - ns = get_namespace() for kwarg in node.keywords: kwarg_settings = self._kwargs[kwarg.arg] - is_literal_value = kwarg.value.derive(ns._constants) is not None - if isinstance(kwarg.value, vy_ast.Call): - call_type = get_exact_type_from_node(kwarg.value.func) - if hasattr(call_type, "evaluate"): - try: - call_type.evaluate(kwarg.value) - is_literal_value = True - except (UnfoldableNode, VyperException): - pass + is_literal_value = derive_folded_value(kwarg.value) is not None if kwarg_settings.require_literal and not is_literal_value: raise TypeMismatch("Value for kwarg must be a literal", kwarg.value) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 879d1ea04f..362edc6cfd 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -50,7 +50,6 @@ StructureException, TypeMismatch, UnfoldableNode, - VyperException, ZeroDivisionException, ) from vyper.semantics.analysis.base import VarInfo @@ -60,7 +59,6 @@ get_possible_types_from_node, validate_expected_type, ) -from vyper.semantics.namespace import get_namespace from vyper.semantics.types import ( TYPE_T, AddressT, @@ -85,7 +83,7 @@ UINT8_T, UINT256_T, ) -from vyper.semantics.types.utils import type_from_annotation +from vyper.semantics.types.utils import derive_folded_value, type_from_annotation from vyper.utils import ( DECIMAL_DIVISOR, EIP_170_LIMIT, @@ -1062,25 +1060,6 @@ def build_IR(self, expr, args, kwargs, context): empty_value = IRnode.from_list(0, typ=BYTES32_T) -def derive_kwarg_value(kwarg, call_type): - if kwarg is None: - return None - - ns = get_namespace() - kwarg_val = kwarg.derive(ns._constants) - if kwarg_val is not None: - return kwarg_val - - if isinstance(kwarg, vy_ast.Call): - try: - evaluated = call_type.evaluate(kwarg) - return evaluated.value - except (UnfoldableNode, VyperException): - pass - - return None - - class RawCall(BuiltinFunction): _id = "raw_call" _inputs = [("to", AddressT()), ("data", BytesT.any())] @@ -1099,8 +1078,8 @@ def fetch_call_return(self, node): kwargz = {i.arg: i.value for i in node.keywords} - outsize = derive_kwarg_value(kwargz.get("max_outsize"), self) - revert_on_failure = derive_kwarg_value(kwargz.get("revert_on_failure"), self) + outsize = derive_folded_value(kwargz.get("max_outsize")) + revert_on_failure = derive_folded_value(kwargz.get("revert_on_failure")) revert_on_failure = revert_on_failure if revert_on_failure is not None else True if outsize is None or outsize == 0: diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 9ad5333a91..13261a5aba 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -51,7 +51,7 @@ is_type_t, ) from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability -from vyper.semantics.types.utils import type_from_annotation +from vyper.semantics.types.utils import derive_folded_value, type_from_annotation def validate_functions(vy_module: vy_ast.Module) -> None: @@ -358,7 +358,7 @@ def visit_For(self, node): validate_expected_type(n, IntegerT.any()) if bound is None: - n_val = n.derive(self.namespace._constants) + n_val = derive_folded_value(n) if n_val is None: raise StateAccessViolation("Value must be a literal", n) if n_val <= 0: @@ -366,7 +366,7 @@ def visit_For(self, node): type_list = get_possible_types_from_node(n) else: - bound_val = bound.derive(self.namespace._constants) + bound_val = derive_folded_value(bound) if bound_val is None: raise StateAccessViolation("bound must be a literal", bound) if bound_val <= 0: @@ -383,7 +383,7 @@ def visit_For(self, node): validate_expected_type(args[0], IntegerT.any()) type_list = get_common_types(*args) - arg0_val = args[0].derive(self.namespace._constants) + arg0_val = derive_folded_value(args[0]) if arg0_val is None: # range(x, x + CONSTANT) if not isinstance(args[1], vy_ast.BinOp) or not isinstance( @@ -397,7 +397,7 @@ def visit_For(self, node): "First and second variable must be the same", args[1].left ) - right_val = args[1].right.derive(self.namespace._constants) + right_val = derive_folded_value(args[1].right) if not isinstance(args[1].right, vy_ast.Int) and not ( isinstance(args[1].right, vy_ast.Name) and right_val ): @@ -410,7 +410,7 @@ def visit_For(self, node): ) else: # range(CONSTANT, CONSTANT) - arg1_val = args[1].derive(self.namespace._constants) + arg1_val = derive_folded_value(args[1]) if not arg1_val: raise InvalidType("Value must be a literal integer", args[1]) validate_expected_type(args[1], IntegerT.any()) @@ -422,7 +422,7 @@ def visit_For(self, node): else: # iteration over a variable or literal list - iter_ = node.iter.derive(self.namespace._constants) + iter_ = derive_folded_value(node.iter) if isinstance(iter_, list) and len(iter_) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index bbbc35e6a6..1f785701b0 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -28,7 +28,7 @@ from vyper.semantics.namespace import Namespace, get_namespace from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT from vyper.semantics.types.function import ContractFunctionT -from vyper.semantics.types.utils import type_from_annotation +from vyper.semantics.types.utils import derive_folded_value, type_from_annotation from vyper.typing import InterfaceDict @@ -80,7 +80,7 @@ def __init__( if c.value is None: continue - val = c.value.derive(self.namespace._constants) + val = derive_folded_value(c.value) self.namespace.add_constant(name, val) if val is not None: @@ -269,7 +269,7 @@ def _validate_self_namespace(): if not node.value: raise VariableDeclarationException("Constant must be declared with a value", node) # TODO: move to check_constant - if not node.value.derive(self.namespace._constants) and not check_constant(node.value): + if not check_constant(node.value): raise StateAccessViolation("Value must be a literal", node.value) validate_expected_type(node.value, type_) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 4b76f357e3..143a2f965e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -12,7 +12,6 @@ StructureException, TypeMismatch, UndeclaredDefinition, - UnfoldableNode, UnknownAttribute, VyperException, ZeroDivisionException, @@ -25,6 +24,7 @@ from vyper.semantics.types.bytestrings import BytesT, StringT from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT +from vyper.semantics.types.utils import derive_folded_value from vyper.utils import checksum_encode, int_to_fourbytes @@ -624,65 +624,27 @@ def check_kwargable(node: vy_ast.VyperNode) -> bool: """ Check if the given node can be used as a default arg """ - if _check_literal(node): + if derive_folded_value(node) is not None: return True - if isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(check_kwargable(item) for item in node.elements) if isinstance(node, vy_ast.Call): - args = node.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return all(check_kwargable(v) for v in args[0].values) - call_type = get_exact_type_from_node(node.func) if getattr(call_type, "_kwargable", False): return True - if getattr(call_type, "evaluate", False): - try: - call_type.evaluate(node) - return True - except (UnfoldableNode, VyperException): - return False - value_type = get_expr_info(node) # is_constant here actually means not_assignable, and is to be renamed return value_type.is_constant -def _check_literal(node: vy_ast.VyperNode) -> bool: - """ - Check if the given node is a literal value. - """ - ns = get_namespace() - val = node.derive(ns._constants) - if val is not None: - return True - - return False - - def check_constant(node: vy_ast.VyperNode) -> bool: """ Check if the given node is a literal or constant value. """ - if _check_literal(node): + if derive_folded_value(node) is not None: return True - if isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(check_constant(item) for item in node.elements) if isinstance(node, vy_ast.Call): - args = node.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return all(check_constant(v) for v in args[0].values) - call_type = get_exact_type_from_node(node.func) if getattr(call_type, "_kwargable", False): return True - if getattr(call_type, "evaluate", False): - try: - call_type.evaluate(node) - return True - except (UnfoldableNode, VyperException): - return False - return False diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index bc1bdbfcf0..bf9a88132c 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -8,7 +8,7 @@ from vyper.semantics.types.base import VyperType from vyper.semantics.types.primitives import IntegerT from vyper.semantics.types.shortcuts import UINT256_T -from vyper.semantics.types.utils import get_index_value, type_from_annotation +from vyper.semantics.types.utils import derive_folded_value, get_index_value, type_from_annotation class _SubscriptableT(VyperType): @@ -287,7 +287,7 @@ def from_annotation(cls, node: vy_ast.Subscript, constants: dict) -> "DArrayT": node, ) - max_length = node.slice.value.elements[1].derive(constants) + max_length = derive_folded_value(node.slice.value.elements[1]) if not max_length or not isinstance(max_length, int): raise StructureException( "DynArray must have a max length of integer type, e.g. DynArray[bool, 5]", node diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index e36dea7dcd..d82b3bd0da 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -6,7 +6,9 @@ InstantiationException, InvalidType, StructureException, + UnfoldableNode, UnknownType, + VyperException, ) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.data_locations import DataLocation @@ -132,6 +134,33 @@ def _failwith(type_name): return typ_ +def derive_literal_value(node: vy_ast.VyperNode): + ns = get_namespace() + val = node.derive(ns._constants) + return val + + +def derive_folded_value(node: vy_ast.VyperNode): + if node is None: + return None + + val = derive_literal_value(node) + if val is not None: + return val + + if isinstance(node, vy_ast.Call): + from vyper.semantics.analysis.utils import get_exact_type_from_node + + call_type = get_exact_type_from_node(node.func) + try: + evaluated = call_type.evaluate(node) + return evaluated.value + except (UnfoldableNode, VyperException): + pass + + return None + + def get_index_value(node: vy_ast.Index, constants: dict) -> int: """ Return the literal value for a `Subscript` index. @@ -151,7 +180,7 @@ def get_index_value(node: vy_ast.Index, constants: dict) -> int: # TODO: revisit this! from vyper.semantics.analysis.utils import get_possible_types_from_node - val = node.value.derive(constants) + val = derive_folded_value(node.value) if not isinstance(val, int): if hasattr(node, "value"):