Skip to content

Commit

Permalink
move derive helper to types utils
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Sep 25, 2023
1 parent a5de4bd commit 83237ed
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 84 deletions.
66 changes: 0 additions & 66 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,6 @@ def description(self):
"""
return getattr(self, "_description", type(self).__name__)

def derive(self, constants: dict):
return None

def evaluate(self) -> "VyperNode":
"""
Attempt to evaluate the content of a node and generate a new node from it.
Expand Down Expand Up @@ -754,9 +751,6 @@ class Constant(ExprNode):
# inherited class for all simple constant node types
__slots__ = ("value",)

def derive(self, constants: dict):
return self.value


class Num(Constant):
# inherited class for all numeric constant node types
Expand Down Expand Up @@ -893,23 +887,11 @@ class List(ExprNode):
__slots__ = ("elements",)
_translated_fields = {"elts": "elements"}

def derive(self, constants: dict):
val = [e.derive(constants) for e in self.elements]
if None in val:
return None
return val


class Tuple(ExprNode):
__slots__ = ("elements",)
_translated_fields = {"elts": "elements"}

def derive(self, constants: dict):
val = [e.derive(constants) for e in self.elements]
if None in val:
return None
return val

def validate(self):
if not self.elements:
raise InvalidLiteral("Cannot have an empty tuple", self)
Expand All @@ -918,12 +900,6 @@ def validate(self):
class Dict(ExprNode):
__slots__ = ("keys", "values")

def derive(self, constants: dict):
values = [v.derive(constants) for v in self.args[0].values]
if any(v is None for v in values):
return None
return {k: v for (k, v) in zip(self.args[0].keys, values)}


class NameConstant(Constant):
__slots__ = ("value",)
Expand All @@ -932,19 +908,10 @@ class NameConstant(Constant):
class Name(ExprNode):
__slots__ = ("id",)

def derive(self, constants: dict):
return constants.get(self.id, None)


class UnaryOp(ExprNode):
__slots__ = ("op", "operand")

def derive(self, constants: dict):
operand = self.operand.derive(constants)
if operand is None:
return None
return self.op._op(operand)

def evaluate(self) -> ExprNode:
"""
Attempt to evaluate the unary operation.
Expand Down Expand Up @@ -993,13 +960,6 @@ def _op(self, value):
class BinOp(ExprNode):
__slots__ = ("left", "op", "right")

def derive(self, constants: dict):
left = self.left.derive(constants)
right = self.right.derive(constants)
if left is None or right is None:
return None
return self.op._op(left, right)

def evaluate(self) -> ExprNode:
"""
Attempt to evaluate the arithmetic operation.
Expand Down Expand Up @@ -1150,12 +1110,6 @@ class RShift(Operator):
class BoolOp(ExprNode):
__slots__ = ("op", "values")

def derive(self, constants: dict):
values = [i.derive(constants) for i in self.values]
if any(v is None for v in values):
return None
return self.op._op(values)

def evaluate(self) -> ExprNode:
"""
Attempt to evaluate the boolean operation.
Expand Down Expand Up @@ -1212,20 +1166,6 @@ def __init__(self, *args, **kwargs):
kwargs["right"] = kwargs.pop("comparators")[0]
super().__init__(*args, **kwargs)

def derive(self, constants: dict):
left = self.left.derive(constants)

if isinstance(self.op, (In, NotIn)):
right = [i.derive(constants) for i in self.right.elements]
if left is None or any(v is None for v in right):
return None
return self.op._op(left, right)

right = self.right.derive(constants)
if left is None or right is None:
return None
return self.op._op(left, right)

def evaluate(self) -> ExprNode:
"""
Attempt to evaluate the comparison.
Expand Down Expand Up @@ -1316,12 +1256,6 @@ def _op(self, left, right):
class Call(ExprNode):
__slots__ = ("func", "args", "keywords", "keyword")

def derive(self, constants: dict):
# only return constant struct values
if len(self.args) == 1 and isinstance(self.args[0], Dict):
return self.args[0].derive(constants)
return None


class keyword(VyperNode):
__slots__ = ("arg", "value")
Expand Down
2 changes: 2 additions & 0 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ def visit_For(self, node):
else:
# range(CONSTANT, CONSTANT)
arg1_val = derive_folded_value(args[1])
print("arg0 val: ", arg0_val)
print("arg1 val: ", arg1_val)
if not arg1_val:
raise InvalidType("Value must be a literal integer", args[1])
validate_expected_type(args[1], IntegerT.any())
Expand Down
71 changes: 53 additions & 18 deletions vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,29 +134,64 @@ def _failwith(type_name):
return typ_


def derive_literal_value(node: vy_ast.VyperNode):
ns = get_namespace()
val = node.derive(ns._constants)
return val


def derive_folded_value(node: vy_ast.VyperNode):
if node is None:
return None

val = derive_literal_value(node)
if val is not None:
return val
if isinstance(node, vy_ast.BinOp):
left = derive_folded_value(node.left)
right = derive_folded_value(node.right)
if left is None or right is None:
return None
return node.op._op(left, right)
elif isinstance(node, vy_ast.BoolOp):
values = [derive_folded_value(i) for i in node.values]
if any(v is None for v in values):
return None
return node.op._op(values)
elif isinstance(node, vy_ast.Call):
if len(node.args) == 1 and isinstance(node.args[0], Dict):
return derive_folded_value(node.args[0])

if isinstance(node, vy_ast.Call):
from vyper.semantics.analysis.utils import get_exact_type_from_node

call_type = get_exact_type_from_node(node.func)
try:
evaluated = call_type.evaluate(node)
return evaluated.value
except (UnfoldableNode, VyperException):
pass
if hasattr(call_type, "evaluate"):
try:
evaluated = call_type.evaluate(node)
return evaluated.value
except (UnfoldableNode, VyperException):
pass
elif isinstance(node, vy_ast.Compare):
left = derive_folded_value(node.left)

if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)):
right = [derive_folded_value(i) for i in node.right.elements]
if left is None or any(v is None for v in right):
return None
return node.op._op(left, right)

right = derive_folded_value(node.right)
if left is None or right is None:
return None
return node.op._op(left, right)
elif isinstance(node, vy_ast.Constant):
return node.value
elif isinstance(node, vy_ast.Dict):
values = [derive_folded_value(v) for v in node.values]
if any(v is None for v in values):
return None
return {k: v for (k, v) in zip(node.keys, values)}
elif isinstance(node, (vy_ast.List, vy_ast.Tuple)):
val = [derive_folded_value(e) for e in node.elements]
if None in val:
return None
return val
elif isinstance(node, vy_ast.Name):
ns = get_namespace()
return ns._constants.get(node.id, None)
elif isinstance(node, vy_ast.UnaryOp):
operand = derive_folded_value(node.operand)
if operand is None:
return None
return node.op._op(operand)

return None

Expand Down

0 comments on commit 83237ed

Please sign in to comment.