From 0dcb40ef545d658cd2aa599592f874b8c1125021 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 8 Sep 2023 00:05:42 +0800 Subject: [PATCH 01/54] wip --- vyper/ast/annotation.py | 7 ++++- vyper/ast/pre_parser.py | 43 +++++++++++++++++++++++++++- vyper/ast/utils.py | 19 +++++++++--- vyper/compiler/phases.py | 13 +++++---- vyper/semantics/analysis/__init__.py | 4 +-- vyper/semantics/analysis/local.py | 16 ++++++++--- 6 files changed, 84 insertions(+), 18 deletions(-) diff --git a/vyper/ast/annotation.py b/vyper/ast/annotation.py index 9c7b1e063f..fa1742859d 100644 --- a/vyper/ast/annotation.py +++ b/vyper/ast/annotation.py @@ -1,7 +1,7 @@ import ast as python_ast import tokenize from decimal import Decimal -from typing import Optional, cast +from typing import Any, Optional, cast import asttokens @@ -249,6 +249,7 @@ def annotate_python_ast( parsed_ast: python_ast.AST, source_code: str, modification_offsets: Optional[ModificationOffsets] = None, + loop_var_annotations: Optional[dict[int, python_ast.AST]] = None, source_id: int = 0, contract_name: Optional[str] = None, ) -> python_ast.AST: @@ -272,5 +273,9 @@ def annotate_python_ast( tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) visitor = AnnotatingVisitor(source_code, modification_offsets, tokens, source_id, contract_name) visitor.visit(parsed_ast) + for k, v in loop_var_annotations.items(): + tokens = asttokens.ASTTokens(v["source_code"], tree=cast(Optional[python_ast.Module], v["parsed_ast"])) + visitor = AnnotatingVisitor(v["source_code"], {}, tokens, source_id, contract_name) + visitor.visit(v["parsed_ast"]) return parsed_ast diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 0ead889787..b4bdf15cd8 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -78,11 +78,16 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: result = [] modification_offsets: ModificationOffsets = {} settings = Settings() + loop_var_annotations = {} try: code_bytes = code.encode("utf-8") token_list = list(tokenize(io.BytesIO(code_bytes).readline)) + is_for_loop = False + after_loop_var = False + loop_var_annotation = [] + for i in range(len(token_list)): token = token_list[i] toks = [token] @@ -142,8 +147,44 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: if (typ, string) == (OP, ";"): raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1]) + + if typ == NAME and string == "for": + is_for_loop = True + #print("for loop!") + #print(token) + + if is_for_loop: + if typ == NAME and string == "in": + loop_var_annotations[start[0]] = loop_var_annotation + + is_for_loop = False + after_loop_var = False + loop_var_annotation = [] + + elif (typ, string) == (OP, ":"): + after_loop_var = True + continue + + elif after_loop_var and not (typ == NAME and string == "for"): + #print("adding to loop var: ", toks) + loop_var_annotation.extend(toks) + continue + result.extend(toks) except TokenError as e: raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e - return settings, modification_offsets, untokenize(result).decode("utf-8") + for k, v in loop_var_annotations.items(): + + + updated_v = untokenize(v) + #print("untokenized v: ", updated_v) + updated_v = updated_v.replace("\\", "") + updated_v = updated_v.replace("\n", "") + import textwrap + #print("updated v: ", textwrap.dedent(updated_v)) + loop_var_annotations[k] = {"source_code": textwrap.dedent(updated_v)} + + #print("untokenized result: ", type(untokenize(result))) + #print("untokenized result decoded: ", untokenize(result).decode("utf-8")) + return settings, modification_offsets, loop_var_annotations, untokenize(result).decode("utf-8") diff --git a/vyper/ast/utils.py b/vyper/ast/utils.py index 4e669385ab..5d83c7996a 100644 --- a/vyper/ast/utils.py +++ b/vyper/ast/utils.py @@ -17,7 +17,7 @@ def parse_to_ast_with_settings( source_id: int = 0, contract_name: Optional[str] = None, add_fn_node: Optional[str] = None, -) -> tuple[Settings, vy_ast.Module]: +) -> tuple[Settings, vy_ast.Module, dict[int, dict[str, Any]]]: """ Parses a Vyper source string and generates basic Vyper AST nodes. @@ -39,9 +39,16 @@ def parse_to_ast_with_settings( """ if "\x00" in source_code: raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, class_types, reformatted_code = pre_parse(source_code) + settings, class_types, loop_var_annotations, reformatted_code = pre_parse(source_code) try: py_ast = python_ast.parse(reformatted_code) + + print("loop vars: ", loop_var_annotations) + for k, v in loop_var_annotations.items(): + print("v: ", v) + parsed_v = python_ast.parse(v["source_code"]) + print("parsed v: ", parsed_v.body[0].value) + loop_var_annotations[k]["parsed_ast"] = parsed_v except SyntaxError as e: # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e @@ -53,12 +60,16 @@ def parse_to_ast_with_settings( fn_node.body = py_ast.body fn_node.args = python_ast.arguments(defaults=[]) py_ast.body = [fn_node] - annotate_python_ast(py_ast, source_code, class_types, source_id, contract_name) + annotate_python_ast(py_ast, source_code, class_types, loop_var_annotations, source_id, contract_name) # Convert to Vyper AST. module = vy_ast.get_node(py_ast) + + for k, v in loop_var_annotations.items(): + loop_var_annotations[k]["vy_ast"] = vy_ast.get_node(v["parsed_ast"]) + assert isinstance(module, vy_ast.Module) # mypy hint - return settings, module + return settings, module, loop_var_annotations def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index a1c7342320..00d124e3d4 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -1,7 +1,7 @@ import copy import warnings from functools import cached_property -from typing import Optional, Tuple +from typing import Any, Optional, Tuple from vyper import ast as vy_ast from vyper.codegen import module @@ -90,7 +90,7 @@ def __init__( @cached_property def _generate_ast(self): - settings, ast = generate_ast(self.source_code, self.source_id, self.contract_name) + settings, ast, loop_var_annotations = generate_ast(self.source_code, self.source_id, self.contract_name) # validate the compiler settings # XXX: this is a bit ugly, clean up later if settings.evm_version is not None: @@ -117,7 +117,7 @@ def _generate_ast(self): if self.settings.optimize is None: self.settings.optimize = OptimizationLevel.default() - return ast + return ast, loop_var_annotations @cached_property def vyper_module(self): @@ -128,12 +128,12 @@ def vyper_module_unfolded(self) -> vy_ast.Module: # This phase is intended to generate an AST for tooling use, and is not # used in the compilation process. - return generate_unfolded_ast(self.vyper_module, self.interface_codes) + return generate_unfolded_ast(self.vyper_module[0], self.interface_codes) @cached_property def _folded_module(self): return generate_folded_ast( - self.vyper_module, self.interface_codes, self.storage_layout_override + self.vyper_module[0], self.vyper_module[1], self.interface_codes, self.storage_layout_override ) @property @@ -240,6 +240,7 @@ def generate_unfolded_ast( def generate_folded_ast( vyper_module: vy_ast.Module, + loop_var_annotations: dict[int, dict[str, Any]], interface_codes: Optional[InterfaceImports], storage_layout_overrides: StorageLayout = None, ) -> Tuple[vy_ast.Module, StorageLayout]: @@ -262,7 +263,7 @@ def generate_folded_ast( vyper_module_folded = copy.deepcopy(vyper_module) vy_ast.folding.fold(vyper_module_folded) - validate_semantics(vyper_module_folded, interface_codes) + validate_semantics(vyper_module_folded, loop_var_annotations, interface_codes) symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides) return vyper_module_folded, symbol_tables diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 9e987d1cd0..960bf530ba 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -7,11 +7,11 @@ from .utils import _ExprAnalyser -def validate_semantics(vyper_ast, interface_codes): +def validate_semantics(vyper_ast, loop_var_annotations, interface_codes): # validate semantics and annotate AST with type/semantics information namespace = get_namespace() with namespace.enter_scope(): add_module_namespace(vyper_ast, interface_codes) vy_ast.expansion.expand_annotated_ast(vyper_ast) - validate_functions(vyper_ast) + validate_functions(vyper_ast, loop_var_annotations) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index c10df3b8fd..af39a01283 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from vyper import ast as vy_ast from vyper.ast.metadata import NodeMetadata @@ -50,7 +50,7 @@ from vyper.semantics.types.utils import type_from_annotation -def validate_functions(vy_module: vy_ast.Module) -> None: +def validate_functions(vy_module: vy_ast.Module, loop_var_annotations: dict[int, dict[str, Any]]) -> None: """Analyzes a vyper ast and validates the function-level namespaces.""" err_list = ExceptionList() @@ -58,7 +58,7 @@ def validate_functions(vy_module: vy_ast.Module) -> None: for node in vy_module.get_children(vy_ast.FunctionDef): with namespace.enter_scope(): try: - FunctionNodeVisitor(vy_module, node, namespace) + FunctionNodeVisitor(vy_module, loop_var_annotations, node, namespace) except VyperException as e: err_list.append(e) @@ -165,9 +165,10 @@ class FunctionNodeVisitor(VyperNodeVisitorBase): scope_name = "function" def __init__( - self, vyper_module: vy_ast.Module, fn_node: vy_ast.FunctionDef, namespace: dict + self, vyper_module: vy_ast.Module, loop_var_annotations: dict[int, dict[str, Any]], fn_node: vy_ast.FunctionDef, namespace: dict ) -> None: self.vyper_module = vyper_module + self.loop_var_annotations = loop_var_annotations self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["type"] @@ -340,6 +341,11 @@ def visit_For(self, node): if isinstance(node.iter, vy_ast.Subscript): raise StructureException("Cannot iterate over a nested list", node.iter) + print("visit_For: ", node.lineno) + iter_type = type_from_annotation(self.loop_var_annotations[node.lineno]["vy_ast"].body[0].value) + print("iter type: ", iter_type) + node.target._metadata["type"] = iter_type + if isinstance(node.iter, vy_ast.Call): # iteration via range() if node.iter.get("func.id") != "range": @@ -468,6 +474,7 @@ def visit_For(self, node): if not isinstance(node.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target) + """ for_loop_exceptions = [] iter_name = node.target.id for type_ in type_list: @@ -514,6 +521,7 @@ def visit_For(self, node): for type_, exc in zip(type_list, for_loop_exceptions) ), ) + """ def visit_Expr(self, node): if not isinstance(node.value, vy_ast.Call): From dabe7e7a9d925ee7f6b47f985e61c65fb5232d86 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 15:47:38 +0800 Subject: [PATCH 02/54] apply bts suggestion --- vyper/ast/parse.py | 24 +++++++++++++++++++----- vyper/compiler/phases.py | 10 ++++------ vyper/semantics/analysis/module.py | 15 +++++++-------- 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index b4dd9531d0..70eedf54e9 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -12,9 +12,9 @@ from vyper.typing import ModificationOffsets -def parse_to_ast(*args: Any, **kwargs: Any) -> tuple[vy_ast.Module, dict[int, dict[str, Any]]]: - _settings, ast, loop_var_annotations = parse_to_ast_with_settings(*args, **kwargs) - return ast, loop_var_annotations +def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module: + _settings, ast = parse_to_ast_with_settings(*args, **kwargs) + return ast def parse_to_ast_with_settings( @@ -57,6 +57,12 @@ def parse_to_ast_with_settings( settings, class_types, loop_var_annotations, reformatted_code = pre_parse(source_code) try: py_ast = python_ast.parse(reformatted_code) + + for k, v in loop_var_annotations.items(): + print("v: ", v) + parsed_v = python_ast.parse(v["source_code"]) + print("parsed v: ", parsed_v.body[0].value) + loop_var_annotations[k]["parsed_ast"] = parsed_v except SyntaxError as e: # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e @@ -83,7 +89,13 @@ def parse_to_ast_with_settings( module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint - return settings, module, loop_var_annotations + for k, v in loop_var_annotations.items(): + loop_var_vy_ast = vy_ast.get_node(v["parsed_ast"]) + loop_var_annotations[k]["vy_ast"] = loop_var_vy_ast + + module._metadata["loop_var_annotations"] = loop_var_annotations + + return settings, module def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: @@ -390,8 +402,10 @@ def annotate_python_ast( ) visitor.visit(parsed_ast) for k, v in loop_var_annotations.items(): + print("k: ", k) + print("v: ", v) tokens = asttokens.ASTTokens(v["source_code"], tree=cast(Optional[python_ast.Module], v["parsed_ast"])) - visitor = AnnotatingVisitor(v["source_code"], {}, tokens, source_id, contract_name) + visitor = AnnotatingVisitor(v["source_code"], {}, tokens, source_id, module_path=module_path, resolved_path=resolved_path) visitor.visit(v["parsed_ast"]) return parsed_ast diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 0e621c62f1..4e6cc9df86 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -129,7 +129,7 @@ def contract_path(self): @cached_property def _generate_ast(self): - settings, ast, loop_var_annotations = vy_ast.parse_to_ast_with_settings( + settings, ast = vy_ast.parse_to_ast_with_settings( self.source_code, self.source_id, module_path=str(self.contract_path), @@ -145,7 +145,7 @@ def _generate_ast(self): # note self.settings.compiler_version is erased here as it is # not used after pre-parsing - return ast, loop_var_annotations + return ast @cached_property def vyper_module(self): @@ -153,9 +153,8 @@ def vyper_module(self): @cached_property def _annotated_module(self): - ast, loop_var_annotations = self.vyper_module return generate_annotated_ast( - ast, loop_var_annotations, self.input_bundle, self.storage_layout_override + self.vyper_module, self.input_bundle, self.storage_layout_override ) @property @@ -245,7 +244,6 @@ def blueprint_bytecode(self) -> bytes: def generate_annotated_ast( vyper_module: vy_ast.Module, - loop_var_annotations: dict[int, dict[str, Any]], input_bundle: InputBundle, storage_layout_overrides: StorageLayout = None, ) -> tuple[vy_ast.Module, StorageLayout]: @@ -267,7 +265,7 @@ def generate_annotated_ast( vyper_module = copy.deepcopy(vyper_module) with input_bundle.search_path(Path(vyper_module.resolved_path).parent): # note: validate_semantics does type inference on the AST - validate_semantics(vyper_module, loop_var_annotations, input_bundle) + validate_semantics(vyper_module, input_bundle) symbol_tables = set_data_positions(vyper_module, storage_layout_overrides) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index da30a261db..92b9186412 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -39,13 +39,12 @@ from vyper.semantics.types.utils import type_from_annotation -def validate_semantics(module_ast, loop_var_annotations, input_bundle, is_interface=False) -> ModuleT: - return validate_semantics_r(module_ast, loop_var_annotations, input_bundle, ImportGraph(), is_interface) +def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: + return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) def validate_semantics_r( module_ast: vy_ast.Module, - loop_var_annotations: dict[int, dict[str, Any]], input_bundle: InputBundle, import_graph: ImportGraph, is_interface: bool, @@ -70,7 +69,7 @@ def validate_semantics_r( # if this is an interface, the function is already validated # in `ContractFunction.from_vyi()` if not is_interface: - validate_functions(module_ast, loop_var_annotations) + validate_functions(module_ast) return ret @@ -486,13 +485,13 @@ def _load_import_helper( def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: - ast, loop_var_annotations = vy_ast.parse_to_ast( + ast = vy_ast.parse_to_ast( file.source_code, source_id=file.source_id, module_path=str(file.path), resolved_path=str(file.resolved_path), ) - return ast, loop_var_annotations + return ast # convert an import to a path (without suffix) @@ -543,8 +542,8 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: raise ModuleNotFoundError(f"Not a builtin: {module_str}") from None # TODO: it might be good to cache this computation - interface_ast, loop_var_annotations = _parse_and_fold_ast(file) + interface_ast = _parse_and_fold_ast(file) with override_global_namespace(Namespace()): - module_t = validate_semantics(interface_ast, loop_var_annotations, input_bundle, is_interface=True) + module_t = validate_semantics(interface_ast, input_bundle, is_interface=True) return module_t.interface From 899699eb267181aa046c54b63f48cc41b5bc8f4d Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 15:47:45 +0800 Subject: [PATCH 03/54] fix for visit --- vyper/semantics/analysis/local.py | 86 +++++++++---------------------- 1 file changed, 25 insertions(+), 61 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index b499b91d8a..1e47417083 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -53,7 +53,7 @@ from vyper.semantics.types.utils import type_from_annotation -def validate_functions(vy_module: vy_ast.Module, loop_var_annotations: dict[int, dict[str, Any]]) -> None: +def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() @@ -61,7 +61,7 @@ def validate_functions(vy_module: vy_ast.Module, loop_var_annotations: dict[int, for node in vy_module.get_children(vy_ast.FunctionDef): with namespace.enter_scope(): try: - analyzer = FunctionNodeVisitor(vy_module, loop_var_annotations, node, namespace) + analyzer = FunctionNodeVisitor(vy_module, node, namespace) analyzer.analyze() except VyperException as e: err_list.append(e) @@ -180,10 +180,9 @@ class FunctionNodeVisitor(VyperNodeVisitorBase): scope_name = "function" def __init__( - self, vyper_module: vy_ast.Module, loop_var_annotations: dict[int, dict[str, Any]], fn_node: vy_ast.FunctionDef, namespace: dict + self, vyper_module: vy_ast.Module, fn_node: vy_ast.FunctionDef, namespace: dict ) -> None: self.vyper_module = vyper_module - self.loop_var_annotations = loop_var_annotations self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["func_type"] @@ -229,6 +228,7 @@ def visit_AnnAssign(self, node): "Memory variables must be declared with an initial value", node ) + print("visit_AnnAssign - typ: ", type(node.annotation)) typ = type_from_annotation(node.annotation, DataLocation.MEMORY) validate_expected_type(node.value, typ) @@ -352,7 +352,11 @@ def visit_For(self, node): raise StructureException("Cannot iterate over a nested list", node.iter) print("visit_For: ", node.lineno) - iter_type = type_from_annotation(self.loop_var_annotations[node.lineno]["vy_ast"].body[0].value) + loop_var_annotations = self.vyper_module._metadata.get("loop_var_annotations") + print("visit_For - loop vars: ", loop_var_annotations) + iter_annotation_node = loop_var_annotations[node.lineno]["vy_ast"].body[0].value + print("visit_For - type annotation node type: ", type(iter_annotation_node)) + iter_type = type_from_annotation(iter_annotation_node, DataLocation.MEMORY) print("iter type: ", iter_type) node.target._metadata["type"] = iter_type @@ -424,62 +428,22 @@ def visit_For(self, node): if not isinstance(node.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target) - # for_loop_exceptions = [] - # iter_name = node.target.id - # for possible_target_type in type_list: - # # type check the for loop body using each possible type for iterator value - - # with self.namespace.enter_scope(): - # self.namespace[iter_name] = VarInfo( - # possible_target_type, modifiability=Modifiability.RUNTIME_CONSTANT - # ) - - # try: - # with NodeMetadata.enter_typechecker_speculation(): - # for stmt in node.body: - # self.visit(stmt) - - # self.expr_visitor.visit(node.target, possible_target_type) - - # if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - # iter_type = get_exact_type_from_node(node.iter) - # # note CMC 2023-10-23: slightly redundant with how type_list is computed - # validate_expected_type(node.target, iter_type.value_type) - # self.expr_visitor.visit(node.iter, iter_type) - # if isinstance(node.iter, vy_ast.List): - # len_ = len(node.iter.elements) - # self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) - # if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - # for a in node.iter.args: - # self.expr_visitor.visit(a, possible_target_type) - # for a in node.iter.keywords: - # if a.arg == "bound": - # self.expr_visitor.visit(a.value, possible_target_type) - - # except (TypeMismatch, InvalidOperation) as exc: - # for_loop_exceptions.append(exc) - # else: - # # success -- do not enter error handling section - # return - - # # failed to find a good type. bail out - # if len(set(str(i) for i in for_loop_exceptions)) == 1: - # # if every attempt at type checking raised the same exception - # raise for_loop_exceptions[0] - - # # return an aggregate TypeMismatch that shows all possible exceptions - # # depending on which type is used - # types_str = [str(i) for i in type_list] - # given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" - # raise TypeMismatch( - # f"Iterator value '{iter_name}' may be cast as {given_str}, " - # "but type checking fails with all possible types:", - # node, - # *( - # (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0]) - # for typ, exc in zip(type_list, for_loop_exceptions) - # ), - # ) + self.expr_visitor.visit(node.target, iter_type) + + if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): + #iter_type = get_exact_type_from_node(node.iter) + # note CMC 2023-10-23: slightly redundant with how type_list is computed + #validate_expected_type(node.target, iter_type.value_type) + self.expr_visitor.visit(node.iter, iter_type) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) + if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + for a in node.iter.args: + self.expr_visitor.visit(a, iter_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, iter_type) def visit_If(self, node): validate_expected_type(node.test, BoolT()) From 8d9b2ef21881e78a61b5cf96732fd89c198ea758 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 15:59:28 +0800 Subject: [PATCH 04/54] clean up prints --- vyper/ast/parse.py | 2 - vyper/semantics/analysis/local.py | 77 ++++++++++++++++++++++--------- 2 files changed, 56 insertions(+), 23 deletions(-) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 70eedf54e9..a0d252ddc7 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -59,9 +59,7 @@ def parse_to_ast_with_settings( py_ast = python_ast.parse(reformatted_code) for k, v in loop_var_annotations.items(): - print("v: ", v) parsed_v = python_ast.parse(v["source_code"]) - print("parsed v: ", parsed_v.body[0].value) loop_var_annotations[k]["parsed_ast"] = parsed_v except SyntaxError as e: # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 1e47417083..ca4922e694 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -228,7 +228,6 @@ def visit_AnnAssign(self, node): "Memory variables must be declared with an initial value", node ) - print("visit_AnnAssign - typ: ", type(node.annotation)) typ = type_from_annotation(node.annotation, DataLocation.MEMORY) validate_expected_type(node.value, typ) @@ -351,13 +350,13 @@ def visit_For(self, node): if isinstance(node.iter, vy_ast.Subscript): raise StructureException("Cannot iterate over a nested list", node.iter) - print("visit_For: ", node.lineno) loop_var_annotations = self.vyper_module._metadata.get("loop_var_annotations") - print("visit_For - loop vars: ", loop_var_annotations) - iter_annotation_node = loop_var_annotations[node.lineno]["vy_ast"].body[0].value - print("visit_For - type annotation node type: ", type(iter_annotation_node)) + iter_annotation = loop_var_annotations.get(node.lineno).get("vy_ast") + if not iter_annotation: + raise StructureException("Iterator needs type annotation", node.iter) + + iter_annotation_node = iter_annotation.body[0].value iter_type = type_from_annotation(iter_annotation_node, DataLocation.MEMORY) - print("iter type: ", iter_type) node.target._metadata["type"] = iter_type if isinstance(node.iter, vy_ast.Call): @@ -428,23 +427,59 @@ def visit_For(self, node): if not isinstance(node.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target) - self.expr_visitor.visit(node.target, iter_type) + iter_name = node.target.id + with self.namespace.enter_scope(): + self.namespace[iter_name] = VarInfo( + iter_type, modifiability=Modifiability.RUNTIME_CONSTANT + ) - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - #iter_type = get_exact_type_from_node(node.iter) - # note CMC 2023-10-23: slightly redundant with how type_list is computed - #validate_expected_type(node.target, iter_type.value_type) - self.expr_visitor.visit(node.iter, iter_type) - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, iter_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, iter_type) + try: + with NodeMetadata.enter_typechecker_speculation(): + for stmt in node.body: + self.visit(stmt) + + self.expr_visitor.visit(node.target, iter_type) + + if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): + iter_type = get_exact_type_from_node(node.iter) + # note CMC 2023-10-23: slightly redundant with how type_list is computed + validate_expected_type(node.target, iter_type.value_type) + self.expr_visitor.visit(node.iter, iter_type) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) + if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + for a in node.iter.args: + self.expr_visitor.visit(a, iter_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, iter_type) + + except (TypeMismatch, InvalidOperation) as exc: + for_loop_exceptions.append(exc) + else: + # success -- do not enter error handling section + return + # failed to find a good type. bail out + if len(set(str(i) for i in for_loop_exceptions)) == 1: + # if every attempt at type checking raised the same exception + raise for_loop_exceptions[0] + + # return an aggregate TypeMismatch that shows all possible exceptions + # depending on which type is used + types_str = [str(i) for i in type_list] + given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" + raise TypeMismatch( + f"Iterator value '{iter_name}' may be cast as {given_str}, " + "but type checking fails with all possible types:", + node, + *( + (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0]) + for typ, exc in zip(type_list, for_loop_exceptions) + ), + ) + def visit_If(self, node): validate_expected_type(node.test, BoolT()) self.expr_visitor.visit(node.test, BoolT()) From bc5422a0bdbac02501e5db61d9c7199a002b0e4b Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 15:59:36 +0800 Subject: [PATCH 05/54] update examples --- examples/auctions/blind_auction.vy | 2 +- examples/tokens/ERC1155ownable.vy | 8 ++++---- examples/voting/ballot.vy | 6 +++--- examples/wallet/wallet.vy | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/auctions/blind_auction.vy b/examples/auctions/blind_auction.vy index 04f908f6d0..597aed57c7 100644 --- a/examples/auctions/blind_auction.vy +++ b/examples/auctions/blind_auction.vy @@ -107,7 +107,7 @@ def reveal(_numBids: int128, _values: uint256[128], _fakes: bool[128], _secrets: # Calculate refund for sender refund: uint256 = 0 - for i in range(MAX_BIDS): + for i: int128 in range(MAX_BIDS): # Note that loop may break sooner than 128 iterations if i >= _numBids if (i >= _numBids): break diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index 30057582e8..e105a79133 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -205,7 +205,7 @@ def balanceOfBatch(accounts: DynArray[address, BATCH_SIZE], ids: DynArray[uint25 assert len(accounts) == len(ids), "ERC1155: accounts and ids length mismatch" batchBalances: DynArray[uint256, BATCH_SIZE] = [] j: uint256 = 0 - for i in ids: + for i: uint256 in ids: batchBalances.append(self.balanceOf[accounts[j]][i]) j += 1 return batchBalances @@ -243,7 +243,7 @@ def mintBatch(receiver: address, ids: DynArray[uint256, BATCH_SIZE], amounts: Dy assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break self.balanceOf[receiver][ids[i]] += amounts[i] @@ -277,7 +277,7 @@ def burnBatch(ids: DynArray[uint256, BATCH_SIZE], amounts: DynArray[uint256, BAT assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break self.balanceOf[msg.sender][ids[i]] -= amounts[i] @@ -333,7 +333,7 @@ def safeBatchTransferFrom(sender: address, receiver: address, ids: DynArray[uint assert sender == msg.sender or self.isApprovedForAll[sender][msg.sender], "Caller is neither owner nor approved operator for this ID" assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break id: uint256 = ids[i] diff --git a/examples/voting/ballot.vy b/examples/voting/ballot.vy index 0b568784a9..107716accf 100644 --- a/examples/voting/ballot.vy +++ b/examples/voting/ballot.vy @@ -54,7 +54,7 @@ def directlyVoted(addr: address) -> bool: def __init__(_proposalNames: bytes32[2]): self.chairperson = msg.sender self.voterCount = 0 - for i in range(2): + for i: int128 in range(2): self.proposals[i] = Proposal({ name: _proposalNames[i], voteCount: 0 @@ -82,7 +82,7 @@ def _forwardWeight(delegate_with_weight_to_forward: address): assert self.voters[delegate_with_weight_to_forward].weight > 0 target: address = self.voters[delegate_with_weight_to_forward].delegate - for i in range(4): + for i: int128 in range(4): if self._delegated(target): target = self.voters[target].delegate # The following effectively detects cycles of length <= 5, @@ -157,7 +157,7 @@ def vote(proposal: int128): def _winningProposal() -> int128: winning_vote_count: int128 = 0 winning_proposal: int128 = 0 - for i in range(2): + for i: int128 in range(2): if self.proposals[i].voteCount > winning_vote_count: winning_vote_count = self.proposals[i].voteCount winning_proposal = i diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index e2515d9e62..231f538ecf 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -14,7 +14,7 @@ seq: public(int128) @external def __init__(_owners: address[5], _threshold: int128): - for i in range(5): + for i: uint256 in range(5): if _owners[i] != empty(address): self.owners[i] = _owners[i] self.threshold = _threshold @@ -47,7 +47,7 @@ def approve(_seq: int128, to: address, _value: uint256, data: Bytes[4096], sigda assert self.seq == _seq # # Iterates through all the owners and verifies that there signatures, # # given as the sigdata argument are correct - for i in range(5): + for i: uint256 in range(5): if sigdata[i][0] != 0: # If an invalid signature is given for an owner then the contract throws assert ecrecover(h2, sigdata[i][0], sigdata[i][1], sigdata[i][2]) == self.owners[i] From deb860f5211919b209379b718ef27c8e2208f8d0 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 16:13:55 +0800 Subject: [PATCH 06/54] delete py ast key --- vyper/ast/parse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index a0d252ddc7..41d2fec330 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -90,6 +90,7 @@ def parse_to_ast_with_settings( for k, v in loop_var_annotations.items(): loop_var_vy_ast = vy_ast.get_node(v["parsed_ast"]) loop_var_annotations[k]["vy_ast"] = loop_var_vy_ast + del loop_var_annotations[k]["parsed_ast"] module._metadata["loop_var_annotations"] = loop_var_annotations From 2c2792f629e45cc005716cafc15461ada91d617e Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 16:23:18 +0800 Subject: [PATCH 07/54] remove prints --- vyper/ast/parse.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 41d2fec330..38be672dac 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -401,8 +401,6 @@ def annotate_python_ast( ) visitor.visit(parsed_ast) for k, v in loop_var_annotations.items(): - print("k: ", k) - print("v: ", v) tokens = asttokens.ASTTokens(v["source_code"], tree=cast(Optional[python_ast.Module], v["parsed_ast"])) visitor = AnnotatingVisitor(v["source_code"], {}, tokens, source_id, module_path=module_path, resolved_path=resolved_path) visitor.visit(v["parsed_ast"]) From daef0b7fc0508b96f07bdd58d8016d9d1fe5e78e Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 16:23:26 +0800 Subject: [PATCH 08/54] fix exc in for --- vyper/semantics/analysis/local.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index ca4922e694..7231ee6d5a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -427,6 +427,7 @@ def visit_For(self, node): if not isinstance(node.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target) + for_loop_exceptions = [] iter_name = node.target.id with self.namespace.enter_scope(): self.namespace[iter_name] = VarInfo( From f644f8a8fe271d647db3f22b5f1c7810263de0d8 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 16:23:34 +0800 Subject: [PATCH 09/54] update grammar --- vyper/ast/grammar.lark | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 7889473b19..4a826153df 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -179,7 +179,7 @@ cond_exec: _expr ":" body default_exec: body if_stmt: "if" cond_exec ("elif" cond_exec)* ["else" ":" default_exec] // TODO: make this into a variable definition e.g. `for i: uint256 in range(0, 5): ...` -loop_variable: NAME [":" NAME] +loop_variable: NAME ":" type loop_iterator: _expr for_stmt: "for" loop_variable "in" loop_iterator ":" body From 49ad2cd18c1302b206226815fbdb1f1a32f63db2 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 16:23:39 +0800 Subject: [PATCH 10/54] update tests --- .../codegen/features/iteration/test_break.py | 12 +- .../features/iteration/test_continue.py | 10 +- .../features/iteration/test_for_in_list.py | 136 +++++++++--------- 3 files changed, 79 insertions(+), 79 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_break.py b/tests/functional/codegen/features/iteration/test_break.py index 8a08a11cc2..4abde9c617 100644 --- a/tests/functional/codegen/features/iteration/test_break.py +++ b/tests/functional/codegen/features/iteration/test_break.py @@ -11,7 +11,7 @@ def test_break_test(get_contract_with_gas_estimation): def foo(n: decimal) -> int128: c: decimal = n * 1.0 output: int128 = 0 - for i in range(400): + for i: int128 in range(400): c = c / 1.2589 if c < 1.0: output = i @@ -35,12 +35,12 @@ def test_break_test_2(get_contract_with_gas_estimation): def foo(n: decimal) -> int128: c: decimal = n * 1.0 output: int128 = 0 - for i in range(40): + for i: int128 in range(40): if c < 10.0: output = i * 10 break c = c / 10.0 - for i in range(10): + for i: int128 in range(10): c = c / 1.2589 if c < 1.0: output = output + i @@ -63,12 +63,12 @@ def test_break_test_3(get_contract_with_gas_estimation): def foo(n: int128) -> int128: c: decimal = convert(n, decimal) output: int128 = 0 - for i in range(40): + for i: int128 in range(40): if c < 10.0: output = i * 10 break c /= 10.0 - for i in range(10): + for i: int128 in range(10): c /= 1.2589 if c < 1.0: output = output + i @@ -108,7 +108,7 @@ def foo(): """ @external def foo(): - for i in [1, 2, 3]: + for i: uint256 in [1, 2, 3]: b: uint256 = i if True: break diff --git a/tests/functional/codegen/features/iteration/test_continue.py b/tests/functional/codegen/features/iteration/test_continue.py index 5f4f82a2de..1b2fcab460 100644 --- a/tests/functional/codegen/features/iteration/test_continue.py +++ b/tests/functional/codegen/features/iteration/test_continue.py @@ -7,7 +7,7 @@ def test_continue1(get_contract_with_gas_estimation): code = """ @external def foo() -> bool: - for i in range(2): + for i: uint256 in range(2): continue return False return True @@ -21,7 +21,7 @@ def test_continue2(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(3): + for i: int128 in range(3): x += 1 continue x -= 1 @@ -36,7 +36,7 @@ def test_continue3(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(3): + for i: int128 in range(3): x += i continue return x @@ -50,7 +50,7 @@ def test_continue4(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(6): + for i: int128 in range(6): if i % 2 == 0: continue x += 1 @@ -83,7 +83,7 @@ def foo(): """ @external def foo(): - for i in [1, 2, 3]: + for i: uint256 in [1, 2, 3]: b: uint256 = i if True: continue diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index bc1a12ae9e..33ad59370e 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -21,7 +21,7 @@ @external def data() -> int128: s: int128[5] = [1, 2, 3, 4, 5] - for i in s: + for i: int128 in s: if i >= 3: return i return -1""", @@ -33,7 +33,7 @@ def data() -> int128: @external def data() -> int128: s: DynArray[int128, 10] = [1, 2, 3, 4, 5] - for i in s: + for i: int128 in s: if i >= 3: return i return -1""", @@ -53,8 +53,8 @@ def data() -> int128: [S({x:3, y:4}), S({x:5, y:6}), S({x:7, y:8}), S({x:9, y:10})] ] ret: int128 = 0 - for ss in sss: - for s in ss: + for ss: DynArray[S, 10] in sss: + for s: S in ss: ret += s.x + s.y return ret""", sum(range(1, 11)), @@ -64,7 +64,7 @@ def data() -> int128: """ @external def data() -> int128: - for i in [3, 5, 7, 9]: + for i: int128 in [3, 5, 7, 9]: if i > 5: return i return -1""", @@ -76,7 +76,7 @@ def data() -> int128: @external def data() -> String[33]: xs: DynArray[String[33], 3] = ["hello", ",", "world"] - for x in xs: + for x: String[33] in xs: if x == ",": return x return "" @@ -88,7 +88,7 @@ def data() -> String[33]: """ @external def data() -> String[33]: - for x in ["hello", ",", "world"]: + for x: String[33] in ["hello", ",", "world"]: if x == ",": return x return "" @@ -100,7 +100,7 @@ def data() -> String[33]: """ @external def data() -> DynArray[String[33], 2]: - for x in [["hello", "world"], ["goodbye", "world!"]]: + for x: DynArray[String[33], 2] in [["hello", "world"], ["goodbye", "world!"]]: if x[1] == "world": return x return [] @@ -114,8 +114,8 @@ def data() -> DynArray[String[33], 2]: def data() -> int128: ret: int128 = 0 xss: int128[3][3] = [[1,2,3],[4,5,6],[7,8,9]] - for xs in xss: - for x in xs: + for xs: int128[3] in xss: + for x: int128 in xs: ret += x return ret""", sum(range(1, 10)), @@ -130,8 +130,8 @@ def data() -> int128: @external def data() -> int128: ret: int128 = 0 - for ss in [[S({x:1, y:2})]]: - for s in ss: + for ss: S[1] in [[S({x:1, y:2})]]: + for s: S in ss: ret += s.x + s.y return ret""", 1 + 2, @@ -147,7 +147,7 @@ def data() -> address: 0xDCEceAF3fc5C0a63d195d69b1A90011B7B19650D ] count: int128 = 0 - for i in addresses: + for i: address in addresses: count += 1 if count == 2: return i @@ -174,7 +174,7 @@ def set(): @external def data() -> int128: - for i in self.x: + for i: int128 in self.x: if i > 5: return i return -1 @@ -198,7 +198,7 @@ def set(xs: DynArray[int128, 4]): @external def data() -> int128: t: int128 = 0 - for i in self.x: + for i: int128 in self.x: t += i return t """ @@ -227,7 +227,7 @@ def ret(i: int128) -> address: @external def iterate_return_second() -> address: count: int128 = 0 - for i in self.addresses: + for i: address in self.addresses: count += 1 if count == 2: return i @@ -258,7 +258,7 @@ def ret(i: int128) -> decimal: @external def i_return(break_count: int128) -> decimal: count: int128 = 0 - for i in self.readings: + for i: decimal in self.readings: if count == break_count: return i count += 1 @@ -284,7 +284,7 @@ def func(amounts: uint256[3]) -> uint256: total: uint256 = as_wei_value(0, "wei") # calculate total - for amount in amounts: + for amount: uint256 in amounts: total += amount return total @@ -303,7 +303,7 @@ def func(amounts: DynArray[uint256, 3]) -> uint256: total: uint256 = 0 # calculate total - for amount in amounts: + for amount: uint256 in amounts: total += amount return total @@ -321,42 +321,42 @@ def func(amounts: DynArray[uint256, 3]) -> uint256: @external def foo(x: int128): p: int128 = 0 - for i in range(3): + for i: int128 in range(3): p += i - for i in range(4): + for i: int128 in range(4): p += i """, """ @external def foo(x: int128): p: int128 = 0 - for i in range(3): + for i: int128 in range(3): p += i - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i """, """ @external def foo(x: int128): p: int128 = 0 - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i """, """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass - for i in range(20): + for i: uint256 in range(20): pass """, # using index variable after loop """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass i: int128 = 100 # create new variable i i = 200 # look up the variable i and check whether it is in forvars @@ -372,25 +372,25 @@ def test_good_code(code, get_contract): RANGE_CONSTANT_CODE = [ ( """ -TREE_FIDDY: constant(int128) = 350 +TREE_FIDDY: constant(uint256) = 350 @external def a() -> uint256: x: uint256 = 0 - for i in range(TREE_FIDDY): + for i: uint256 in range(TREE_FIDDY): x += 1 return x""", 350, ), ( """ -ONE_HUNDRED: constant(int128) = 100 +ONE_HUNDRED: constant(uint256) = 100 @external def a() -> uint256: x: uint256 = 0 - for i in range(1, 1 + ONE_HUNDRED): + for i: uint256 in range(1, 1 + ONE_HUNDRED): x += 1 return x""", 100, @@ -401,9 +401,9 @@ def a() -> uint256: END: constant(int128) = 199 @external -def a() -> uint256: - x: uint256 = 0 - for i in range(START, END): +def a() -> int128: + x: int128 = 0 + for i: int128 in range(START, END): x += 1 return x""", 99, @@ -413,7 +413,7 @@ def a() -> uint256: @external def a() -> int128: x: int128 = 0 - for i in range(-5, -1): + for i: int128 in range(-5, -1): x += i return x""", -14, @@ -436,7 +436,7 @@ def test_range_constant(get_contract, code, result): def data() -> int128: s: int128[6] = [1, 2, 3, 4, 5, 6] count: int128 = 0 - for i in s: + for i: int128 in s: s[count] = 1 # this should not be allowed. if i >= 3: return i @@ -451,7 +451,7 @@ def data() -> int128: def foo(): s: int128[6] = [1, 2, 3, 4, 5, 6] count: int128 = 0 - for i in s: + for i: int128 in s: s[count] += 1 """, ImmutableViolation, @@ -468,7 +468,7 @@ def set(): @external def data() -> int128: count: int128 = 0 - for i in self.s: + for i: int128 in self.s: self.s[count] = 1 # this should not be allowed. if i >= 3: return i @@ -493,7 +493,7 @@ def doStuff(i: uint256) -> uint256: @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.foo: + for item: uint256 in self.my_array2.foo: self.doStuff(i) i += 1 """, @@ -519,7 +519,7 @@ def doStuff(i: uint256) -> uint256: @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.bar.foo: + for item: uint256 in self.my_array2.bar.foo: self.doStuff(i) i += 1 """, @@ -545,7 +545,7 @@ def doStuff(): @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.foo: + for item: uint256 in self.my_array2.foo: self.doStuff() i += 1 """, @@ -556,8 +556,8 @@ def _helper(): """ @external def foo(x: int128): - for i in range(4): - for i in range(5): + for i: int128 in range(4): + for i: int128 in range(5): pass """, NamespaceCollision, @@ -566,8 +566,8 @@ def foo(x: int128): """ @external def foo(x: int128): - for i in [1,2]: - for i in [1,2]: + for i: int128 in [1,2]: + for i: int128 in [1,2]: pass """, NamespaceCollision, @@ -577,7 +577,7 @@ def foo(x: int128): """ @external def foo(x: int128): - for i in [1,2]: + for i: int128 in [1,2]: i = 2 """, ImmutableViolation, @@ -588,7 +588,7 @@ def foo(x: int128): @external def foo(): xs: DynArray[uint256, 5] = [1,2,3] - for x in xs: + for x: uint256 in xs: xs.pop() """, ImmutableViolation, @@ -599,7 +599,7 @@ def foo(): @external def foo(): xs: DynArray[uint256, 5] = [1,2,3] - for x in xs: + for x: uint256 in xs: xs.append(x) """, ImmutableViolation, @@ -610,7 +610,7 @@ def foo(): @external def foo(): xs: DynArray[DynArray[uint256, 5], 5] = [[1,2,3]] - for x in xs: + for x: DynArray[uint256, 5] in xs: x.pop() """, ImmutableViolation, @@ -629,7 +629,7 @@ def b(): @external def foo(): - for x in self.array: + for x: uint256 in self.array: self.a() """, ImmutableViolation, @@ -638,7 +638,7 @@ def foo(): """ @external def foo(x: int128): - for i in [1,2]: + for i: int128 in [1,2]: i += 2 """, ImmutableViolation, @@ -648,7 +648,7 @@ def foo(x: int128): """ @external def foo(): - for i in range(-3): + for i: uint256 in range(-3): pass """, StructureException, @@ -656,13 +656,13 @@ def foo(): """ @external def foo(): - for i in range(0): + for i: uint256 in range(0): pass """, """ @external def foo(): - for i in []: + for i: uint256 in []: pass """, """ @@ -670,14 +670,14 @@ def foo(): @external def foo(): - for i in FOO: + for i: uint256 in FOO: pass """, ( """ @external def foo(): - for i in range(5,3): + for i: uint256 in range(5,3): pass """, StructureException, @@ -686,7 +686,7 @@ def foo(): """ @external def foo(): - for i in range(5,3,-1): + for i: int128 in range(5,3,-1): pass """, ArgumentException, @@ -696,7 +696,7 @@ def foo(): @external def foo(): a: uint256 = 2 - for i in range(a): + for i: uint256 in range(a): pass """, StateAccessViolation, @@ -706,7 +706,7 @@ def foo(): @external def foo(): a: int128 = 6 - for i in range(a,a-3): + for i: int128 in range(a,a-3): pass """, StateAccessViolation, @@ -716,7 +716,7 @@ def foo(): """ @external def foo(): - for i in range(): + for i: uint256 in range(): pass """, ArgumentException, @@ -725,7 +725,7 @@ def foo(): """ @external def foo(): - for i in range(0,1,2): + for i: uint256 in range(0,1,2): pass """, ArgumentException, @@ -735,7 +735,7 @@ def foo(): """ @external def foo(): - for i in b"asdf": + for i: Bytes[1] in b"asdf": pass """, InvalidType, @@ -744,7 +744,7 @@ def foo(): """ @external def foo(): - for i in 31337: + for i: uint256 in 31337: pass """, InvalidType, @@ -753,7 +753,7 @@ def foo(): """ @external def foo(): - for i in bar(): + for i: uint256 in bar(): pass """, IteratorException, @@ -762,7 +762,7 @@ def foo(): """ @external def foo(): - for i in self.bar(): + for i: uint256 in self.bar(): pass """, IteratorException, @@ -772,7 +772,7 @@ def foo(): @external def test_for() -> int128: a: int128 = 0 - for i in range(max_value(int128), max_value(int128)+2): + for i: int128 in range(max_value(int128), max_value(int128)+2): a = i return a """, @@ -784,7 +784,7 @@ def test_for() -> int128: def test_for() -> int128: a: int128 = 0 b: uint256 = 0 - for i in range(5): + for i: int128 in range(5): a = i b = i return a From 635e0c053a3019ee0e5cea6a657bbd5a76afe6ad Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 17:17:22 +0800 Subject: [PATCH 11/54] fix tests --- .../functional/builtins/codegen/test_empty.py | 4 +- .../builtins/codegen/test_mulmod.py | 2 +- .../functional/builtins/codegen/test_slice.py | 2 +- .../features/iteration/test_for_range.py | 56 +++++++++---------- .../codegen/features/test_assert.py | 4 +- .../codegen/features/test_internal_call.py | 2 +- .../codegen/integration/test_crowdfund.py | 4 +- .../codegen/types/numbers/test_decimals.py | 2 +- .../codegen/types/test_dynamic_array.py | 4 +- tests/functional/codegen/types/test_lists.py | 4 +- tests/functional/syntax/test_blockscope.py | 4 +- tests/functional/syntax/test_constants.py | 2 +- tests/functional/syntax/test_for_range.py | 50 ++++++++--------- tests/unit/ast/test_pre_parser.py | 2 +- 14 files changed, 71 insertions(+), 71 deletions(-) diff --git a/tests/functional/builtins/codegen/test_empty.py b/tests/functional/builtins/codegen/test_empty.py index c3627785dc..896c845da2 100644 --- a/tests/functional/builtins/codegen/test_empty.py +++ b/tests/functional/builtins/codegen/test_empty.py @@ -423,7 +423,7 @@ def test_empty(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool: view @internal def write_junk_to_memory(): xs: int128[1024] = empty(int128[1024]) - for i in range(1024): + for i: uint256 in range(1024): xs[i] = -(i + 1) @internal def priv(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool: @@ -469,7 +469,7 @@ def test_return_empty(get_contract_with_gas_estimation): @internal def write_junk_to_memory(): xs: int128[1024] = empty(int128[1024]) - for i in range(1024): + for i: uint256 in range(1024): xs[i] = -(i + 1) @external diff --git a/tests/functional/builtins/codegen/test_mulmod.py b/tests/functional/builtins/codegen/test_mulmod.py index ba82ebd5b8..31de1d9f22 100644 --- a/tests/functional/builtins/codegen/test_mulmod.py +++ b/tests/functional/builtins/codegen/test_mulmod.py @@ -20,7 +20,7 @@ def test_uint256_mulmod_complex(get_contract_with_gas_estimation): @external def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: o: uint256 = 1 - for i in range(256): + for i: uint256 in range(256): o = uint256_mulmod(o, o, modulus) if exponent & shift(1, 255 - i) != 0: o = uint256_mulmod(o, base, modulus) diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index a15a3eeb35..80936bbf82 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -17,7 +17,7 @@ def test_basic_slice(get_contract_with_gas_estimation): @external def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: inp: Bytes[50] = inp1 - for i in range(1, 11): + for i: uint256 in range(1, 11): inp = slice(inp, 1, 30 - i * 2) return inp """ diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index e946447285..c661c46553 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -6,7 +6,7 @@ def test_basic_repeater(get_contract_with_gas_estimation): @external def repeat(z: int128) -> int128: x: int128 = 0 - for i in range(6): + for i: int128 in range(6): x = x + z return(x) """ @@ -19,7 +19,7 @@ def test_range_bound(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(n, bound=6): + for i: uint256 in range(n, bound=6): x += i + 1 return x """ @@ -37,7 +37,7 @@ def test_range_bound_constant_end(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(n, 7, bound=6): + for i: uint256 in range(n, 7, bound=6): x += i + 1 return x """ @@ -58,7 +58,7 @@ def test_range_bound_two_args(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(1, n, bound=6): + for i: uint256 in range(1, n, bound=6): x += i + 1 return x """ @@ -80,7 +80,7 @@ def test_range_bound_two_runtime_args(get_contract, tx_failed): @external def repeat(start: uint256, end: uint256) -> uint256: x: uint256 = 0 - for i in range(start, end, bound=6): + for i: uint256 in range(start, end, bound=6): x += i return x """ @@ -109,7 +109,7 @@ def test_range_overflow(get_contract, tx_failed): @external def get_last(start: uint256, end: uint256) -> uint256: x: uint256 = 0 - for i in range(start, end, bound=6): + for i: uint256 in range(start, end, bound=6): x = i return x """ @@ -134,11 +134,11 @@ def test_digit_reverser(get_contract_with_gas_estimation): def reverse_digits(x: int128) -> int128: dig: int128[6] = [0, 0, 0, 0, 0, 0] z: int128 = x - for i in range(6): + for i: uint256 in range(6): dig[i] = z % 10 z = z / 10 o: int128 = 0 - for i in range(6): + for i: uint256 in range(6): o = o * 10 + dig[i] return o @@ -153,9 +153,9 @@ def test_more_complex_repeater(get_contract_with_gas_estimation): @external def repeat() -> int128: out: int128 = 0 - for i in range(6): + for i: uint256 in range(6): out = out * 10 - for j in range(4): + for j: int128 in range(4): out = out + j return(out) """ @@ -170,7 +170,7 @@ def test_offset_repeater(get_contract_with_gas_estimation, typ): @external def sum() -> {typ}: out: {typ} = 0 - for i in range(80, 121): + for i: {typ} in range(80, 121): out = out + i return out """ @@ -185,7 +185,7 @@ def test_offset_repeater_2(get_contract_with_gas_estimation, typ): @external def sum(frm: {typ}, to: {typ}) -> {typ}: out: {typ} = 0 - for i in range(frm, frm + 101, bound=101): + for i: {typ} in range(frm, frm + 101, bound=101): if i == to: break out = out + i @@ -205,7 +205,7 @@ def _bar() -> bool: @external def foo() -> bool: - for i in range(3): + for i: uint256 in range(3): self._bar() return True """ @@ -219,8 +219,8 @@ def test_return_inside_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for j in range(10): + for i: {typ} in range(10): + for j: {typ} in range(10): if j > 5: if i > a: return i @@ -254,14 +254,14 @@ def test_for_range_edge(get_contract, typ): def test(): found: bool = False x: {typ} = max_value({typ}) - for i in range(x - 1, x, bound=1): + for i: {typ} in range(x - 1, x, bound=1): if i + 1 == max_value({typ}): found = True assert found found = False x = max_value({typ}) - 1 - for i in range(x - 1, x + 1, bound=2): + for i: {typ} in range(x - 1, x + 1, bound=2): if i + 1 == max_value({typ}): found = True assert found @@ -276,7 +276,7 @@ def test_for_range_oob_check(get_contract, tx_failed, typ): @external def test(): x: {typ} = max_value({typ}) - for i in range(x, x + 2, bound=2): + for i: {typ} in range(x, x + 2, bound=2): pass """ c = get_contract(code) @@ -289,8 +289,8 @@ def test_return_inside_nested_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: return i + x return 31337 @@ -318,8 +318,8 @@ def test_return_void_nested_repeater(get_contract, typ, val): result: {typ} @internal def _final(a: {typ}): - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: self.result = i + x return @@ -347,8 +347,8 @@ def test_external_nested_repeater(get_contract, typ, val): code = f""" @external def foo(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: return i + x return 31337 @@ -368,8 +368,8 @@ def test_external_void_nested_repeater(get_contract, typ, val): result: public({typ}) @external def foo(a: {typ}): - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: self.result = i + x return @@ -388,8 +388,8 @@ def test_breaks_and_returns_inside_nested_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if a < 2: break return 6 diff --git a/tests/functional/codegen/features/test_assert.py b/tests/functional/codegen/features/test_assert.py index af189e6dca..df379d3f16 100644 --- a/tests/functional/codegen/features/test_assert.py +++ b/tests/functional/codegen/features/test_assert.py @@ -159,7 +159,7 @@ def test_assert_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: - for i in range(3): + for i: uint256 in range(3): assert x[i] < 5 return True """ @@ -179,7 +179,7 @@ def test_assert_with_reason_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: - for i in range(3): + for i: uint256 in range(3): assert x[i] < 5, "because reasons" return True """ diff --git a/tests/functional/codegen/features/test_internal_call.py b/tests/functional/codegen/features/test_internal_call.py index f10d22ec99..422f53fdeb 100644 --- a/tests/functional/codegen/features/test_internal_call.py +++ b/tests/functional/codegen/features/test_internal_call.py @@ -152,7 +152,7 @@ def _increment(): @external def returnten() -> int128: - for i in range(10): + for i: uint256 in range(10): self._increment() return self.counter """ diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 671d424d60..891ed5aebe 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -52,7 +52,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30, bound=30): + for i: int128 in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return @@ -147,7 +147,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30, bound=30): + for i: int128 in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index fcf71f12f0..72171dd4b5 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -125,7 +125,7 @@ def test_harder_decimal_test(get_contract_with_gas_estimation): @external def phooey(inp: decimal) -> decimal: x: decimal = 10000.0 - for i in range(4): + for i: uint256 in range(4): x = x * inp return x diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 70a68e3206..171c1c5394 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -1328,7 +1328,7 @@ def test_list_of_structs_arg(get_contract): @external def bar(_baz: DynArray[Foo, 3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): e: Foobar = _baz[i].z f: uint256 = convert(e, uint256) sum += _baz[i].x * _baz[i].y + f @@ -1397,7 +1397,7 @@ def test_list_of_nested_struct_arrays(get_contract): @external def bar(_bar: DynArray[Bar, 3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _bar[i].f[0].e.a[0] * _bar[i].f[1].e.a[1] return sum """ diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index b5b9538c20..ee287064e8 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -566,7 +566,7 @@ def test_list_of_structs_arg(get_contract): @external def bar(_baz: Foo[3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _baz[i].x * _baz[i].y return sum """ @@ -608,7 +608,7 @@ def test_list_of_nested_struct_arrays(get_contract): @external def bar(_bar: Bar[3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _bar[i].f[0].e.a[0] * _bar[i].f[1].e.a[1] return sum """ diff --git a/tests/functional/syntax/test_blockscope.py b/tests/functional/syntax/test_blockscope.py index 942aa3fa68..466b5509ca 100644 --- a/tests/functional/syntax/test_blockscope.py +++ b/tests/functional/syntax/test_blockscope.py @@ -33,7 +33,7 @@ def foo(choice: bool): @external def foo(choice: bool): - for i in range(4): + for i: int128 in range(4): a: int128 = 0 a = 1 """, @@ -41,7 +41,7 @@ def foo(choice: bool): @external def foo(choice: bool): - for i in range(4): + for i: int128 in range(4): a: int128 = 0 a += 1 """, diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index ffd2f1faa0..7089dee3bb 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -240,7 +240,7 @@ def test1(): @external @view def test(): - for i in range(CONST / 4): + for i: uint256 in range(CONST / 4): pass """, """ diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index a9c3ad5cab..2ba562ac1f 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -15,7 +15,7 @@ """ @external def foo(): - for a[1] in range(10): + for a[1]: uint256 in range(10): pass """, StructureException, @@ -26,7 +26,7 @@ def foo(): """ @external def bar(): - for i in range(1,2,bound=0): + for i: uint256 in range(1,2,bound=0): pass """, StructureException, @@ -38,7 +38,7 @@ def bar(): @external def foo(): x: uint256 = 100 - for _ in range(10, bound=x): + for _: uint256 in range(10, bound=x): pass """, StateAccessViolation, @@ -49,7 +49,7 @@ def foo(): """ @external def foo(): - for _ in range(10, 20, bound=5): + for _: uint256 in range(10, 20, bound=5): pass """, StructureException, @@ -60,7 +60,7 @@ def foo(): """ @external def foo(): - for _ in range(10, 20, bound=0): + for _: uint256 in range(10, 20, bound=0): pass """, StructureException, @@ -72,7 +72,7 @@ def foo(): @external def bar(): x:uint256 = 1 - for i in range(x,x+1,bound=2,extra=3): + for i: uint256 in range(x,x+1,bound=2,extra=3): pass """, ArgumentException, @@ -83,7 +83,7 @@ def bar(): """ @external def bar(): - for i in range(0): + for i: uint256 in range(0): pass """, StructureException, @@ -95,7 +95,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(x): + for i: uint256 in range(x): pass """, StateAccessViolation, @@ -107,7 +107,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(0, x): + for i: uint256 in range(0, x): pass """, StateAccessViolation, @@ -118,7 +118,7 @@ def bar(): """ @external def repeat(n: uint256) -> uint256: - for i in range(0, n * 10): + for i: uint256 in range(0, n * 10): pass return n """, @@ -131,7 +131,7 @@ def repeat(n: uint256) -> uint256: @external def bar(): x:uint256 = 1 - for i in range(0, x + 1): + for i: uint256 in range(0, x + 1): pass """, StateAccessViolation, @@ -142,7 +142,7 @@ def bar(): """ @external def bar(): - for i in range(2, 1): + for i: uint256 in range(2, 1): pass """, StructureException, @@ -154,7 +154,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(x, x): + for i: uint256 in range(x, x): pass """, StateAccessViolation, @@ -166,7 +166,7 @@ def bar(): @external def foo(): x: int128 = 5 - for i in range(x, x + 10): + for i: int128 in range(x, x + 10): pass """, StateAccessViolation, @@ -177,7 +177,7 @@ def foo(): """ @external def repeat(n: uint256) -> uint256: - for i in range(n, 6): + for i: uint256 in range(n, 6): pass return x """, @@ -190,7 +190,7 @@ def repeat(n: uint256) -> uint256: @external def foo(x: int128): y: int128 = 7 - for i in range(x, x + y): + for i: int128 in range(x, x + y): pass """, StateAccessViolation, @@ -201,7 +201,7 @@ def foo(x: int128): """ @external def bar(x: uint256): - for i in range(3, x): + for i: uint256 in range(3, x): pass """, StateAccessViolation, @@ -215,7 +215,7 @@ def bar(x: uint256): @external def foo(): - for i in range(FOO, BAR): + for i: uint256 in range(FOO, BAR): pass """, TypeMismatch, @@ -228,7 +228,7 @@ def foo(): @external def foo(): - for i in range(10, bound=FOO): + for i: int128 in range(10, bound=FOO): pass """, StructureException, @@ -259,34 +259,34 @@ def test_range_fail(bad_code, error_type, message, source_code): """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass """, """ @external def foo(): - for i in range(10, 20): + for i: uint256 in range(10, 20): pass """, """ @external def foo(): x: int128 = 5 - for i in range(1, x, bound=4): + for i: int128 in range(1, x, bound=4): pass """, """ @external def foo(): x: int128 = 5 - for i in range(x, bound=4): + for i: int128 in range(x, bound=4): pass """, """ @external def foo(): x: int128 = 5 - for i in range(0, x, bound=4): + for i: int128 in range(0, x, bound=4): pass """, """ @@ -295,7 +295,7 @@ def kick(): nonpayable foos: Foo[3] @external def kick_foos(): - for foo in self.foos: + for foo: Foo in self.foos: foo.kick() """, ] diff --git a/tests/unit/ast/test_pre_parser.py b/tests/unit/ast/test_pre_parser.py index 682c13ca84..020e83627c 100644 --- a/tests/unit/ast/test_pre_parser.py +++ b/tests/unit/ast/test_pre_parser.py @@ -173,7 +173,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): @pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples) def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version): mock_version("0.3.10") - settings, _, _ = pre_parse(code) + settings, _, _, _ = pre_parse(code) assert settings == pre_parse_settings From 42e06f5187ac8f3092a8b5c5895192e0e2851c9e Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 17:18:41 +0800 Subject: [PATCH 12/54] fix lint --- vyper/ast/parse.py | 16 +++++++++++++--- vyper/ast/pre_parser.py | 21 ++++++++++----------- vyper/compiler/phases.py | 2 +- vyper/semantics/analysis/local.py | 8 ++++---- 4 files changed, 28 insertions(+), 19 deletions(-) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 38be672dac..243d8b648f 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -400,9 +400,19 @@ def annotate_python_ast( resolved_path=resolved_path, ) visitor.visit(parsed_ast) - for k, v in loop_var_annotations.items(): - tokens = asttokens.ASTTokens(v["source_code"], tree=cast(Optional[python_ast.Module], v["parsed_ast"])) - visitor = AnnotatingVisitor(v["source_code"], {}, tokens, source_id, module_path=module_path, resolved_path=resolved_path) + + for _, v in loop_var_annotations.items(): + tokens = asttokens.ASTTokens( + v["source_code"], tree=cast(Optional[python_ast.Module], v["parsed_ast"]) + ) + visitor = AnnotatingVisitor( + v["source_code"], + {}, + tokens, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) visitor.visit(v["parsed_ast"]) return parsed_ast diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index da01dce080..507b4230e6 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -154,9 +154,9 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: if typ == NAME and string == "for": is_for_loop = True - #print("for loop!") - #print(token) - + # print("for loop!") + # print(token) + if is_for_loop: if typ == NAME and string == "in": loop_var_annotations[start[0]] = loop_var_annotation @@ -170,7 +170,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: continue elif after_loop_var and not (typ == NAME and string == "for"): - #print("adding to loop var: ", toks) + # print("adding to loop var: ", toks) loop_var_annotation.extend(toks) continue @@ -179,16 +179,15 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e for k, v in loop_var_annotations.items(): - - updated_v = untokenize(v) - #print("untokenized v: ", updated_v) + # print("untokenized v: ", updated_v) updated_v = updated_v.replace("\\", "") updated_v = updated_v.replace("\n", "") import textwrap - #print("updated v: ", textwrap.dedent(updated_v)) + + # print("updated v: ", textwrap.dedent(updated_v)) loop_var_annotations[k] = {"source_code": textwrap.dedent(updated_v)} - - #print("untokenized result: ", type(untokenize(result))) - #print("untokenized result decoded: ", untokenize(result).decode("utf-8")) + + # print("untokenized result: ", type(untokenize(result))) + # print("untokenized result decoded: ", untokenize(result).decode("utf-8")) return settings, modification_offsets, loop_var_annotations, untokenize(result).decode("utf-8") diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 4e6cc9df86..850adcfea3 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -2,7 +2,7 @@ import warnings from functools import cached_property from pathlib import Path, PurePath -from typing import Any, Optional +from typing import Optional from vyper import ast as vy_ast from vyper.codegen import module diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 7231ee6d5a..17351cbaf1 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Optional from vyper import ast as vy_ast from vyper.ast.metadata import NodeMetadata @@ -354,10 +354,10 @@ def visit_For(self, node): iter_annotation = loop_var_annotations.get(node.lineno).get("vy_ast") if not iter_annotation: raise StructureException("Iterator needs type annotation", node.iter) - + iter_annotation_node = iter_annotation.body[0].value iter_type = type_from_annotation(iter_annotation_node, DataLocation.MEMORY) - node.target._metadata["type"] = iter_type + node.target._metadata["type"] = iter_type if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -480,7 +480,7 @@ def visit_For(self, node): for typ, exc in zip(type_list, for_loop_exceptions) ), ) - + def visit_If(self, node): validate_expected_type(node.test, BoolT()) self.expr_visitor.visit(node.test, BoolT()) From e7a46127c32c993cc55b837492a75563e4259d1e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 6 Jan 2024 08:40:41 -0500 Subject: [PATCH 13/54] revert a change --- vyper/semantics/analysis/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 92b9186412..8e435f870f 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -485,13 +485,13 @@ def _load_import_helper( def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: - ast = vy_ast.parse_to_ast( + ret = vy_ast.parse_to_ast( file.source_code, source_id=file.source_id, module_path=str(file.path), resolved_path=str(file.resolved_path), ) - return ast + return ret # convert an import to a path (without suffix) From 6f6aceab95cee6c4928888d12d71e9e178a9e4e0 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 23:51:35 +0800 Subject: [PATCH 14/54] add visit_For in py ast parse --- vyper/ast/nodes.py | 2 +- vyper/ast/parse.py | 50 ++++++++++++++++++------------- vyper/semantics/analysis/local.py | 8 +---- 3 files changed, 31 insertions(+), 29 deletions(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index efab5117d4..9ad556b470 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1546,7 +1546,7 @@ class IfExp(ExprNode): class For(Stmt): - __slots__ = ("iter", "target", "body") + __slots__ = ("iter", "iter_type", "target", "body") _only_empty_fields = ("orelse",) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 243d8b648f..05c8e0bb72 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -87,13 +87,6 @@ def parse_to_ast_with_settings( module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint - for k, v in loop_var_annotations.items(): - loop_var_vy_ast = vy_ast.get_node(v["parsed_ast"]) - loop_var_annotations[k]["vy_ast"] = loop_var_vy_ast - del loop_var_annotations[k]["parsed_ast"] - - module._metadata["loop_var_annotations"] = loop_var_annotations - return settings, module @@ -125,12 +118,14 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets + _loop_var_annotations: dict[int, python_ast.AST] def __init__( self, source_code: str, modification_offsets: Optional[ModificationOffsets], tokens: asttokens.ASTTokens, + loop_var_annotations: dict[int, python_ast.AST], source_id: int, module_path: Optional[str] = None, resolved_path: Optional[str] = None, @@ -142,6 +137,7 @@ def __init__( self._source_code: str = source_code self.counter: int = 0 self._modification_offsets = {} + self._loop_var_annotations = loop_var_annotations if modification_offsets is not None: self._modification_offsets = modification_offsets @@ -244,6 +240,28 @@ def visit_Expr(self, node): return node + def visit_For(self, node): + """ + Annotate `For` nodes with the iterator's type annotation that was extracted + during pre-parsing. + """ + iter_type_info = self._loop_var_annotations.get(node.lineno) + if not iter_type_info: + raise SyntaxException( + "For loop iterator requires type annotation", + self._source_code, + node.iter.lineno, + node.iter.col_offset, + ) + + iter_type_ast = iter_type_info["parsed_ast"] + + self.generic_visit(node) + self.generic_visit(iter_type_ast) + node.iter_type = iter_type_ast.body[0].value + + return node + def visit_Subscript(self, node): """ Maintain consistency of `Subscript.slice` across python versions. @@ -384,6 +402,9 @@ def annotate_python_ast( The originating source code of the AST. modification_offsets : dict, optional A mapping of class names to their original class types. + loop_var_annotations: dict, optional + A mapping of line numbers of `For` nodes to the type annotation of the iterator + extracted during pre-parsing. Returns ------- @@ -395,24 +416,11 @@ def annotate_python_ast( source_code, modification_offsets, tokens, + loop_var_annotations, source_id, module_path=module_path, resolved_path=resolved_path, ) visitor.visit(parsed_ast) - for _, v in loop_var_annotations.items(): - tokens = asttokens.ASTTokens( - v["source_code"], tree=cast(Optional[python_ast.Module], v["parsed_ast"]) - ) - visitor = AnnotatingVisitor( - v["source_code"], - {}, - tokens, - source_id, - module_path=module_path, - resolved_path=resolved_path, - ) - visitor.visit(v["parsed_ast"]) - return parsed_ast diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 17351cbaf1..24fafc8ba7 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -350,13 +350,7 @@ def visit_For(self, node): if isinstance(node.iter, vy_ast.Subscript): raise StructureException("Cannot iterate over a nested list", node.iter) - loop_var_annotations = self.vyper_module._metadata.get("loop_var_annotations") - iter_annotation = loop_var_annotations.get(node.lineno).get("vy_ast") - if not iter_annotation: - raise StructureException("Iterator needs type annotation", node.iter) - - iter_annotation_node = iter_annotation.body[0].value - iter_type = type_from_annotation(iter_annotation_node, DataLocation.MEMORY) + iter_type = type_from_annotation(node.iter_type, DataLocation.MEMORY) node.target._metadata["type"] = iter_type if isinstance(node.iter, vy_ast.Call): From 578d47109e9c15d03e4135f5ff0c2a4ac8a72323 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 23:52:54 +0800 Subject: [PATCH 15/54] remove typechecker speculation --- vyper/semantics/analysis/local.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 24fafc8ba7..0e3fc2572f 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -429,9 +429,8 @@ def visit_For(self, node): ) try: - with NodeMetadata.enter_typechecker_speculation(): - for stmt in node.body: - self.visit(stmt) + for stmt in node.body: + self.visit(stmt) self.expr_visitor.visit(node.target, iter_type) From 07fcde2b99af87dfc17e25ca692ad521c513916a Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 6 Jan 2024 23:52:59 +0800 Subject: [PATCH 16/54] remove prints --- vyper/ast/pre_parser.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 507b4230e6..3079996c98 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -154,8 +154,6 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: if typ == NAME and string == "for": is_for_loop = True - # print("for loop!") - # print(token) if is_for_loop: if typ == NAME and string == "in": @@ -170,7 +168,6 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: continue elif after_loop_var and not (typ == NAME and string == "for"): - # print("adding to loop var: ", toks) loop_var_annotation.extend(toks) continue @@ -180,14 +177,10 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: for k, v in loop_var_annotations.items(): updated_v = untokenize(v) - # print("untokenized v: ", updated_v) updated_v = updated_v.replace("\\", "") updated_v = updated_v.replace("\n", "") import textwrap - # print("updated v: ", textwrap.dedent(updated_v)) loop_var_annotations[k] = {"source_code": textwrap.dedent(updated_v)} - # print("untokenized result: ", type(untokenize(result))) - # print("untokenized result decoded: ", untokenize(result).decode("utf-8")) return settings, modification_offsets, loop_var_annotations, untokenize(result).decode("utf-8") From 76c3d2d6011bcce620d2b9143192c2287ecc6b16 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 00:27:41 +0800 Subject: [PATCH 17/54] fix sqrt --- .../unit/semantics/analysis/test_for_loop.py | 55 +++---------------- vyper/builtins/functions.py | 2 +- 2 files changed, 8 insertions(+), 49 deletions(-) diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index e2c0f555af..e282f26b0f 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -22,7 +22,7 @@ def foo(): @internal def bar(): self.foo() - for i in self.a: + for i: uint256 in self.a: pass """ vyper_module = parse_to_ast(code) @@ -42,7 +42,7 @@ def foo(a: uint256[3]) -> uint256[3]: @internal def bar(): a: uint256[3] = [1,2,3] - for i in a: + for i: uint256 in a: self.foo(a) """ vyper_module = parse_to_ast(code) @@ -56,7 +56,7 @@ def test_modify_iterator(dummy_input_bundle): @internal def bar(): - for i in self.a: + for i: uint256 in self.a: self.a[0] = 1 """ vyper_module = parse_to_ast(code) @@ -70,7 +70,7 @@ def test_bad_keywords(dummy_input_bundle): @internal def bar(n: uint256): x: uint256 = 0 - for i in range(n, boundddd=10): + for i: uint256 in range(n, boundddd=10): x += i """ vyper_module = parse_to_ast(code) @@ -84,7 +84,7 @@ def test_bad_bound(dummy_input_bundle): @internal def bar(n: uint256): x: uint256 = 0 - for i in range(n, bound=n): + for i: uint256 in range(n, bound=n): x += i """ vyper_module = parse_to_ast(code) @@ -103,7 +103,7 @@ def foo(): @internal def bar(): - for i in self.a: + for i: uint256 in self.a: self.foo() """ vyper_module = parse_to_ast(code) @@ -126,50 +126,9 @@ def bar(): @internal def baz(): - for i in self.a: + for i: uint256 in self.a: self.bar() """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): validate_semantics(vyper_module, dummy_input_bundle) - - -iterator_inference_codes = [ - """ -@external -def main(): - for j in range(3): - x: uint256 = j - y: uint16 = j - """, # GH issue 3212 - """ -@external -def foo(): - for i in [1]: - a:uint256 = i - b:uint16 = i - """, # GH issue 3374 - """ -@external -def foo(): - for i in [1]: - for j in [1]: - a:uint256 = i - b:uint16 = i - """, # GH issue 3374 - """ -@external -def foo(): - for i in [1,2,3]: - for j in [1,2,3]: - b:uint256 = j + i - c:uint16 = i - """, # GH issue 3374 -] - - -@pytest.mark.parametrize("code", iterator_inference_codes) -def test_iterator_type_inference_checker(code, dummy_input_bundle): - vyper_module = parse_to_ast(code) - with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, dummy_input_bundle) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index c896fc7ef6..39d97c4abe 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2157,7 +2157,7 @@ def build_IR(self, expr, args, kwargs, context): z = x / 2.0 + 0.5 y: decimal = x - for i in range(256): + for i: uint256 in range(256): if z == y: break y = z From 9db7a369964243e42b58fb18236a8c256c2a9458 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 00:28:41 +0800 Subject: [PATCH 18/54] fix visit_Num --- vyper/ast/parse.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 05c8e0bb72..9184a8773a 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -255,9 +255,9 @@ def visit_For(self, node): ) iter_type_ast = iter_type_info["parsed_ast"] - - self.generic_visit(node) self.generic_visit(iter_type_ast) + self.generic_visit(node) + node.iter_type = iter_type_ast.body[0].value return node @@ -321,7 +321,11 @@ def visit_Num(self, node): """ # modify vyper AST type according to the format of the literal value self.generic_visit(node) - value = node.node_source_code + + # the type annotation of a for loop iterator is removed from the source + # code during pre-parsing, and therefore the `node_source_code` attribute + # of an integer in the type annotation would not be available e.g. DynArray[uint256, 3] + value = node.node_source_code if hasattr(node, "node_source_code") else node.n # deduce non base-10 types based on prefix if value.lower()[:2] == "0x": From e055ee911df747d25562f228b66dfcb0474c45f0 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 00:28:48 +0800 Subject: [PATCH 19/54] update comments --- vyper/ast/pre_parser.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 3079996c98..edd0aa149c 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -51,7 +51,7 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: VYPER_EXPRESSION_TYPES = {"log"} -def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: +def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict[int, dict[str, str]], str]: """ Re-formats a vyper source string into a python source string and performs some validation. More specifically, @@ -60,9 +60,11 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: * Validates "@version" pragma against current compiler version * Prevents direct use of python "class" keyword * Prevents use of python semi-colon statement separator + * Extracts type annotation of for loop iterators into a separate dictionary Also returns a mapping of detected interface and struct names to their - respective vyper class types ("interface" or "struct"). + respective vyper class types ("interface" or "struct"), and a mapping of line numbers + of for loops to the type annotation of their iterators. Parameters ---------- @@ -71,8 +73,12 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: Returns ------- - dict - Mapping of offsets where source was modified. + Settings + Compilation settings based on the directives in the source code + ModificationOffsets + A mapping of class names to their original class types. + dict[int, dict[str, Any]] + A mapping of line numbers of `For` nodes to the type annotation of the iterator str Reformatted python source string. """ From c7acdaae90e01e821fb959f5c79d4af19b55f883 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 00:28:55 +0800 Subject: [PATCH 20/54] fix tests --- tests/functional/codegen/types/test_bytes.py | 2 +- .../codegen/types/test_dynamic_array.py | 24 +++++++++---------- .../exceptions/test_argument_exception.py | 4 ++-- .../exceptions/test_constancy_exception.py | 6 ++--- tests/functional/syntax/test_list.py | 2 +- tests/unit/ast/nodes/test_hex.py | 2 +- .../ast/test_annotate_and_optimize_ast.py | 4 ++-- 7 files changed, 22 insertions(+), 22 deletions(-) diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 1ee9b8d835..882629de65 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -268,7 +268,7 @@ def test_zero_padding_with_private(get_contract): def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value - for _ in range(8): + for _: uint256 in range(8): y = (y << 8) | (x & 255) x >>= 8 return slice(convert(y, bytes32), 24, 8) diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 171c1c5394..e47eda6042 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -969,7 +969,7 @@ def foo() -> (uint256, uint256, uint256, uint256, uint256): my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array """, @@ -981,7 +981,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: some_var: uint256 @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.some_var = x # test that typechecker for append args works self.my_array.append(self.some_var) @@ -994,9 +994,9 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) - for x in xs: + for x: uint256 in xs: self.my_array.pop() return self.my_array """, @@ -1008,7 +1008,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> (DynArray[uint256, 5], uint256): - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array, self.my_array.pop() """, @@ -1020,7 +1020,7 @@ def foo(xs: DynArray[uint256, 5]) -> (DynArray[uint256, 5], uint256): my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> (uint256, DynArray[uint256, 5]): - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array.pop(), self.my_array """, @@ -1033,7 +1033,7 @@ def foo(xs: DynArray[uint256, 5]) -> (uint256, DynArray[uint256, 5]): def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] i: uint256 = 0 - for x in xs: + for x: uint256 in xs: if i >= len(xs) - 1: break ys.append(x) @@ -1049,7 +1049,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 6]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array """, @@ -1061,9 +1061,9 @@ def foo(xs: DynArray[uint256, 6]) -> DynArray[uint256, 5]: @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] - for x in xs: + for x: uint256 in xs: ys.append(x) - for x in xs: + for x: uint256 in xs: ys.pop() return ys """, @@ -1075,9 +1075,9 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] - for x in xs: + for x: uint256 in xs: ys.append(x) - for x in xs: + for x: uint256 in xs: ys.pop() ys.pop() # fail return ys diff --git a/tests/functional/syntax/exceptions/test_argument_exception.py b/tests/functional/syntax/exceptions/test_argument_exception.py index 0b7ec21bdb..4240aec8d2 100644 --- a/tests/functional/syntax/exceptions/test_argument_exception.py +++ b/tests/functional/syntax/exceptions/test_argument_exception.py @@ -80,13 +80,13 @@ def foo(): """ @external def foo(): - for i in range(): + for i: uint256 in range(): pass """, """ @external def foo(): - for i in range(1, 2, 3, 4): + for i: uint256 in range(1, 2, 3, 4): pass """, ] diff --git a/tests/functional/syntax/exceptions/test_constancy_exception.py b/tests/functional/syntax/exceptions/test_constancy_exception.py index 4bd0b4fcb9..7adf9538c7 100644 --- a/tests/functional/syntax/exceptions/test_constancy_exception.py +++ b/tests/functional/syntax/exceptions/test_constancy_exception.py @@ -57,7 +57,7 @@ def foo() -> int128: return 5 @external def bar(): - for i in range(self.foo(), self.foo() + 1): + for i: int128 in range(self.foo(), self.foo() + 1): pass""", """ glob: int128 @@ -67,13 +67,13 @@ def foo() -> int128: return 5 @external def bar(): - for i in [1,2,3,4,self.foo()]: + for i: int128 in [1,2,3,4,self.foo()]: pass""", """ @external def foo(): x: int128 = 5 - for i in range(x): + for i: int128 in range(x): pass""", """ f:int128 diff --git a/tests/functional/syntax/test_list.py b/tests/functional/syntax/test_list.py index db41de5526..3936f8c220 100644 --- a/tests/functional/syntax/test_list.py +++ b/tests/functional/syntax/test_list.py @@ -306,7 +306,7 @@ def foo(): @external def foo(): x: DynArray[uint256, 3] = [1, 2, 3] - for i in [[], []]: + for i: DynArray[uint256, 3] in [[], []]: x = i """, ] diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index d413340083..a6bc3147e6 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -24,7 +24,7 @@ def foo(): """ @external def foo(): - for i in [0x6b175474e89094c44da98b954eedeac495271d0F]: + for i: address in [0x6b175474e89094c44da98b954eedeac495271d0F]: pass """, """ diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 16ce6fe631..4147ee77b7 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -28,10 +28,10 @@ def foo() -> int128: def get_contract_info(source_code): - _, class_types, reformatted_code = pre_parse(source_code) + _, class_types, loop_var_annotations, reformatted_code = pre_parse(source_code) py_ast = python_ast.parse(reformatted_code) - annotate_python_ast(py_ast, reformatted_code, class_types) + annotate_python_ast(py_ast, reformatted_code, class_types, loop_var_annotations) return py_ast, reformatted_code From 0dd86fdb8b43ac267cfdfbf4671fbd08601355ce Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 00:29:50 +0800 Subject: [PATCH 21/54] fix lint --- tests/unit/semantics/analysis/test_for_loop.py | 7 +------ vyper/semantics/analysis/local.py | 1 - 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index e282f26b0f..ccd501e101 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -1,12 +1,7 @@ import pytest from vyper.ast import parse_to_ast -from vyper.exceptions import ( - ArgumentException, - ImmutableViolation, - StateAccessViolation, - TypeMismatch, -) +from vyper.exceptions import ArgumentException, ImmutableViolation, StateAccessViolation from vyper.semantics.analysis import validate_semantics diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 0e3fc2572f..0f6af0c01e 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,7 +1,6 @@ from typing import Optional from vyper import ast as vy_ast -from vyper.ast.metadata import NodeMetadata from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ExceptionList, From f64e6f1c119cbb007d0634bc269d6a4c7f199af4 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 00:45:02 +0800 Subject: [PATCH 22/54] fix mypy --- .../ast/test_annotate_and_optimize_ast.py | 4 ++-- vyper/ast/parse.py | 24 +++++++++---------- vyper/ast/pre_parser.py | 14 ++++++----- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 4147ee77b7..b202f6d8a3 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -28,10 +28,10 @@ def foo() -> int128: def get_contract_info(source_code): - _, class_types, loop_var_annotations, reformatted_code = pre_parse(source_code) + _, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code) py_ast = python_ast.parse(reformatted_code) - annotate_python_ast(py_ast, reformatted_code, class_types, loop_var_annotations) + annotate_python_ast(py_ast, reformatted_code, loop_var_annotations, class_types) return py_ast, reformatted_code diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 9184a8773a..3bbc24c073 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -23,7 +23,7 @@ def parse_to_ast_with_settings( module_path: Optional[str] = None, resolved_path: Optional[str] = None, add_fn_node: Optional[str] = None, -) -> tuple[Settings, vy_ast.Module, dict[int, dict[str, Any]]]: +) -> tuple[Settings, vy_ast.Module]: """ Parses a Vyper source string and generates basic Vyper AST nodes. @@ -54,7 +54,7 @@ def parse_to_ast_with_settings( """ if "\x00" in source_code: raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, class_types, loop_var_annotations, reformatted_code = pre_parse(source_code) + settings, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code) try: py_ast = python_ast.parse(reformatted_code) @@ -76,8 +76,8 @@ def parse_to_ast_with_settings( annotate_python_ast( py_ast, source_code, - class_types, loop_var_annotations, + class_types, source_id, module_path=module_path, resolved_path=resolved_path, @@ -118,14 +118,14 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets - _loop_var_annotations: dict[int, python_ast.AST] + _loop_var_annotations: dict[int, dict[str, Any]] def __init__( self, source_code: str, + loop_var_annotations: dict[int, dict[str, Any]], modification_offsets: Optional[ModificationOffsets], tokens: asttokens.ASTTokens, - loop_var_annotations: dict[int, python_ast.AST], source_id: int, module_path: Optional[str] = None, resolved_path: Optional[str] = None, @@ -325,10 +325,10 @@ def visit_Num(self, node): # the type annotation of a for loop iterator is removed from the source # code during pre-parsing, and therefore the `node_source_code` attribute # of an integer in the type annotation would not be available e.g. DynArray[uint256, 3] - value = node.node_source_code if hasattr(node, "node_source_code") else node.n + value = node.node_source_code if hasattr(node, "node_source_code") else None # deduce non base-10 types based on prefix - if value.lower()[:2] == "0x": + if value and value.lower()[:2] == "0x": if len(value) % 2: raise SyntaxException( "Hex notation requires an even number of digits", @@ -339,7 +339,7 @@ def visit_Num(self, node): node.ast_type = "Hex" node.n = value - elif value.lower()[:2] == "0b": + elif value and value.lower()[:2] == "0b": node.ast_type = "Bytes" mod = (len(value) - 2) % 8 if mod: @@ -389,8 +389,8 @@ def visit_UnaryOp(self, node): def annotate_python_ast( parsed_ast: python_ast.AST, source_code: str, + loop_var_annotations: dict[int, dict[str, Any]], modification_offsets: Optional[ModificationOffsets] = None, - loop_var_annotations: Optional[dict[int, python_ast.AST]] = None, source_id: int = 0, module_path: Optional[str] = None, resolved_path: Optional[str] = None, @@ -404,11 +404,11 @@ def annotate_python_ast( The AST to be annotated and optimized. source_code : str The originating source code of the AST. - modification_offsets : dict, optional - A mapping of class names to their original class types. loop_var_annotations: dict, optional A mapping of line numbers of `For` nodes to the type annotation of the iterator extracted during pre-parsing. + modification_offsets : dict, optional + A mapping of class names to their original class types. Returns ------- @@ -418,9 +418,9 @@ def annotate_python_ast( tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) visitor = AnnotatingVisitor( source_code, + loop_var_annotations, modification_offsets, tokens, - loop_var_annotations, source_id, module_path=module_path, resolved_path=resolved_path, diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index edd0aa149c..efe7caa135 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -1,6 +1,7 @@ import io import re from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize +from typing import Any from packaging.specifiers import InvalidSpecifier, SpecifierSet @@ -51,7 +52,7 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: VYPER_EXPRESSION_TYPES = {"log"} -def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict[int, dict[str, str]], str]: +def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], ModificationOffsets, str]: """ Re-formats a vyper source string into a python source string and performs some validation. More specifically, @@ -85,7 +86,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict[int, dict[ result = [] modification_offsets: ModificationOffsets = {} settings = Settings() - loop_var_annotations = {} + loop_var_annotation_tokens = {} try: code_bytes = code.encode("utf-8") @@ -93,7 +94,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict[int, dict[ is_for_loop = False after_loop_var = False - loop_var_annotation = [] + loop_var_annotation: list = [] for i in range(len(token_list)): token = token_list[i] @@ -163,7 +164,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict[int, dict[ if is_for_loop: if typ == NAME and string == "in": - loop_var_annotations[start[0]] = loop_var_annotation + loop_var_annotation_tokens[start[0]] = loop_var_annotation is_for_loop = False after_loop_var = False @@ -181,7 +182,8 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict[int, dict[ except TokenError as e: raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e - for k, v in loop_var_annotations.items(): + loop_var_annotations: dict[int, dict[str, Any]] = {} + for k, v in loop_var_annotation_tokens.items(): updated_v = untokenize(v) updated_v = updated_v.replace("\\", "") updated_v = updated_v.replace("\n", "") @@ -189,4 +191,4 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict[int, dict[ loop_var_annotations[k] = {"source_code": textwrap.dedent(updated_v)} - return settings, modification_offsets, loop_var_annotations, untokenize(result).decode("utf-8") + return settings, loop_var_annotations, modification_offsets, untokenize(result).decode("utf-8") From 5bc54c0c834c3a63dcdb86911aad99527dbf1954 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 01:06:38 +0800 Subject: [PATCH 23/54] simpliy visit_For --- .../features/iteration/test_for_in_list.py | 4 +- vyper/semantics/analysis/local.py | 92 ++++++------------- 2 files changed, 28 insertions(+), 68 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 33ad59370e..22fd7ccb43 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -648,7 +648,7 @@ def foo(x: int128): """ @external def foo(): - for i: uint256 in range(-3): + for i: int128 in range(-3): pass """, StructureException, @@ -776,7 +776,7 @@ def test_for() -> int128: a = i return a """, - TypeMismatch, + InvalidType, ), ( """ diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 0f6af0c01e..76b139b055 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -6,7 +6,6 @@ ExceptionList, FunctionDeclarationException, ImmutableViolation, - InvalidOperation, InvalidType, IteratorException, NonPayableViolation, @@ -39,7 +38,6 @@ EventT, FlagT, HashMapT, - IntegerT, SArrayT, StringT, StructT, @@ -350,7 +348,6 @@ def visit_For(self, node): raise StructureException("Cannot iterate over a nested list", node.iter) iter_type = type_from_annotation(node.iter_type, DataLocation.MEMORY) - node.target._metadata["type"] = iter_type if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -358,7 +355,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - type_list = _analyse_range_call(node.iter) + _analyse_range_call(node.iter, iter_type) else: # iteration over a variable or literal list @@ -366,14 +363,10 @@ def visit_For(self, node): if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) - type_list = [ - i.value_type - for i in get_possible_types_from_node(node.iter) - if isinstance(i, (DArrayT, SArrayT)) - ] - - if not type_list: - raise InvalidType("Not an iterable type", node.iter) + if not any( + isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter) + ): + raise InvalidType("Not an iterable type", node.iter) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): # check for references to the iterated value within the body of the loop @@ -420,58 +413,31 @@ def visit_For(self, node): if not isinstance(node.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target) - for_loop_exceptions = [] iter_name = node.target.id with self.namespace.enter_scope(): self.namespace[iter_name] = VarInfo( iter_type, modifiability=Modifiability.RUNTIME_CONSTANT ) - try: - for stmt in node.body: - self.visit(stmt) - - self.expr_visitor.visit(node.target, iter_type) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - iter_type = get_exact_type_from_node(node.iter) - # note CMC 2023-10-23: slightly redundant with how type_list is computed - validate_expected_type(node.target, iter_type.value_type) - self.expr_visitor.visit(node.iter, iter_type) - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, iter_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, iter_type) - - except (TypeMismatch, InvalidOperation) as exc: - for_loop_exceptions.append(exc) - else: - # success -- do not enter error handling section - return - - # failed to find a good type. bail out - if len(set(str(i) for i in for_loop_exceptions)) == 1: - # if every attempt at type checking raised the same exception - raise for_loop_exceptions[0] - - # return an aggregate TypeMismatch that shows all possible exceptions - # depending on which type is used - types_str = [str(i) for i in type_list] - given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" - raise TypeMismatch( - f"Iterator value '{iter_name}' may be cast as {given_str}, " - "but type checking fails with all possible types:", - node, - *( - (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0]) - for typ, exc in zip(type_list, for_loop_exceptions) - ), - ) + for stmt in node.body: + self.visit(stmt) + + self.expr_visitor.visit(node.target, iter_type) + + if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): + iter_type = get_exact_type_from_node(node.iter) + # note CMC 2023-10-23: slightly redundant with how type_list is computed + validate_expected_type(node.target, iter_type.value_type) + self.expr_visitor.visit(node.iter, iter_type) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) + if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + for a in node.iter.args: + self.expr_visitor.visit(a, iter_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, iter_type) def visit_If(self, node): validate_expected_type(node.test, BoolT()) @@ -748,7 +714,7 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: +def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType) -> list[VyperType]: """ Check that the arguments to a range() call are valid. :param node: call to range() @@ -761,11 +727,7 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: all_args = (start, end, *kwargs.values()) for arg1 in all_args: - validate_expected_type(arg1, IntegerT.any()) - - type_list = get_common_types(*all_args) - if not type_list: - raise TypeMismatch("Iterator values are of different types", node) + validate_expected_type(arg1, iter_type) if "bound" in kwargs: bound = kwargs["bound"] @@ -785,5 +747,3 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: raise StateAccessViolation(error, arg) if end.value <= start.value: raise StructureException("End must be greater than start", end) - - return type_list From b951b47833ebfe8ba72c9f2d1366ff92f55f653b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 6 Jan 2024 10:07:57 -0500 Subject: [PATCH 24/54] rewrite for loop slurper with a small state machine --- vyper/ast/pre_parser.py | 97 +++++++++++++++++++++++++++-------------- 1 file changed, 64 insertions(+), 33 deletions(-) diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index efe7caa135..d4115ab2b3 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -1,4 +1,5 @@ import io +import enum import re from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize from typing import Any @@ -43,6 +44,54 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: start, ) +class ForParserState(enum.Enum): + NOT_RUNNING = enum.auto() + START_SOON = enum.auto() + RUNNING = enum.auto() + +# a simple state machine which allows us to handle loop variable annotations +# (which are rejected by the python parser due to pep-526, so we scoop up the +# tokens between `:` and `in` and parse them and add them back in later). +class ForParser: + def __init__(self): + self.annotations = {} + self._current_annotation = None + + self._state = ForParserState.NOT_RUNNING + self._current_for_loop = None + + def consume(self, token): + # state machine: we can start slurping tokens soon + if token.type == NAME and token.string == "for": + # note: self._is_running should be false here, but we don't sanity + # check here as that should be an error the parser will handle. + self._state = ForParserState.START_SOON + self._current_for_loop = token.start + + if self._state == ForParserState.NOT_RUNNING: + return False + + # state machine: start slurping tokens + if token.type == OP and token.string == ":": + self._state = ForParserState.RUNNING + assert self._current_annotation is None, (self._current_for_loop, self._current_annotation) + self._current_annotation = [] + return False + + if self._state != ForParserState.RUNNING: + return False + + # state machine: end slurping tokens + if token.type == NAME and token.string == "in": + self._state = ForParserState.NOT_RUNNING + self.annotations[self._current_for_loop] = self._current_annotation + self._current_annotation = None + return False + + # slurp the token + self._current_annotation.append(token) + return True + # compound statements that are replaced with `class` # TODO remove enum in favor of flag @@ -86,18 +135,13 @@ def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], Modificat result = [] modification_offsets: ModificationOffsets = {} settings = Settings() - loop_var_annotation_tokens = {} + for_parser = ForParser() try: code_bytes = code.encode("utf-8") token_list = list(tokenize(io.BytesIO(code_bytes).readline)) - is_for_loop = False - after_loop_var = False - loop_var_annotation: list = [] - - for i in range(len(token_list)): - token = token_list[i] + for token in token_list: toks = [token] typ = token.type @@ -159,36 +203,23 @@ def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], Modificat if (typ, string) == (OP, ";"): raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1]) - if typ == NAME and string == "for": - is_for_loop = True - - if is_for_loop: - if typ == NAME and string == "in": - loop_var_annotation_tokens[start[0]] = loop_var_annotation - - is_for_loop = False - after_loop_var = False - loop_var_annotation = [] - - elif (typ, string) == (OP, ":"): - after_loop_var = True - continue - - elif after_loop_var and not (typ == NAME and string == "for"): - loop_var_annotation.extend(toks) - continue + if not for_parser.consume(token): + result.extend(toks) - result.extend(toks) except TokenError as e: raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e - loop_var_annotations: dict[int, dict[str, Any]] = {} - for k, v in loop_var_annotation_tokens.items(): + for_loop_annotations = {} + for k, v in for_parser.annotations.items(): updated_v = untokenize(v) - updated_v = updated_v.replace("\\", "") - updated_v = updated_v.replace("\n", "") - import textwrap + # print("untokenized v: ", updated_v) + # updated_v = updated_v.replace("\\", "") + # updated_v = updated_v.replace("\n", "") + # import textwrap - loop_var_annotations[k] = {"source_code": textwrap.dedent(updated_v)} + # print("updated v: ", textwrap.dedent(updated_v)) + for_loop_annotations[k] = updated_v - return settings, loop_var_annotations, modification_offsets, untokenize(result).decode("utf-8") + # print("untokenized result: ", type(untokenize(result))) + # print("untokenized result decoded: ", untokenize(result).decode("utf-8")) + return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") From 3c5c0cb40e98fd0a1d121d7a1cbef6fb1948b8a0 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 6 Jan 2024 13:08:54 -0500 Subject: [PATCH 25/54] rewrite visit_For, use AnnAssign for the target add some more error messages --- vyper/ast/nodes.py | 2 +- vyper/ast/parse.py | 100 +++++++++++++++++------------- vyper/ast/pre_parser.py | 46 +++++++------- vyper/codegen/stmt.py | 4 +- vyper/semantics/analysis/local.py | 14 ++--- 5 files changed, 90 insertions(+), 76 deletions(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 9ad556b470..fffd3ca7cd 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1546,7 +1546,7 @@ class IfExp(ExprNode): class For(Stmt): - __slots__ = ("iter", "iter_type", "target", "body") + __slots__ = ("target", "iter", "body") _only_empty_fields = ("orelse",) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 3bbc24c073..1e869dfb87 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -54,13 +54,9 @@ def parse_to_ast_with_settings( """ if "\x00" in source_code: raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code) + settings, class_types, for_loop_annotations, reformatted_code = pre_parse(source_code) try: py_ast = python_ast.parse(reformatted_code) - - for k, v in loop_var_annotations.items(): - parsed_v = python_ast.parse(v["source_code"]) - loop_var_annotations[k]["parsed_ast"] = parsed_v except SyntaxError as e: # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e @@ -76,13 +72,16 @@ def parse_to_ast_with_settings( annotate_python_ast( py_ast, source_code, - loop_var_annotations, class_types, + for_loop_annotations, source_id, module_path=module_path, resolved_path=resolved_path, ) + # postcondition: consumed all the for loop annotations + assert len(for_loop_annotations) == 0 + # Convert to Vyper AST. module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint @@ -123,8 +122,8 @@ class AnnotatingVisitor(python_ast.NodeTransformer): def __init__( self, source_code: str, - loop_var_annotations: dict[int, dict[str, Any]], - modification_offsets: Optional[ModificationOffsets], + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, tokens: asttokens.ASTTokens, source_id: int, module_path: Optional[str] = None, @@ -134,12 +133,11 @@ def __init__( self._source_id = source_id self._module_path = module_path self._resolved_path = resolved_path - self._source_code: str = source_code + self._source_code = source_code + self._modification_offsets = modification_offsets + self._for_loop_annotations = for_loop_annotations + self.counter: int = 0 - self._modification_offsets = {} - self._loop_var_annotations = loop_var_annotations - if modification_offsets is not None: - self._modification_offsets = modification_offsets def generic_visit(self, node): """ @@ -221,6 +219,45 @@ def visit_ClassDef(self, node): node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)] return node + def visit_For(self, node): + """ + Visit a For node, splicing in the loop variable annotation provided by + the pre-parser + """ + raw_annotation = self._for_loop_annotations.pop((node.lineno, node.col_offset)) + + if not raw_annotation: + # a common case for people migrating to 0.4.0, provide a more + # specific error message than "invalid type annotation" + raise SyntaxException( + "missing type annotation\n\n" + "(hint: did you mean something like " + f"`for {node.target.id}: uint256 in ...`?)\n", + self._source_code, + node.lineno, + node.col_offset, + ) + + try: + annotation = python_ast.parse(raw_annotation, mode="eval") + except SyntaxError as e: + raise SyntaxException( + "invalid type annotation", self._source_code, node.lineno, node.col_offset + ) from e + + assert isinstance(annotation, python_ast.Expression) + annotation = annotation.body + + node.target_annotation = annotation + + old_target = node.target + new_target = python_ast.AnnAssign(target=old_target, annotation=annotation, simple=1) + node.target = new_target + + self.generic_visit(node) + + return node + def visit_Expr(self, node): """ Convert the `Yield` node into a Vyper-specific node type. @@ -240,28 +277,6 @@ def visit_Expr(self, node): return node - def visit_For(self, node): - """ - Annotate `For` nodes with the iterator's type annotation that was extracted - during pre-parsing. - """ - iter_type_info = self._loop_var_annotations.get(node.lineno) - if not iter_type_info: - raise SyntaxException( - "For loop iterator requires type annotation", - self._source_code, - node.iter.lineno, - node.iter.col_offset, - ) - - iter_type_ast = iter_type_info["parsed_ast"] - self.generic_visit(iter_type_ast) - self.generic_visit(node) - - node.iter_type = iter_type_ast.body[0].value - - return node - def visit_Subscript(self, node): """ Maintain consistency of `Subscript.slice` across python versions. @@ -322,13 +337,10 @@ def visit_Num(self, node): # modify vyper AST type according to the format of the literal value self.generic_visit(node) - # the type annotation of a for loop iterator is removed from the source - # code during pre-parsing, and therefore the `node_source_code` attribute - # of an integer in the type annotation would not be available e.g. DynArray[uint256, 3] - value = node.node_source_code if hasattr(node, "node_source_code") else None + value = node.node_source_code # deduce non base-10 types based on prefix - if value and value.lower()[:2] == "0x": + if value.lower()[:2] == "0x": if len(value) % 2: raise SyntaxException( "Hex notation requires an even number of digits", @@ -339,7 +351,7 @@ def visit_Num(self, node): node.ast_type = "Hex" node.n = value - elif value and value.lower()[:2] == "0b": + elif value.lower()[:2] == "0b": node.ast_type = "Bytes" mod = (len(value) - 2) % 8 if mod: @@ -389,8 +401,8 @@ def visit_UnaryOp(self, node): def annotate_python_ast( parsed_ast: python_ast.AST, source_code: str, - loop_var_annotations: dict[int, dict[str, Any]], - modification_offsets: Optional[ModificationOffsets] = None, + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, source_id: int = 0, module_path: Optional[str] = None, resolved_path: Optional[str] = None, @@ -418,8 +430,8 @@ def annotate_python_ast( tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) visitor = AnnotatingVisitor( source_code, - loop_var_annotations, modification_offsets, + for_loop_annotations, tokens, source_id, module_path=module_path, diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index d4115ab2b3..10f895d9b0 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -1,5 +1,5 @@ -import io import enum +import io import re from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize from typing import Any @@ -44,16 +44,19 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: start, ) + class ForParserState(enum.Enum): NOT_RUNNING = enum.auto() START_SOON = enum.auto() RUNNING = enum.auto() + # a simple state machine which allows us to handle loop variable annotations # (which are rejected by the python parser due to pep-526, so we scoop up the # tokens between `:` and `in` and parse them and add them back in later). class ForParser: - def __init__(self): + def __init__(self, code): + self._code = code self.annotations = {} self._current_annotation = None @@ -74,20 +77,27 @@ def consume(self, token): # state machine: start slurping tokens if token.type == OP and token.string == ":": self._state = ForParserState.RUNNING - assert self._current_annotation is None, (self._current_for_loop, self._current_annotation) - self._current_annotation = [] - return False - if self._state != ForParserState.RUNNING: - return False + # sanity check -- this should never really happen, but if it does, + # try to raise an exception which pinpoints the source. + if self._current_annotation is not None: + raise SyntaxException( + "for loop parse error", self._code, token.start[0], token.start[1] + ) + + self._current_annotation = [] + return True # do not add ":" to tokens. # state machine: end slurping tokens if token.type == NAME and token.string == "in": self._state = ForParserState.NOT_RUNNING - self.annotations[self._current_for_loop] = self._current_annotation + self.annotations[self._current_for_loop] = self._current_annotation or [] self._current_annotation = None return False + if self._state != ForParserState.RUNNING: + return False + # slurp the token self._current_annotation.append(token) return True @@ -101,7 +111,7 @@ def consume(self, token): VYPER_EXPRESSION_TYPES = {"log"} -def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], ModificationOffsets, str]: +def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: """ Re-formats a vyper source string into a python source string and performs some validation. More specifically, @@ -127,15 +137,15 @@ def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], Modificat Compilation settings based on the directives in the source code ModificationOffsets A mapping of class names to their original class types. - dict[int, dict[str, Any]] - A mapping of line numbers of `For` nodes to the type annotation of the iterator + dict[tuple[int, int], str] + A mapping of line/column offsets of `For` nodes to the annotation of the for loop target str Reformatted python source string. """ result = [] modification_offsets: ModificationOffsets = {} settings = Settings() - for_parser = ForParser() + for_parser = ForParser(code) try: code_bytes = code.encode("utf-8") @@ -211,15 +221,7 @@ def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], Modificat for_loop_annotations = {} for k, v in for_parser.annotations.items(): - updated_v = untokenize(v) - # print("untokenized v: ", updated_v) - # updated_v = updated_v.replace("\\", "") - # updated_v = updated_v.replace("\n", "") - # import textwrap - - # print("updated v: ", textwrap.dedent(updated_v)) - for_loop_annotations[k] = updated_v + v_source = untokenize(v).replace("\\", "").strip() + for_loop_annotations[k] = v_source - # print("untokenized result: ", type(untokenize(result))) - # print("untokenized result decoded: ", untokenize(result).decode("utf-8")) return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index bc29a79734..5487421a06 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -297,11 +297,11 @@ def _parse_For_list(self): with self.context.range_scope(): iter_list = Expr(self.stmt.iter, self.context).ir_node - target_type = self.stmt.target._metadata["type"] + target_type = self.stmt.target.target._metadata["type"] assert target_type == iter_list.typ.value_type # user-supplied name for loop variable - varname = self.stmt.target.id + varname = self.stmt.target.target.id loop_var = IRnode.from_list( self.context.new_variable(varname, target_type), typ=target_type, location=MEMORY ) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 76b139b055..effc545a0c 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -347,7 +347,10 @@ def visit_For(self, node): if isinstance(node.iter, vy_ast.Subscript): raise StructureException("Cannot iterate over a nested list", node.iter) - iter_type = type_from_annotation(node.iter_type, DataLocation.MEMORY) + if not isinstance(node.target, vy_ast.AnnAssign): + raise StructureException("Invalid syntax for loop iterator", node.target) + + iter_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -410,10 +413,7 @@ def visit_For(self, node): call_node, ) - if not isinstance(node.target, vy_ast.Name): - raise StructureException("Invalid syntax for loop iterator", node.target) - - iter_name = node.target.id + iter_name = node.target.target.id with self.namespace.enter_scope(): self.namespace[iter_name] = VarInfo( iter_type, modifiability=Modifiability.RUNTIME_CONSTANT @@ -422,7 +422,7 @@ def visit_For(self, node): for stmt in node.body: self.visit(stmt) - self.expr_visitor.visit(node.target, iter_type) + self.expr_visitor.visit(node.target.target, iter_type) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): iter_type = get_exact_type_from_node(node.iter) @@ -714,7 +714,7 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType) -> list[VyperType]: +def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType): """ Check that the arguments to a range() call are valid. :param node: call to range() From fe6721c730c15d0c732eb34a74c137ff8411f31c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 6 Jan 2024 13:23:42 -0500 Subject: [PATCH 26/54] add a comment --- vyper/ast/pre_parser.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 10f895d9b0..b39e1b5d0e 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -221,7 +221,9 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: for_loop_annotations = {} for k, v in for_parser.annotations.items(): - v_source = untokenize(v).replace("\\", "").strip() + v_source = untokenize(v) + # untokenize adds backslashes and whitespace, strip them. + v_source = v_source.replace("\\", "").strip() for_loop_annotations[k] = v_source return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") From f74fe50841add896404057ebd9fe908a19cd6f9d Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 6 Jan 2024 13:24:08 -0500 Subject: [PATCH 27/54] fix lint --- vyper/ast/pre_parser.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index b39e1b5d0e..aa6dd37271 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -2,7 +2,6 @@ import io import re from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize -from typing import Any from packaging.specifiers import InvalidSpecifier, SpecifierSet From c5bcb9bfd74dc0fb372aa4fab004555d0e983dce Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 10:11:30 +0800 Subject: [PATCH 28/54] fix comment in for parser --- vyper/ast/pre_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index aa6dd37271..c7e6f3698f 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -65,7 +65,7 @@ def __init__(self, code): def consume(self, token): # state machine: we can start slurping tokens soon if token.type == NAME and token.string == "for": - # note: self._is_running should be false here, but we don't sanity + # note: self._state should be NOT_RUNNING here, but we don't sanity # check here as that should be an error the parser will handle. self._state = ForParserState.START_SOON self._current_for_loop = token.start From 01cc34b4c2bd059914a2c6c802bf466e098a912a Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 10:53:57 +0800 Subject: [PATCH 29/54] fix For codegen --- vyper/codegen/stmt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 5487421a06..806d9cc9c4 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -270,7 +270,7 @@ def _parse_For_range(self): if rounds_bound < 1: # pragma: nocover raise TypeCheckFailure("unreachable: unchecked 0 bound") - varname = self.stmt.target.id + varname = self.stmt.target.target.id i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=UINT256_T) iptr = self.context.new_variable(varname, iter_typ) From 2948dc723ae60f8aaa970dc691fb4b73cf220208 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 10:54:11 +0800 Subject: [PATCH 30/54] call ASTTokens ctor for side effects --- vyper/ast/parse.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 1e869dfb87..b92c9d543c 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -240,6 +240,12 @@ def visit_For(self, node): try: annotation = python_ast.parse(raw_annotation, mode="eval") + # call ASTTokens ctor for its side effects of enhancing the Python AST tree + # with token and source code information, specifically the `first_token` and + # `last_token` attributes that are accessed in `generic_visit`. + asttokens.ASTTokens( + raw_annotation, tree=cast(Optional[python_ast.Module], annotation) + ) except SyntaxError as e: raise SyntaxException( "invalid type annotation", self._source_code, node.lineno, node.col_offset @@ -248,8 +254,6 @@ def visit_For(self, node): assert isinstance(annotation, python_ast.Expression) annotation = annotation.body - node.target_annotation = annotation - old_target = node.target new_target = python_ast.AnnAssign(target=old_target, annotation=annotation, simple=1) node.target = new_target From 54c31b8906512ac0a6d6cf48f75fa794d23543a4 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 10:54:19 +0800 Subject: [PATCH 31/54] fix visit_For semantics --- vyper/semantics/analysis/local.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index effc545a0c..e6b86ff135 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -350,7 +350,7 @@ def visit_For(self, node): if not isinstance(node.target, vy_ast.AnnAssign): raise StructureException("Invalid syntax for loop iterator", node.target) - iter_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) + iter_item_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -358,7 +358,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - _analyse_range_call(node.iter, iter_type) + _analyse_range_call(node.iter, iter_item_type) else: # iteration over a variable or literal list @@ -416,28 +416,28 @@ def visit_For(self, node): iter_name = node.target.target.id with self.namespace.enter_scope(): self.namespace[iter_name] = VarInfo( - iter_type, modifiability=Modifiability.RUNTIME_CONSTANT + iter_item_type, modifiability=Modifiability.RUNTIME_CONSTANT ) for stmt in node.body: self.visit(stmt) - self.expr_visitor.visit(node.target.target, iter_type) + self.expr_visitor.visit(node.target.target, iter_item_type) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): iter_type = get_exact_type_from_node(node.iter) # note CMC 2023-10-23: slightly redundant with how type_list is computed - validate_expected_type(node.target, iter_type.value_type) + validate_expected_type(node.target.target, iter_type.value_type) self.expr_visitor.visit(node.iter, iter_type) if isinstance(node.iter, vy_ast.List): len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) + self.expr_visitor.visit(node.iter, SArrayT(iter_item_type, len_)) if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": for a in node.iter.args: - self.expr_visitor.visit(a, iter_type) + self.expr_visitor.visit(a, iter_item_type) for a in node.iter.keywords: if a.arg == "bound": - self.expr_visitor.visit(a.value, iter_type) + self.expr_visitor.visit(a.value, iter_item_type) def visit_If(self, node): validate_expected_type(node.test, BoolT()) @@ -714,7 +714,7 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType): +def _analyse_range_call(node: vy_ast.Call, iter_item_type: VyperType): """ Check that the arguments to a range() call are valid. :param node: call to range() @@ -727,7 +727,7 @@ def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType): all_args = (start, end, *kwargs.values()) for arg1 in all_args: - validate_expected_type(arg1, iter_type) + validate_expected_type(arg1, iter_item_type) if "bound" in kwargs: bound = kwargs["bound"] From ef7841fed51ec50424e0569bc2f5628a4f7da09a Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 11:34:01 +0800 Subject: [PATCH 32/54] replace ctor with mark_tokens --- vyper/ast/parse.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index b92c9d543c..21b801fd50 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -240,12 +240,10 @@ def visit_For(self, node): try: annotation = python_ast.parse(raw_annotation, mode="eval") - # call ASTTokens ctor for its side effects of enhancing the Python AST tree - # with token and source code information, specifically the `first_token` and - # `last_token` attributes that are accessed in `generic_visit`. - asttokens.ASTTokens( - raw_annotation, tree=cast(Optional[python_ast.Module], annotation) - ) + # enhance the Python AST tree with token and source code information, specifically the + # `first_token` and `last_token` attributes that are accessed in `generic_visit`. + tokens = asttokens.ASTTokens(raw_annotation) + tokens.mark_tokens(annotation) except SyntaxError as e: raise SyntaxException( "invalid type annotation", self._source_code, node.lineno, node.col_offset From 1caba88b44afd524236b961acacd2123525744ce Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 11:48:38 +0800 Subject: [PATCH 33/54] fix more codegen --- vyper/codegen/stmt.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 806d9cc9c4..03f3e4450d 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -33,7 +33,7 @@ ) from vyper.semantics.types import DArrayT, MemberFunctionT from vyper.semantics.types.function import ContractFunctionT -from vyper.semantics.types.shortcuts import INT256_T, UINT256_T +from vyper.semantics.types.shortcuts import UINT256_T class Stmt: @@ -231,11 +231,8 @@ def parse_For(self): return self._parse_For_list() def _parse_For_range(self): - # TODO make sure type always gets annotated - if "type" in self.stmt.target._metadata: - iter_typ = self.stmt.target._metadata["type"] - else: - iter_typ = INT256_T + assert "type" in self.stmt.target.target._metadata + iter_typ = self.stmt.target.target._metadata["type"] # Get arg0 for_iter: vy_ast.Call = self.stmt.iter @@ -271,7 +268,7 @@ def _parse_For_range(self): raise TypeCheckFailure("unreachable: unchecked 0 bound") varname = self.stmt.target.target.id - i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=UINT256_T) + i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=iter_typ) iptr = self.context.new_variable(varname, iter_typ) self.context.forvars[varname] = True From 62208d5f512f86e2b3a4a01b418f99597a315b3b Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 11:48:43 +0800 Subject: [PATCH 34/54] fix tests --- tests/unit/compiler/asm/test_asm_optimizer.py | 2 +- tests/unit/compiler/test_source_map.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index 44b823757c..b2851e908a 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -58,7 +58,7 @@ def ctor_only(): @internal def runtime_only(): - for i in range(10): + for i: uint256 in range(10): self.s += 1 @external diff --git a/tests/unit/compiler/test_source_map.py b/tests/unit/compiler/test_source_map.py index c9a152b09c..5b478dd2aa 100644 --- a/tests/unit/compiler/test_source_map.py +++ b/tests/unit/compiler/test_source_map.py @@ -6,7 +6,7 @@ @internal def _baz(a: int128) -> int128: b: int128 = a - for i in range(2, 5): + for i: int128 in range(2, 5): b *= i if b > 31337: break From c140574816eafa9460ef73c0d9571c7abb7188c0 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 13:00:43 +0800 Subject: [PATCH 35/54] improve diagnostics for For --- vyper/semantics/analysis/local.py | 33 ++++++++++++++++++------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index e6b86ff135..4246fa6760 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -347,8 +347,8 @@ def visit_For(self, node): if isinstance(node.iter, vy_ast.Subscript): raise StructureException("Cannot iterate over a nested list", node.iter) - if not isinstance(node.target, vy_ast.AnnAssign): - raise StructureException("Invalid syntax for loop iterator", node.target) + if not isinstance(node.target.target, vy_ast.Name): + raise StructureException("Invalid syntax for loop iterator", node.target.target) iter_item_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) @@ -723,27 +723,32 @@ def _analyse_range_call(node: vy_ast.Call, iter_item_type: VyperType): validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args - start, end = [i.get_folded_value() if i.has_folded_value else i for i in (start, end)] + folded_start, folded_end = [ + i.get_folded_value() if i.has_folded_value else i for i in (start, end) + ] - all_args = (start, end, *kwargs.values()) - for arg1 in all_args: - validate_expected_type(arg1, iter_item_type) + all_args_unfolded = (start, end, *kwargs.values()) + all_args_folded = (folded_start, folded_end, *kwargs.values()) + for unfolded_arg, folded_arg in zip(all_args_unfolded, all_args_folded): + try: + validate_expected_type(folded_arg, iter_item_type) + except VyperException as e: + raise InvalidType(str(e), unfolded_arg) if "bound" in kwargs: bound = kwargs["bound"] - if bound.has_folded_value: - bound = bound.get_folded_value() - if not isinstance(bound, vy_ast.Num): + folded_bound = bound.get_folded_value() if bound.has_folded_value else bound + if not isinstance(folded_bound, vy_ast.Num): raise StateAccessViolation("Bound must be a literal", bound) - if bound.value <= 0: + if folded_bound.value <= 0: raise StructureException("Bound must be at least 1", bound) - if isinstance(start, vy_ast.Num) and isinstance(end, vy_ast.Num): + if isinstance(folded_start, vy_ast.Num) and isinstance(folded_end, vy_ast.Num): error = "Please remove the `bound=` kwarg when using range with constants" raise StructureException(error, bound) else: - for arg in (start, end): - if not isinstance(arg, vy_ast.Num): + for arg, folded_arg in zip((start, end), (folded_start, folded_end)): + if not isinstance(folded_arg, vy_ast.Num): error = "Value must be a literal integer, unless a bound is specified" raise StateAccessViolation(error, arg) - if end.value <= start.value: + if folded_end.value <= folded_start.value: raise StructureException("End must be greater than start", end) From 58c135694c48b729bc1d567bb29c5f382fd3f71b Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 13:00:47 +0800 Subject: [PATCH 36/54] update tests --- tests/functional/syntax/test_for_range.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index 2ba562ac1f..f0c47035e4 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -5,9 +5,9 @@ from vyper import compiler from vyper.exceptions import ( ArgumentException, + InvalidType, StateAccessViolation, StructureException, - TypeMismatch, ) fail_list = [ @@ -218,9 +218,15 @@ def foo(): for i: uint256 in range(FOO, BAR): pass """, - TypeMismatch, - "Iterator values are of different types", - "range(FOO, BAR)", + InvalidType, + """Expected uint256 but literal can only be cast as int128. + line 2:24 + 1 + ---> 2 FOO: constant(int128) = 3 + -------------------------------^ + 3 BAR: constant(uint256) = 7 +""", # noqa: W291 + "FOO", ), ( """ @@ -233,7 +239,7 @@ def foo(): """, StructureException, "Bound must be at least 1", - "-1", + "FOO", ), ] From d068ebb172a4eabc58902b5f5f212b6be1e689e1 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 13:05:05 +0800 Subject: [PATCH 37/54] revert var name --- vyper/semantics/analysis/local.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 4246fa6760..a6c8eb5b3c 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -350,7 +350,7 @@ def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target.target) - iter_item_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) + iter_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -358,7 +358,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - _analyse_range_call(node.iter, iter_item_type) + _analyse_range_call(node.iter, iter_type) else: # iteration over a variable or literal list @@ -416,13 +416,13 @@ def visit_For(self, node): iter_name = node.target.target.id with self.namespace.enter_scope(): self.namespace[iter_name] = VarInfo( - iter_item_type, modifiability=Modifiability.RUNTIME_CONSTANT + iter_type, modifiability=Modifiability.RUNTIME_CONSTANT ) for stmt in node.body: self.visit(stmt) - self.expr_visitor.visit(node.target.target, iter_item_type) + self.expr_visitor.visit(node.target.target, iter_type) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): iter_type = get_exact_type_from_node(node.iter) @@ -431,13 +431,13 @@ def visit_For(self, node): self.expr_visitor.visit(node.iter, iter_type) if isinstance(node.iter, vy_ast.List): len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(iter_item_type, len_)) + self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": for a in node.iter.args: - self.expr_visitor.visit(a, iter_item_type) + self.expr_visitor.visit(a, iter_type) for a in node.iter.keywords: if a.arg == "bound": - self.expr_visitor.visit(a.value, iter_item_type) + self.expr_visitor.visit(a.value, iter_type) def visit_If(self, node): validate_expected_type(node.test, BoolT()) @@ -714,7 +714,7 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call, iter_item_type: VyperType): +def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType): """ Check that the arguments to a range() call are valid. :param node: call to range() @@ -731,7 +731,7 @@ def _analyse_range_call(node: vy_ast.Call, iter_item_type: VyperType): all_args_folded = (folded_start, folded_end, *kwargs.values()) for unfolded_arg, folded_arg in zip(all_args_unfolded, all_args_folded): try: - validate_expected_type(folded_arg, iter_item_type) + validate_expected_type(folded_arg, iter_type) except VyperException as e: raise InvalidType(str(e), unfolded_arg) From 296ea7ce9a4a69b97afb45af519583187278f90b Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 13:09:39 +0800 Subject: [PATCH 38/54] minor refactor --- vyper/semantics/analysis/local.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index a6c8eb5b3c..9abd46a0d6 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -426,8 +426,6 @@ def visit_For(self, node): if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): iter_type = get_exact_type_from_node(node.iter) - # note CMC 2023-10-23: slightly redundant with how type_list is computed - validate_expected_type(node.target.target, iter_type.value_type) self.expr_visitor.visit(node.iter, iter_type) if isinstance(node.iter, vy_ast.List): len_ = len(node.iter.elements) From 7e45133b4841a83add404ff13330f430cc91dd03 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 14:26:01 +0800 Subject: [PATCH 39/54] fix test --- tests/functional/codegen/types/test_bytes_zero_padding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/codegen/types/test_bytes_zero_padding.py b/tests/functional/codegen/types/test_bytes_zero_padding.py index f9fcf37b25..6597facd1b 100644 --- a/tests/functional/codegen/types/test_bytes_zero_padding.py +++ b/tests/functional/codegen/types/test_bytes_zero_padding.py @@ -10,7 +10,7 @@ def little_endian_contract(get_contract_module): def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value - for _ in range(8): + for _: uint256 in range(8): y = (y << 8) | (x & 255) x >>= 8 return slice(convert(y, bytes32), 24, 8) From 2722b1061dcf0a98b374007f487603ae16376ded Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 14:56:47 +0800 Subject: [PATCH 40/54] fix grammar test --- tests/functional/grammar/test_grammar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index 7dd8c35929..652102c376 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -106,6 +106,6 @@ def has_no_docstrings(c): @hypothesis.settings(max_examples=500) def test_grammar_bruteforce(code): if utf8_encodable(code): - _, _, reformatted_code = pre_parse(code + "\n") + _, _, _, reformatted_code = pre_parse(code + "\n") tree = parse_to_ast(reformatted_code) assert isinstance(tree, Module) From 5bbbb96fbb4e42331f43fa6901d52871efd11521 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 15:03:41 +0800 Subject: [PATCH 41/54] remove TODO --- vyper/ast/grammar.lark | 1 - 1 file changed, 1 deletion(-) diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 4a826153df..234e96e552 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -178,7 +178,6 @@ body: _NEWLINE _INDENT ([COMMENT] _NEWLINE | _stmt)+ _DEDENT cond_exec: _expr ":" body default_exec: body if_stmt: "if" cond_exec ("elif" cond_exec)* ["else" ":" default_exec] -// TODO: make this into a variable definition e.g. `for i: uint256 in range(0, 5): ...` loop_variable: NAME ":" type loop_iterator: _expr for_stmt: "for" loop_variable "in" loop_iterator ":" body From 6b72c38af1c340c03bf5eaf151ef96fec872400c Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 15:06:35 +0800 Subject: [PATCH 42/54] revert empty line --- vyper/ast/parse.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 21b801fd50..31977ca52b 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -338,7 +338,6 @@ def visit_Num(self, node): """ # modify vyper AST type according to the format of the literal value self.generic_visit(node) - value = node.node_source_code # deduce non base-10 types based on prefix From 8fe23c9a1a528cca61015d077213a6578fbe5a44 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 16:12:03 +0800 Subject: [PATCH 43/54] clean up For visit --- vyper/semantics/analysis/local.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 9abd46a0d6..ee4bdf005d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -431,11 +431,10 @@ def visit_For(self, node): len_ = len(node.iter.elements) self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, iter_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, iter_type) + args = node.iter.args + kwargs = [s.value for s in node.iter.keywords] + for arg in (*args, *kwargs): + self.expr_visitor.visit(arg, iter_type) def visit_If(self, node): validate_expected_type(node.test, BoolT()) From f329bb4b05f0e04d96ae9536307eda41f18c0b17 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 16:12:07 +0800 Subject: [PATCH 44/54] add test --- .../codegen/features/iteration/test_for_in_list.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 22fd7ccb43..34f511a245 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -791,6 +791,13 @@ def test_for() -> int128: """, TypeMismatch, ), + """ +@external +def foo(): + a: DynArray[DynArray[uint256, 2], 3] = [[0, 1], [2, 3], [4, 5]] + for i: uint256 in a[2]: + pass + """, ] BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE] From 77d9bd370c48ebe08ceae8bc0c1ebb831fc7f19b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 7 Jan 2024 09:39:19 -0500 Subject: [PATCH 45/54] rename iter_type to target_type to be consistent with AST roll back some stylistic changes to error messages --- vyper/semantics/analysis/local.py | 51 +++++++++++++------------------ 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index ee4bdf005d..f0312331e4 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -350,7 +350,7 @@ def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target.target) - iter_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) + target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -358,7 +358,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - _analyse_range_call(node.iter, iter_type) + _analyse_range_call(node.iter, target_type) else: # iteration over a variable or literal list @@ -413,28 +413,28 @@ def visit_For(self, node): call_node, ) - iter_name = node.target.target.id + target_name = node.target.target.id with self.namespace.enter_scope(): - self.namespace[iter_name] = VarInfo( - iter_type, modifiability=Modifiability.RUNTIME_CONSTANT + self.namespace[target_name] = VarInfo( + target_type, modifiability=Modifiability.RUNTIME_CONSTANT ) for stmt in node.body: self.visit(stmt) - self.expr_visitor.visit(node.target.target, iter_type) + self.expr_visitor.visit(node.target.target, target_type) - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - iter_type = get_exact_type_from_node(node.iter) - self.expr_visitor.visit(node.iter, iter_type) if isinstance(node.iter, vy_ast.List): len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + self.expr_visitor.visit(node.iter, SArrayT(target_type, len_)) + elif isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": args = node.iter.args kwargs = [s.value for s in node.iter.keywords] for arg in (*args, *kwargs): - self.expr_visitor.visit(arg, iter_type) + self.expr_visitor.visit(arg, target_type) + else: + iter_type = get_exact_type_from_node(node.iter) + self.expr_visitor.visit(node.iter, iter_type) def visit_If(self, node): validate_expected_type(node.test, BoolT()) @@ -720,32 +720,23 @@ def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType): validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args - folded_start, folded_end = [ - i.get_folded_value() if i.has_folded_value else i for i in (start, end) - ] - - all_args_unfolded = (start, end, *kwargs.values()) - all_args_folded = (folded_start, folded_end, *kwargs.values()) - for unfolded_arg, folded_arg in zip(all_args_unfolded, all_args_folded): - try: - validate_expected_type(folded_arg, iter_type) - except VyperException as e: - raise InvalidType(str(e), unfolded_arg) + start, end = [i.get_folded_value() if i.has_folded_value else i for i in (start, end)] if "bound" in kwargs: bound = kwargs["bound"] - folded_bound = bound.get_folded_value() if bound.has_folded_value else bound - if not isinstance(folded_bound, vy_ast.Num): + if bound.has_folded_value: + bound = bound.get_folded_value() + if not isinstance(bound, vy_ast.Num): raise StateAccessViolation("Bound must be a literal", bound) - if folded_bound.value <= 0: + if bound.value <= 0: raise StructureException("Bound must be at least 1", bound) - if isinstance(folded_start, vy_ast.Num) and isinstance(folded_end, vy_ast.Num): + if isinstance(start, vy_ast.Num) and isinstance(end, vy_ast.Num): error = "Please remove the `bound=` kwarg when using range with constants" raise StructureException(error, bound) else: - for arg, folded_arg in zip((start, end), (folded_start, folded_end)): - if not isinstance(folded_arg, vy_ast.Num): + for arg in (start, end): + if not isinstance(arg, vy_ast.Num): error = "Value must be a literal integer, unless a bound is specified" raise StateAccessViolation(error, arg) - if folded_end.value <= folded_start.value: + if end.value <= start.value: raise StructureException("End must be greater than start", end) From 7dbdee6f39ff34aeee3685315488e61bf009b6e5 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 7 Jan 2024 09:43:41 -0500 Subject: [PATCH 46/54] remove a dead parameter --- vyper/semantics/analysis/local.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index f0312331e4..042593a7f4 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -711,12 +711,13 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType): +def _analyse_range_call(node: vy_ast.Call): """ Check that the arguments to a range() call are valid. :param node: call to range() :return: None """ + assert node.iter.func.id == "range" validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args From 7cbc8541d028b33cdae7a56fbcc2da88db15b5e9 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 7 Jan 2024 16:30:54 +0000 Subject: [PATCH 47/54] fix a couple small lint bugs --- vyper/semantics/analysis/local.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 042593a7f4..7a542b980a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -358,7 +358,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - _analyse_range_call(node.iter, target_type) + _analyse_range_call(node.iter) else: # iteration over a variable or literal list @@ -717,7 +717,7 @@ def _analyse_range_call(node: vy_ast.Call): :param node: call to range() :return: None """ - assert node.iter.func.id == "range" + assert node.func.id == "range" validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args From b4f3c39739a367e56437ac72fa3d10f9774b361e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 7 Jan 2024 16:32:24 +0000 Subject: [PATCH 48/54] fix exception messages for folded nodes --- tests/functional/syntax/test_for_range.py | 13 ++++--------- vyper/ast/nodes.py | 13 +++++++++---- vyper/exceptions.py | 5 +++++ 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index f0c47035e4..b7a8133db3 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -5,6 +5,7 @@ from vyper import compiler from vyper.exceptions import ( ArgumentException, + TypeMismatch, InvalidType, StateAccessViolation, StructureException, @@ -218,14 +219,8 @@ def foo(): for i: uint256 in range(FOO, BAR): pass """, - InvalidType, - """Expected uint256 but literal can only be cast as int128. - line 2:24 - 1 - ---> 2 FOO: constant(int128) = 3 - -------------------------------^ - 3 BAR: constant(uint256) = 7 -""", # noqa: W291 + TypeMismatch, + "Given reference has type int128, expected uint256", "FOO", ), ( @@ -258,7 +253,7 @@ def test_range_fail(bad_code, error_type, message, source_code): with pytest.raises(error_type) as exc_info: compiler.compile_code(bad_code) assert message == exc_info.value.message - assert source_code == exc_info.value.args[1].node_source_code + assert source_code == exc_info.value.args[1].get_original_node().node_source_code valid_list = [ diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index fffd3ca7cd..9e4eb19561 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -24,7 +24,7 @@ ) from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code -NODE_BASE_ATTRIBUTES = ("_children", "_depth", "_parent", "ast_type", "node_id", "_metadata") +NODE_BASE_ATTRIBUTES = ("_children", "_depth", "_parent", "ast_type", "node_id", "_metadata", "_original_node") NODE_SRC_ATTRIBUTES = ( "col_offset", "end_col_offset", @@ -257,6 +257,7 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): self.set_parent(parent) self._children: set = set() self._metadata: NodeMetadata = NodeMetadata() + self._original_node = None for field_name in NODE_SRC_ATTRIBUTES: # when a source offset is not available, use the parent's source offset @@ -411,12 +412,16 @@ def _set_folded_value(self, node: "VyperNode") -> None: # sanity check this is only called once assert "folded_value" not in self._metadata - # set the folded node's parent so that get_ancestor works - # this is mainly important for error messages. - node._parent = self._parent + # set the "original node" so that exceptions can point to the original + # node and not the folded node + node = copy.copy(node) + node._original_node = self self._metadata["folded_value"] = node + def get_original_node(self) -> "VyperNode": + return self._original_node or self + def _try_fold(self) -> "VyperNode": """ Attempt to constant-fold the content of a node, returning the result of diff --git a/vyper/exceptions.py b/vyper/exceptions.py index f216069eab..6fc1ace734 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -60,6 +60,7 @@ def __init__(self, message="Error Message not found.", *items): # annotation (in case it is only available optionally) self.annotations = [k for k in items if k is not None] + def with_annotation(self, *annotations): """ Creates a copy of this exception with a modified source annotation. @@ -92,6 +93,10 @@ def __str__(self): node = value[1] if isinstance(value, tuple) else value node_msg = "" + if isinstance(node, vy_ast.VyperNode): + # folded AST nodes contain pointers to the original source + node = node.get_original_node() + try: source_annotation = annotate_source_code( # add trailing space because EOF exceptions point one char beyond the length From 0d759c7d45ca3324e483825947166d2f4a3ae5e6 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 7 Jan 2024 16:41:49 +0000 Subject: [PATCH 49/54] fix lint --- tests/functional/syntax/test_for_range.py | 3 +-- vyper/ast/nodes.py | 10 +++++++++- vyper/exceptions.py | 1 - vyper/semantics/analysis/local.py | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index b7a8133db3..66981a90de 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -5,10 +5,9 @@ from vyper import compiler from vyper.exceptions import ( ArgumentException, - TypeMismatch, - InvalidType, StateAccessViolation, StructureException, + TypeMismatch, ) fail_list = [ diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 9e4eb19561..7a8c7443b7 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -24,7 +24,15 @@ ) from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code -NODE_BASE_ATTRIBUTES = ("_children", "_depth", "_parent", "ast_type", "node_id", "_metadata", "_original_node") +NODE_BASE_ATTRIBUTES = ( + "_children", + "_depth", + "_parent", + "ast_type", + "node_id", + "_metadata", + "_original_node", +) NODE_SRC_ATTRIBUTES = ( "col_offset", "end_col_offset", diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 6fc1ace734..51f3fea14c 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -60,7 +60,6 @@ def __init__(self, message="Error Message not found.", *items): # annotation (in case it is only available optionally) self.annotations = [k for k in items if k is not None] - def with_annotation(self, *annotations): """ Creates a copy of this exception with a modified source annotation. diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 7a542b980a..d5af3e243a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -717,7 +717,7 @@ def _analyse_range_call(node: vy_ast.Call): :param node: call to range() :return: None """ - assert node.func.id == "range" + assert node.func.get("id") == "range" validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args From 9dcac6ceefbc83aabf3a34dce8791a870deaf946 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 7 Jan 2024 12:30:20 -0500 Subject: [PATCH 50/54] allow iterating over subscript --- .../features/iteration/test_for_in_list.py | 19 ++++++++++++------- vyper/semantics/analysis/local.py | 3 --- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 34f511a245..5c7b5c6b1b 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -418,6 +418,18 @@ def a() -> int128: return x""", -14, ), + ( + """ +@external +def a() -> uint256: + a: DynArray[DynArray[uint256, 2], 3] = [[0, 1], [2, 3], [4, 5]] + x: uint256 = 0 + for i: uint256 in a[2]: + x += i + return x + """, + 9, + ), ] @@ -791,13 +803,6 @@ def test_for() -> int128: """, TypeMismatch, ), - """ -@external -def foo(): - a: DynArray[DynArray[uint256, 2], 3] = [[0, 1], [2, 3], [4, 5]] - for i: uint256 in a[2]: - pass - """, ] BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE] diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d5af3e243a..d66ec1f591 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -344,9 +344,6 @@ def visit_Expr(self, node): self.expr_visitor.visit(node.value, fn_type) def visit_For(self, node): - if isinstance(node.iter, vy_ast.Subscript): - raise StructureException("Cannot iterate over a nested list", node.iter) - if not isinstance(node.target.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target.target) From 2351bb22ce4054298b89cc6b52dd7fdb23e9d8a9 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 7 Jan 2024 12:54:00 -0500 Subject: [PATCH 51/54] revert removed tests --- .../unit/semantics/analysis/test_for_loop.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index ccd501e101..607587cc28 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -1,7 +1,12 @@ import pytest from vyper.ast import parse_to_ast -from vyper.exceptions import ArgumentException, ImmutableViolation, StateAccessViolation +from vyper.exceptions import ( + ArgumentException, + ImmutableViolation, + StateAccessViolation, + TypeMismatch, +) from vyper.semantics.analysis import validate_semantics @@ -127,3 +132,44 @@ def baz(): vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): validate_semantics(vyper_module, dummy_input_bundle) + + +iterator_inference_codes = [ + """ +@external +def main(): + for j: uint256 in range(3): + x: uint256 = j + y: uint16 = j + """, # GH issue 3212 + """ +@external +def foo(): + for i: uint256 in [1]: + a: uint256 = i + b: uint16 = i + """, # GH issue 3374 + """ +@external +def foo(): + for i: uint256 in [1]: + for j: uint256 in [1]: + a: uint256 = i + b: uint16 = i + """, # GH issue 3374 + """ +@external +def foo(): + for i: uint256 in [1,2,3]: + for j: uint256 in [1,2,3]: + b: uint256 = j + i + c: uint16 = i + """, # GH issue 3374 +] + + +@pytest.mark.parametrize("code", iterator_inference_codes) +def test_iterator_type_inference_checker(code, dummy_input_bundle): + vyper_module = parse_to_ast(code) + with pytest.raises(TypeMismatch): + validate_semantics(vyper_module, dummy_input_bundle) From 6e04ed5c1ce585b9621c1c901d27fdfbdad253d9 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 7 Jan 2024 13:00:29 -0500 Subject: [PATCH 52/54] rename some variables --- vyper/codegen/stmt.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 03f3e4450d..a47faefeb1 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -232,15 +232,16 @@ def parse_For(self): def _parse_For_range(self): assert "type" in self.stmt.target.target._metadata - iter_typ = self.stmt.target.target._metadata["type"] + target_type = self.stmt.target.target._metadata["type"] # Get arg0 - for_iter: vy_ast.Call = self.stmt.iter - args_len = len(for_iter.args) + range_call: vy_ast.Call = self.stmt.iter + assert isinstance(range_call, vy_ast.Call) + args_len = len(range_call.args) if args_len == 1: - arg0, arg1 = (IRnode.from_list(0, typ=iter_typ), for_iter.args[0]) + arg0, arg1 = (IRnode.from_list(0, typ=target_type), range_call.args[0]) elif args_len == 2: - arg0, arg1 = for_iter.args + arg0, arg1 = range_call.args else: # pragma: nocover raise TypeCheckFailure("unreachable: bad # of arguments to range()") @@ -248,7 +249,7 @@ def _parse_For_range(self): start = Expr.parse_value_expr(arg0, self.context) end = Expr.parse_value_expr(arg1, self.context) kwargs = { - s.arg: Expr.parse_value_expr(s.value, self.context) for s in for_iter.keywords + s.arg: Expr.parse_value_expr(s.value, self.context) for s in range_call.keywords } if "bound" in kwargs: @@ -268,8 +269,8 @@ def _parse_For_range(self): raise TypeCheckFailure("unreachable: unchecked 0 bound") varname = self.stmt.target.target.id - i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=iter_typ) - iptr = self.context.new_variable(varname, iter_typ) + i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=target_type) + iptr = self.context.new_variable(varname, target_type) self.context.forvars[varname] = True From afb321524bdbf46c52441e3e0960725ef8126761 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 7 Jan 2024 13:00:34 -0500 Subject: [PATCH 53/54] clarify a comment --- vyper/ast/parse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 31977ca52b..b657cf2245 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -240,8 +240,8 @@ def visit_For(self, node): try: annotation = python_ast.parse(raw_annotation, mode="eval") - # enhance the Python AST tree with token and source code information, specifically the - # `first_token` and `last_token` attributes that are accessed in `generic_visit`. + # annotate with token and source code information. `first_token` + # and `last_token` attributes are accessed in `generic_visit`. tokens = asttokens.ASTTokens(raw_annotation) tokens.mark_tokens(annotation) except SyntaxError as e: From 9678f915e60c57d166857fa6104cf3a96f6b9ab7 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 7 Jan 2024 13:04:13 -0500 Subject: [PATCH 54/54] rename a function --- vyper/semantics/analysis/local.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d66ec1f591..169c71269d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -355,7 +355,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - _analyse_range_call(node.iter) + _validate_range_call(node.iter) else: # iteration over a variable or literal list @@ -708,7 +708,7 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call): +def _validate_range_call(node: vy_ast.Call): """ Check that the arguments to a range() call are valid. :param node: call to range()