From 83237ed6e41f494240b41e8e227984892df1a6ce Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 25 Sep 2023 15:52:22 +0800 Subject: [PATCH] move derive helper to types utils --- vyper/ast/nodes.py | 66 ---------------------------- vyper/semantics/analysis/local.py | 2 + vyper/semantics/types/utils.py | 71 +++++++++++++++++++++++-------- 3 files changed, 55 insertions(+), 84 deletions(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 0a534a11eb..2497928035 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -373,9 +373,6 @@ def description(self): """ return getattr(self, "_description", type(self).__name__) - def derive(self, constants: dict): - return None - def evaluate(self) -> "VyperNode": """ Attempt to evaluate the content of a node and generate a new node from it. @@ -754,9 +751,6 @@ class Constant(ExprNode): # inherited class for all simple constant node types __slots__ = ("value",) - def derive(self, constants: dict): - return self.value - class Num(Constant): # inherited class for all numeric constant node types @@ -893,23 +887,11 @@ class List(ExprNode): __slots__ = ("elements",) _translated_fields = {"elts": "elements"} - def derive(self, constants: dict): - val = [e.derive(constants) for e in self.elements] - if None in val: - return None - return val - class Tuple(ExprNode): __slots__ = ("elements",) _translated_fields = {"elts": "elements"} - def derive(self, constants: dict): - val = [e.derive(constants) for e in self.elements] - if None in val: - return None - return val - def validate(self): if not self.elements: raise InvalidLiteral("Cannot have an empty tuple", self) @@ -918,12 +900,6 @@ def validate(self): class Dict(ExprNode): __slots__ = ("keys", "values") - def derive(self, constants: dict): - values = [v.derive(constants) for v in self.args[0].values] - if any(v is None for v in values): - return None - return {k: v for (k, v) in zip(self.args[0].keys, values)} - class NameConstant(Constant): __slots__ = ("value",) @@ -932,19 +908,10 @@ class NameConstant(Constant): class Name(ExprNode): __slots__ = ("id",) - def derive(self, constants: dict): - return constants.get(self.id, None) - class UnaryOp(ExprNode): __slots__ = ("op", "operand") - def derive(self, constants: dict): - operand = self.operand.derive(constants) - if operand is None: - return None - return self.op._op(operand) - def evaluate(self) -> ExprNode: """ Attempt to evaluate the unary operation. @@ -993,13 +960,6 @@ def _op(self, value): class BinOp(ExprNode): __slots__ = ("left", "op", "right") - def derive(self, constants: dict): - left = self.left.derive(constants) - right = self.right.derive(constants) - if left is None or right is None: - return None - return self.op._op(left, right) - def evaluate(self) -> ExprNode: """ Attempt to evaluate the arithmetic operation. @@ -1150,12 +1110,6 @@ class RShift(Operator): class BoolOp(ExprNode): __slots__ = ("op", "values") - def derive(self, constants: dict): - values = [i.derive(constants) for i in self.values] - if any(v is None for v in values): - return None - return self.op._op(values) - def evaluate(self) -> ExprNode: """ Attempt to evaluate the boolean operation. @@ -1212,20 +1166,6 @@ def __init__(self, *args, **kwargs): kwargs["right"] = kwargs.pop("comparators")[0] super().__init__(*args, **kwargs) - def derive(self, constants: dict): - left = self.left.derive(constants) - - if isinstance(self.op, (In, NotIn)): - right = [i.derive(constants) for i in self.right.elements] - if left is None or any(v is None for v in right): - return None - return self.op._op(left, right) - - right = self.right.derive(constants) - if left is None or right is None: - return None - return self.op._op(left, right) - def evaluate(self) -> ExprNode: """ Attempt to evaluate the comparison. @@ -1316,12 +1256,6 @@ def _op(self, left, right): class Call(ExprNode): __slots__ = ("func", "args", "keywords", "keyword") - def derive(self, constants: dict): - # only return constant struct values - if len(self.args) == 1 and isinstance(self.args[0], Dict): - return self.args[0].derive(constants) - return None - class keyword(VyperNode): __slots__ = ("arg", "value") diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 13261a5aba..90562b6bd2 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -411,6 +411,8 @@ def visit_For(self, node): else: # range(CONSTANT, CONSTANT) arg1_val = derive_folded_value(args[1]) + print("arg0 val: ", arg0_val) + print("arg1 val: ", arg1_val) if not arg1_val: raise InvalidType("Value must be a literal integer", args[1]) validate_expected_type(args[1], IntegerT.any()) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index d82b3bd0da..61faa8495e 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -134,29 +134,64 @@ def _failwith(type_name): return typ_ -def derive_literal_value(node: vy_ast.VyperNode): - ns = get_namespace() - val = node.derive(ns._constants) - return val - - def derive_folded_value(node: vy_ast.VyperNode): - if node is None: - return None - - val = derive_literal_value(node) - if val is not None: - return val + if isinstance(node, vy_ast.BinOp): + left = derive_folded_value(node.left) + right = derive_folded_value(node.right) + if left is None or right is None: + return None + return node.op._op(left, right) + elif isinstance(node, vy_ast.BoolOp): + values = [derive_folded_value(i) for i in node.values] + if any(v is None for v in values): + return None + return node.op._op(values) + elif isinstance(node, vy_ast.Call): + if len(node.args) == 1 and isinstance(node.args[0], Dict): + return derive_folded_value(node.args[0]) - if isinstance(node, vy_ast.Call): from vyper.semantics.analysis.utils import get_exact_type_from_node call_type = get_exact_type_from_node(node.func) - try: - evaluated = call_type.evaluate(node) - return evaluated.value - except (UnfoldableNode, VyperException): - pass + if hasattr(call_type, "evaluate"): + try: + evaluated = call_type.evaluate(node) + return evaluated.value + except (UnfoldableNode, VyperException): + pass + elif isinstance(node, vy_ast.Compare): + left = derive_folded_value(node.left) + + if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): + right = [derive_folded_value(i) for i in node.right.elements] + if left is None or any(v is None for v in right): + return None + return node.op._op(left, right) + + right = derive_folded_value(node.right) + if left is None or right is None: + return None + return node.op._op(left, right) + elif isinstance(node, vy_ast.Constant): + return node.value + elif isinstance(node, vy_ast.Dict): + values = [derive_folded_value(v) for v in node.values] + if any(v is None for v in values): + return None + return {k: v for (k, v) in zip(node.keys, values)} + elif isinstance(node, (vy_ast.List, vy_ast.Tuple)): + val = [derive_folded_value(e) for e in node.elements] + if None in val: + return None + return val + elif isinstance(node, vy_ast.Name): + ns = get_namespace() + return ns._constants.get(node.id, None) + elif isinstance(node, vy_ast.UnaryOp): + operand = derive_folded_value(node.operand) + if operand is None: + return None + return node.op._op(operand) return None