From 27edf0b55334c896e6434bdacb13372cdc3b9a43 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Wed, 1 Nov 2023 16:00:16 +0800 Subject: [PATCH] fixes wip --- vyper/ast/folding.py | 14 +- vyper/builtins/functions.py | 216 +++++++++++----------- vyper/semantics/analysis/local.py | 41 ++-- vyper/semantics/analysis/pre_typecheck.py | 102 ++++++---- vyper/semantics/types/subscriptable.py | 14 +- vyper/semantics/types/utils.py | 5 +- 6 files changed, 210 insertions(+), 182 deletions(-) diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index e02af55ad8..5fed3d87c8 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Union from vyper.ast import nodes as vy_ast from vyper.exceptions import UnfoldableNode @@ -47,7 +47,6 @@ def replace_literal_ops(vyper_module: vy_ast.Module) -> int: except UnfoldableNode: continue - new_node._metadata["folded_value"] = new_node.value typ = node._metadata.get("type") # type metadata may not be present @@ -88,7 +87,6 @@ def replace_subscripts(vyper_module: vy_ast.Module) -> int: except UnfoldableNode: continue - new_node._metadata["folded_value"] = node._metadata["folded_value"] new_node._metadata["type"] = node._metadata["type"] changed_nodes += 1 @@ -127,7 +125,6 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int: except UnfoldableNode: continue - new_node._metadata["folded_value"] = new_node.value new_node._metadata["type"] = node._metadata["type"] changed_nodes += 1 @@ -159,10 +156,7 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: continue type_ = node._metadata["type"] - folded_value = node.value._metadata["folded_value"] - changed_nodes += replace_constant( - vyper_module, node.target.id, node.value, type_, folded_value, False - ) + changed_nodes += replace_constant(vyper_module, node.target.id, node.value, type_, False) return changed_nodes @@ -199,7 +193,6 @@ def replace_constant( id_: str, replacement_node: Union[vy_ast.Constant, vy_ast.List, vy_ast.Call], type_: VyperType, - folded_value: Any, raise_on_error: bool, ) -> int: """ @@ -216,8 +209,6 @@ def replace_constant( `Call` nodes are for struct constants. type_ : VyperType Type definition to be propagated to type checker. - folded_value: Any - Folded value of the constant raise_on_error: bool Boolean indicating if `UnfoldableNode` exception should be raised or ignored. @@ -256,7 +247,6 @@ def replace_constant( try: # note: _replace creates a copy of the replacement_node new_node = _replace(node, replacement_node, type_) - new_node._metadata["folded_value"] = folded_value except UnfoldableNode: if raise_on_error: raise diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 1f49936f91..306129d1a0 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -139,10 +139,10 @@ class Floor(BuiltinFunction): # TODO: maybe use int136? _return_type = INT256_T - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) input_val = node.args[0]._metadata.get("folded_value") - if not isinstance(input_val, Decimal): + if not isinstance(input_val, vy_ast.Decimal): raise UnfoldableNode value = math.floor(input_val) @@ -170,10 +170,10 @@ class Ceil(BuiltinFunction): # TODO: maybe use int136? _return_type = INT256_T - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) input_val = node.args[0]._metadata.get("folded_value") - if not isinstance(input_val, Decimal): + if not isinstance(input_val, vy_ast.Decimal): raise UnfoldableNode value = math.ceil(input_val) @@ -465,11 +465,14 @@ class Len(BuiltinFunction): _inputs = [("b", (StringT.any(), BytesT.any(), DArrayT.any()))] _return_type = UINT256_T - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) arg = node.args[0]._metadata.get("folded_value") - if isinstance(arg, (str, bytes)): + if isinstance(arg, (vy_ast.Str, vy_ast.Bytes)): length = len(arg) + elif isinstance(arg, vy_ast.Hex): + # 2 characters represent 1 byte and we subtract 1 to ignore the leading `0x` + length = len(arg.value) // 2 - 1 else: raise UnfoldableNode @@ -599,22 +602,19 @@ class Keccak256(BuiltinFunction): _inputs = [("value", (BytesT.any(), BYTES32_T, StringT.any()))] _return_type = BYTES32_T - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) - value = node.args[0]._metadata.get("folded_value") - if not isinstance(value, (bytes, str)): + arg = node.args[0]._metadata.get("folded_value") + if isinstance(arg, vy_ast.Bytes): + value = arg.value + elif isinstance(arg, vy_ast.Str): + value = arg.value.encode() + elif isinstance(arg, vy_ast.Hex): + length = len(arg.value) // 2 - 1 + value = int(arg.value, 16).to_bytes(length, "big") + else: raise UnfoldableNode - if isinstance(value, str): - # we need the argument type to differentiate between - # strings and hex values - arg_typ = self.infer_arg_types(node).pop() - if isinstance(arg_typ, StringT): - value = value.encode() - elif arg_typ == BYTES32_T: - length = len(value) // 2 - 1 - value = int(value, 16).to_bytes(length, "big") - hash_ = f"0x{keccak256(value).hex()}" return vy_ast.Hex.from_node(node, value=hash_) @@ -650,22 +650,19 @@ class Sha256(BuiltinFunction): _inputs = [("value", (BYTES32_T, BytesT.any(), StringT.any()))] _return_type = BYTES32_T - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) - value = node.args[0]._metadata.get("folded_value") - if not isinstance(value, (bytes, str)): + arg = node.args[0]._metadata.get("folded_value") + if isinstance(arg, vy_ast.Bytes): + value = arg.value + elif isinstance(arg, vy_ast.Str): + value = arg.value.encode() + elif isinstance(arg, vy_ast.Hex): + length = len(arg.value) // 2 - 1 + value = int(arg.value, 16).to_bytes(length, "big") + else: raise UnfoldableNode - if isinstance(value, str): - # we need the argument type to differentiate between - # strings and hex values - arg_typ = self.infer_arg_types(node).pop() - if isinstance(arg_typ, StringT): - value = value.encode() - elif arg_typ == BYTES32_T: - length = len(value) // 2 - 1 - value = int(value, 16).to_bytes(length, "big") - hash_ = f"0x{hashlib.sha256(value).hexdigest()}" return vy_ast.Hex.from_node(node, value=hash_) @@ -724,7 +721,7 @@ def build_IR(self, expr, args, kwargs, context): class MethodID(FoldedFunction): _id = "method_id" - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1, ["output_type"]) args = node.args @@ -983,26 +980,27 @@ class AsWeiValue(BuiltinFunction): } def get_denomination(self, node): - value = node.args[1]._metadata.get("folded_value") - if not isinstance(value, str): + arg = node.args[1]._metadata.get("folded_value") + if not isinstance(arg, vy_ast.Str): raise ArgumentException( "Wei denomination must be given as a literal string", node.args[1] ) try: - denom = next(v for k, v in self.wei_denoms.items() if value in k) + denom = next(v for k, v in self.wei_denoms.items() if arg.value in k) except StopIteration: - raise ArgumentException(f"Unknown denomination: {value}", node.args[1]) from None + raise ArgumentException(f"Unknown denomination: {arg.value}", node.args[1]) from None return denom - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 2) denom = self.get_denomination(node) - value = node.args[0]._metadata.get("folded_value") - if not isinstance(value, (Decimal, int)): + arg = node.args[0]._metadata.get("folded_value") + if not isinstance(arg, (vy_ast.Decimal, vy_ast.Int)): raise UnfoldableNode + value = arg.value if value < 0: raise InvalidLiteral("Negative wei value not allowed", node.args[0]) @@ -1096,12 +1094,16 @@ def fetch_call_return(self, node): return None return BoolT() - return_type = BytesT() - return_type.set_min_length(outsize) + if not isinstance(outsize, vy_ast.Int) or outsize.value < 0: + raise + + if outsize.value: + return_type = BytesT() + return_type.set_min_length(outsize.value) - if revert_on_failure: - return return_type - return TupleT([BoolT(), return_type]) + if revert_on_failure: + return return_type + return TupleT([BoolT(), return_type]) def infer_arg_types(self, node, expected_return_type=None): self._validate_arg_types(node) @@ -1350,20 +1352,20 @@ class BitwiseAnd(BuiltinFunction): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): if not self.__class__._warned: vyper_warn("`bitwise_and()` is deprecated! Please use the & operator instead.") self.__class__._warned = True validate_call_args(node, 2) - values = [i._metadata.get("folded_value") for i in node.args] - for v, arg in zip(values, node.args): - if not isinstance(v, int): + args = [i._metadata.get("folded_value") for i in node.args] + for v, arg in zip(args, node.args): + if not isinstance(v, vy_ast.Int): raise UnfoldableNode - if v < 0 or v >= 2**256: + if v.value < 0 or v.value >= 2**256: raise InvalidLiteral("Value out of range for uint256", arg) - value = values[0] & values[1] + value = args[0].value & args[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1377,20 +1379,20 @@ class BitwiseOr(BuiltinFunction): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): if not self.__class__._warned: vyper_warn("`bitwise_or()` is deprecated! Please use the | operator instead.") self.__class__._warned = True validate_call_args(node, 2) - values = [i._metadata.get("folded_value") for i in node.args] - for v, arg in zip(values, node.args): - if not isinstance(arg, int): + args = [i._metadata.get("folded_value") for i in node.args] + for v, arg in zip(args, node.args): + if not isinstance(arg, vy_ast.Int): raise UnfoldableNode - if v < 0 or v >= 2**256: + if v.value < 0 or v.value >= 2**256: raise InvalidLiteral("Value out of range for uint256", arg) - value = values[0] | values[1] + value = args[0].value | args[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1404,20 +1406,20 @@ class BitwiseXor(BuiltinFunction): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): if not self.__class__._warned: vyper_warn("`bitwise_xor()` is deprecated! Please use the ^ operator instead.") self.__class__._warned = True validate_call_args(node, 2) - values = [i._metadata.get("folded_value") for i in node.args] - for v, arg in zip(values, node.args): - if not isinstance(arg, int): + args = [i._metadata.get("folded_value") for i in node.args] + for v, arg in zip(args, node.args): + if not isinstance(arg, vy_ast.Int): raise UnfoldableNode - if v < 0 or v >= 2**256: + if v.value < 0 or v.value >= 2**256: raise InvalidLiteral("Value out of range for uint256", arg) - value = values[0] ^ values[1] + value = args[0].value ^ args[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1431,20 +1433,20 @@ class BitwiseNot(BuiltinFunction): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): if not self.__class__._warned: vyper_warn("`bitwise_not()` is deprecated! Please use the ~ operator instead.") self.__class__._warned = True validate_call_args(node, 1) - value = node.args[0]._metadata.get("folded_value") - if not isinstance(value, int): + arg = node.args[0]._metadata.get("folded_value") + if not isinstance(arg, vy_ast.Int): raise UnfoldableNode - if value < 0 or value >= 2**256: + if arg.value < 0 or arg.value >= 2**256: raise InvalidLiteral("Value out of range for uint256", node.args[0]) - value = (2**256 - 1) - value + value = (2**256 - 1) - arg.value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1458,7 +1460,7 @@ class Shift(BuiltinFunction): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): if not self.__class__._warned: vyper_warn("`shift()` is deprecated! Please use the << or >> operator instead.") self.__class__._warned = True @@ -1510,18 +1512,18 @@ class _AddMulMod(BuiltinFunction): _inputs = [("a", UINT256_T), ("b", UINT256_T), ("c", UINT256_T)] _return_type = UINT256_T - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 3) - values = [i._metadata.get("folded_value") for i in node.args] - if isinstance(values[2], int) and values[2] == 0: + args = [i._metadata.get("folded_value") for i in node.args] + if isinstance(args[2], vy_ast.Int) and args[2] == 0: raise ZeroDivisionException("Modulo by 0", node.args[2]) - for v, arg in zip(values, node.args): + for v, arg in zip(args, node.args): if not isinstance(v, int): raise UnfoldableNode - if v < 0 or v >= 2**256: + if v.value < 0 or v.value >= 2**256: raise InvalidLiteral("Value out of range for uint256", arg) - value = self._eval_fn(values[0], values[1]) % values[2] + value = self._eval_fn(args[0].value, args[1].value) % args[2].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1553,17 +1555,17 @@ class PowMod256(BuiltinFunction): _inputs = [("a", UINT256_T), ("b", UINT256_T)] _return_type = UINT256_T - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 2) - values = [i._metadata.get("folded_value") for i in node.args] - if any(not isinstance(i, int) for i in values): + args = [i._metadata.get("folded_value") for i in node.args] + if any(not isinstance(i, vy_ast.Int) for i in args): raise UnfoldableNode - left, right = values - if left < 0 or right < 0: + left, right = args + if left.value < 0 or right.value < 0: raise UnfoldableNode - value = pow(left, right, 2**256) + value = pow(left.value, right.value, 2**256) return vy_ast.Int.from_node(node, value=value) def build_IR(self, expr, context): @@ -1577,15 +1579,15 @@ class Abs(BuiltinFunction): _inputs = [("value", INT256_T)] _return_type = INT256_T - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) - value = node.args[0]._metadata.get("folded_value") - if not isinstance(value, int): + arg = node.args[0]._metadata.get("folded_value") + if not isinstance(arg, vy_ast.Int): raise UnfoldableNode - if not SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_INT256: + if not SizeLimits.MIN_INT256 <= arg.value <= SizeLimits.MAX_INT256: raise OverflowException("Literal is outside of allowable range for int256") - value = abs(value) + value = abs(arg.value) if not SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_INT256: raise OverflowException("Absolute literal value is outside allowable range for int256") @@ -2021,28 +2023,30 @@ class UnsafeDiv(_UnsafeMath): class _MinMax(BuiltinFunction): _inputs = [("a", (DecimalT(), IntegerT.any())), ("b", (DecimalT(), IntegerT.any()))] - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 2) - values = [i._metadata.get("folded_value") for i in node.args] - if not isinstance(values[0], type(values[1])): + args = [i._metadata.get("folded_value") for i in node.args] + if not isinstance(args[0], type(args[1])): raise UnfoldableNode - if not isinstance(values[0], (Decimal, int)): + if not isinstance(args[0], (vy_ast.Decimal, vy_ast.Int)): raise UnfoldableNode - left, right = values - if isinstance(left, Decimal) and ( - min(left, right) < SizeLimits.MIN_AST_DECIMAL - or max(left, right) > SizeLimits.MAX_AST_DECIMAL + left, right = args + if isinstance(left.value, Decimal) and ( + min(left.value, right.value) < SizeLimits.MIN_AST_DECIMAL + or max(left.value, right.value) > SizeLimits.MAX_AST_DECIMAL ): raise InvalidType("Decimal value is outside of allowable range", node) - types_list = get_common_types( - *node.args, filter_fn=lambda x: isinstance(x, (IntegerT, DecimalT)) - ) - if not types_list: - raise TypeMismatch("Cannot perform action between dislike numeric types", node) + # skip during pre-typecheck + if not skip_typecheck: + types_list = get_common_types( + *args, filter_fn=lambda x: isinstance(x, (IntegerT, DecimalT)) + ) + if not types_list: + raise TypeMismatch("Cannot perform action between dislike numeric types", node) - value = self._eval_fn(left, right) + value = self._eval_fn(left.value, right.value) if isinstance(left, Decimal): node = vy_ast.Decimal.from_node(node, value=value) @@ -2113,13 +2117,13 @@ def fetch_call_return(self, node): len_needed = math.ceil(bits * math.log(2) / math.log(10)) return StringT(len_needed) - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) - value = node.args[0]._metadata.get("folded_value") - if not isinstance(value, int): + arg = node.args[0]._metadata.get("folded_value") + if not isinstance(arg, vy_ast.Int): raise UnfoldableNode - value = str(value) + value = str(arg.value) return vy_ast.Str.from_node(node, value=value) def infer_arg_types(self, node, expected_return_type=None): @@ -2599,7 +2603,7 @@ def build_IR(self, expr, args, kwargs, context): class _MinMaxValue(TypenameFoldedFunction): - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) @@ -2634,7 +2638,7 @@ def _eval(self, type_): class Epsilon(TypenameFoldedFunction): _id = "epsilon" - def evaluate(self, node): + def evaluate(self, node, skip_typecheck=False): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 95cfd3f505..553c2b18df 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -353,23 +353,24 @@ def visit_For(self, node): if len(args) == 1: # range(CONSTANT) n = args[0] + folded_n = n._metadata.get("folded_value") + bound = kwargs.pop("bound", None) validate_expected_type(n, IntegerT.any()) if bound is None: - n_val = n._metadata.get("folded_value") - if not isinstance(n_val, int): + if not isinstance(folded_n, vy_ast.Num): raise StateAccessViolation("Value must be a literal integer", n) - if n_val <= 0: - raise StructureException("For loop must have at least 1 iteration", args[0]) + if folded_n.value <= 0: + raise StructureException("For loop must have at least 1 iteration", n) type_list = get_possible_types_from_node(n) else: - bound_val = bound._metadata.get("folded_value") - if bound_val is None: + folded_bound = bound._metadata.get("folded_value") + if folded_bound is None: raise StateAccessViolation("bound must be a literal", bound) - if bound_val <= 0: - raise StructureException("bound must be at least 1", args[0]) + if folded_bound.value <= 0: + raise StructureException("bound must be at least 1", bound) type_list = get_common_types(n, bound) else: @@ -382,8 +383,8 @@ def visit_For(self, node): validate_expected_type(args[0], IntegerT.any()) type_list = get_common_types(*args) - arg0_val = args[0]._metadata.get("folded_value") - if not isinstance(arg0_val, int): + folded_arg0 = args[0]._metadata.get("folded_value") + if not isinstance(folded_arg0, vy_ast.Constant): # range(x, x + CONSTANT) if not isinstance(args[1], vy_ast.BinOp) or not isinstance( args[1].op, vy_ast.Add @@ -396,22 +397,22 @@ def visit_For(self, node): "First and second variable must be the same", args[1].left ) - right_val = args[1].right._metadata.get("folded_value") - if not isinstance(right_val, int): + folded_right = args[1].right._metadata.get("folded_value") + if not isinstance(folded_right, vy_ast.Int): raise InvalidLiteral("Literal must be an integer", args[1].right) - if right_val < 1: + if folded_right.value < 1: raise StructureException( - f"For loop has invalid number of iterations ({right_val})," + f"For loop has invalid number of iterations ({folded_right.value})," " the value must be greater than zero", args[1].right, ) else: # range(CONSTANT, CONSTANT) - arg1_val = args[1]._metadata.get("folded_value") - if not isinstance(arg1_val, int): + folded_arg1 = args[1]._metadata.get("folded_value") + if not isinstance(folded_arg1, vy_ast.Int): raise InvalidType("Value must be a literal integer", args[1]) - validate_expected_type(args[1], IntegerT.any()) - if arg0_val >= arg1_val: + validate_expected_type(folded_arg1, IntegerT.any()) + if folded_arg0.value >= folded_arg1.value: raise StructureException("Second value must be > first value", args[1]) if not type_list: @@ -419,8 +420,8 @@ def visit_For(self, node): else: # iteration over a variable or literal list - iter_ = node.iter._metadata.get("folded_value") - if isinstance(iter_, list) and len(iter_) == 0: + folded_iter = node.iter._metadata.get("folded_value") + if isinstance(folded_iter, vy_ast.List) and len(folded_iter.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) type_list = [ diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index c2e8d64d6b..e2747d3920 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -1,4 +1,4 @@ -from decimal import Decimal +from typing import Optional from vyper import ast as vy_ast from vyper.exceptions import UnfoldableNode, VyperException @@ -41,7 +41,7 @@ def __init__(self, node: vy_ast.VyperNode) -> None: self.visit(c.value) - val = c.value._metadata.get("folded_value") + val = get_folded_value(c.value) # note that if a constant is redefined, its value will be overwritten, # but it is okay because the syntax error is handled downstream @@ -62,7 +62,8 @@ def visit(self, node): def visit_EventDef(self, node): for n in node.body: - self.visit(n.annotation) + if isinstance(n, vy_ast.AnnAssign): + self.visit(n.annotation) def visit_FunctionDef(self, node): # visit type annotations of arguments @@ -89,7 +90,8 @@ def visit_Module(self, node): def visit_StructDef(self, node): for n in node.body: - self.visit(n.annotation) + if isinstance(node, vy_ast.AnnAssign): + self.visit(n.annotation) def visit_VariableDecl(self, node): self.visit(node.annotation) @@ -147,28 +149,37 @@ def visit_Return(self, node): # Expr + def visit_keyword(self, node): + self.visit(node.arg) + self.visit(node.value) + def visit_Attribute(self, node): self.visit(node.value) - value_node_val = node.value._metadata.get("folded_value") - if isinstance(value_node_val, dict): - node._metadata["folded_value"] = value_node_val[node.attr] + value_node = get_folded_value(node.value) + if isinstance(value_node, vy_ast.Dict): + for k, v in zip(node.keys, node.values): + if k.id == node.attr: + node._metadata["folded_value"] = v + return def visit_BinOp(self, node): self.visit(node.left) self.visit(node.right) - left = node.left._metadata.get("folded_value") - right = node.right._metadata.get("folded_value") - if isinstance(left, type(right)) and isinstance(left, (int, Decimal)): - node._metadata["folded_value"] = node.op._op(left, right) + left = get_folded_value(node.left) + right = get_folded_value(node.right) + if isinstance(left, type(right)) and isinstance(left, (vy_ast.Int, vy_ast.Decimal)): + value = node.op._op(left.value, right.value) + node._metadata["folded_value"] = type(left).from_node(node, value=value) def visit_BoolOp(self, node): for i in node.values: self.visit(i) - values = [i._metadata.get("folded_value") for i in node.values] - if all(isinstance(v, bool) for v in values): - node._metadata["folded_value"] = node.op._op(values) + values = [get_folded_value(i) for i in node.values] + if all(isinstance(v, vy_ast.NameConstant) for v in values): + value = node.op._op([v.value for v in values]) + node._metadata["folded_value"] = vy_ast.NameConstant.from_node(node, value=value) def visit_Call(self, node): for arg in node.args: @@ -179,7 +190,7 @@ def visit_Call(self, node): # constant structs if len(node.args) == 1 and isinstance(node.args[0], vy_ast.Dict): self.visit(node.args[0]) - node._metadata["folded_value"] = node.args[0]._metadata.get("folded_value") + node._metadata["folded_value"] = get_folded_value(node.args[0]) from vyper.builtins.functions import DISPATCH_TABLE @@ -190,7 +201,9 @@ def visit_Call(self, node): call_type = DISPATCH_TABLE.get(func_name) if call_type and hasattr(call_type, "evaluate"): try: - node._metadata["folded_value"] = call_type.evaluate(node).value # type: ignore + node._metadata["folded_value"] = call_type.evaluate( + node, skip_typecheck=True + ) # type: ignore return except (UnfoldableNode, VyperException): pass @@ -199,44 +212,49 @@ def visit_Compare(self, node): self.visit(node.left) self.visit(node.right) - left = node.left._metadata.get("folded_value") + left = get_folded_value(node.left) if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): if not isinstance(node.right, (vy_ast.List, vy_ast.Tuple)): return - right = [i._metadata.get("folded_value") for i in node.right.elements] + right = [get_folded_value(i) for i in node.right.elements] if left is None or len(set([type(i) for i in right])) > 1: return - node._metadata["folded_value"] = node.op._op(left, right) + value = node.op._op(left.value, [i.value for i in right]) + node._metadata["folded_value"] = vy_ast.NameConstant.from_node(value=value) + return - right = node.right._metadata.get("folded_value") - if isinstance(left, type(right)) and isinstance(left, (int, Decimal)): - node._metadata["folded_value"] = node.op._op(left, right) + right = get_folded_value(node) + if isinstance(left, type(right)) and isinstance(left, (vy_ast.Int, vy_ast.Decimal)): + value = node.op._op(left.value, right.value) + node._metadata["folded_value"] = vy_ast.NameConstant.from_node(value=value) def visit_Constant(self, node): - node._metadata["folded_value"] = node.value + node._metadata["folded_value"] = node def visit_Dict(self, node): for v in node.values: self.visit(v) - values = [v._metadata.get("folded_value") for v in node.values] + values = [get_folded_value(v) for v in node.values] if not any(v is None for v in values): - node._metadata["folded_value"] = {k.id: v for (k, v) in zip(node.keys, values)} + node._metadata["folded_value"] = vy_ast.Dict.from_node( + node, keys=node.keys, values=values + ) def visit_Index(self, node): self.visit(node.value) - index_val = node.value._metadata.get("folded_value") - if index_val is not None: - node._metadata["folded_value"] = index_val + index = get_folded_value(node.value) + if isinstance(index, vy_ast.Constant): + node._metadata["folded_value"] = index # repeated code for List and Tuple def _subscriptable_helper(self, node): for e in node.elements: self.visit(e) - values = [e._metadata.get("folded_value") for e in node.elements] + values = [get_folded_value(e) for e in node.elements] if None not in values: node._metadata["folded_value"] = values @@ -251,21 +269,33 @@ def visit_Subscript(self, node): self.visit(node.slice) self.visit(node.value) - slice_val = node.slice._metadata.get("folded_value") - sliced_val = node.value._metadata.get("folded_value") - if None not in (slice_val, sliced_val): - node._metadata["folded_value"] = sliced_val[slice_val] + sliced = get_folded_value(node.slice) + index = get_folded_value(node.value) + if None not in (sliced, index): + node._metadata["folded_value"] = sliced.elements[index.value] def visit_Tuple(self, node): self._subscriptable_helper(node) def visit_UnaryOp(self, node): self.visit(node.operand) - val = node.operand._metadata.get("folded_value") - if isinstance(val, int): - node._metadata["folded_value"] = node.op._op(val) + val = get_folded_value(node.operand) + if isinstance(val, (vy_ast.Int, vy_ast.Decimal)): + value = node.op._op(val.value) + node._metadata["folded_value"] = type(val).from_node(node, value=value) def visit_IfExp(self, node): self.visit(node.test) self.visit(node.body) self.visit(node.orelse) + + +def get_folded_value(node: vy_ast.VyperNode) -> Optional[vy_ast.VyperNode]: + if isinstance(node, vy_ast.Constant): + return node + elif isinstance(node, (vy_ast.List, vy_ast.Tuple)): + values = [get_folded_value(i) for i in node.elements] + if None not in values: + return values + + return node._metadata.get("folded_value") diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 6b802f541e..1752f971f1 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -128,11 +128,12 @@ def validate_index_type(self, node): # TODO break this cycle from vyper.semantics.analysis.utils import validate_expected_type - index_val = node._metadata.get("folded_value") - if isinstance(index_val, int): - if index_val < 0: + index = node._metadata.get("folded_value") + if isinstance(index, vy_ast.Int): + value = index.value + if value < 0: raise ArrayIndexException("Vyper does not support negative indexing", node) - if index_val >= self.length: + if value >= self.length: raise ArrayIndexException("Index out of range", node) validate_expected_type(node, IntegerT.any()) @@ -286,11 +287,12 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": node, ) - max_length = node.slice.value.elements[1]._metadata.get("folded_value") - if not isinstance(max_length, int): + folded_max_length = node.slice.value.elements[1]._metadata.get("folded_value") + if not isinstance(folded_max_length, vy_ast.Int): raise StructureException( "DynArray must have a max length of integer type, e.g. DynArray[bool, 5]", node ) + max_length = folded_max_length.value value_type = type_from_annotation(node.slice.value.elements[0]) if not value_type._as_darray: diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 38f160bf7c..b140451cfa 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -139,11 +139,12 @@ def get_index_value(node: vy_ast.Index) -> int: int Literal integer value. """ - val = node.value._metadata.get("folded_value") + folded_node = node.value._metadata.get("folded_value") - if not isinstance(val, int): + if not isinstance(folded_node, vy_ast.Int): raise InvalidType("Subscript must be a literal integer", node) + val = folded_node.value if val <= 0: raise ArrayIndexException("Subscript must be greater than 0", node)