Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: require type annotations for loop variables #3596

Merged
merged 56 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
0dcb40e
wip
tserg Sep 7, 2023
6aacc2c
Merge branch 'master' of https://github.com/vyperlang/vyper into feat…
tserg Jan 5, 2024
dabe7e7
apply bts suggestion
tserg Jan 6, 2024
899699e
fix for visit
tserg Jan 6, 2024
8d9b2ef
clean up prints
tserg Jan 6, 2024
bc5422a
update examples
tserg Jan 6, 2024
deb860f
delete py ast key
tserg Jan 6, 2024
2c2792f
remove prints
tserg Jan 6, 2024
daef0b7
fix exc in for
tserg Jan 6, 2024
f644f8a
update grammar
tserg Jan 6, 2024
49ad2cd
update tests
tserg Jan 6, 2024
635e0c0
fix tests
tserg Jan 6, 2024
42e06f5
fix lint
tserg Jan 6, 2024
e7a4612
revert a change
charles-cooper Jan 6, 2024
6f6acea
add visit_For in py ast parse
tserg Jan 6, 2024
578d471
remove typechecker speculation
tserg Jan 6, 2024
07fcde2
remove prints
tserg Jan 6, 2024
76c3d2d
fix sqrt
tserg Jan 6, 2024
9db7a36
fix visit_Num
tserg Jan 6, 2024
e055ee9
update comments
tserg Jan 6, 2024
c7acdaa
fix tests
tserg Jan 6, 2024
0dd86fd
fix lint
tserg Jan 6, 2024
f64e6f1
fix mypy
tserg Jan 6, 2024
b913d45
Merge branch 'feat/loop_var_annotation2' of https://github.com/tserg/…
tserg Jan 6, 2024
5bc54c0
simpliy visit_For
tserg Jan 6, 2024
b951b47
rewrite for loop slurper with a small state machine
charles-cooper Jan 6, 2024
3c5c0cb
rewrite visit_For, use AnnAssign for the target
charles-cooper Jan 6, 2024
fe6721c
add a comment
charles-cooper Jan 6, 2024
f74fe50
fix lint
charles-cooper Jan 6, 2024
c5bcb9b
fix comment in for parser
tserg Jan 7, 2024
01cc34b
fix For codegen
tserg Jan 7, 2024
2948dc7
call ASTTokens ctor for side effects
tserg Jan 7, 2024
54c31b8
fix visit_For semantics
tserg Jan 7, 2024
ef7841f
replace ctor with mark_tokens
tserg Jan 7, 2024
1caba88
fix more codegen
tserg Jan 7, 2024
62208d5
fix tests
tserg Jan 7, 2024
c140574
improve diagnostics for For
tserg Jan 7, 2024
58c1356
update tests
tserg Jan 7, 2024
d068ebb
revert var name
tserg Jan 7, 2024
296ea7c
minor refactor
tserg Jan 7, 2024
7e45133
fix test
tserg Jan 7, 2024
2722b10
fix grammar test
tserg Jan 7, 2024
5bbbb96
remove TODO
tserg Jan 7, 2024
6b72c38
revert empty line
tserg Jan 7, 2024
8fe23c9
clean up For visit
tserg Jan 7, 2024
f329bb4
add test
tserg Jan 7, 2024
77d9bd3
rename iter_type to target_type to be consistent with AST
charles-cooper Jan 7, 2024
7dbdee6
remove a dead parameter
charles-cooper Jan 7, 2024
7cbc854
fix a couple small lint bugs
charles-cooper Jan 7, 2024
b4f3c39
fix exception messages for folded nodes
charles-cooper Jan 7, 2024
0d759c7
fix lint
charles-cooper Jan 7, 2024
9dcac6c
allow iterating over subscript
charles-cooper Jan 7, 2024
2351bb2
revert removed tests
charles-cooper Jan 7, 2024
6e04ed5
rename some variables
charles-cooper Jan 7, 2024
afb3215
clarify a comment
charles-cooper Jan 7, 2024
9678f91
rename a function
charles-cooper Jan 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vyper/ast/annotation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast as python_ast
import tokenize
from decimal import Decimal
from typing import Optional, cast
from typing import Any, Optional, cast
Fixed Show fixed Hide fixed

import asttokens

Expand Down Expand Up @@ -249,6 +249,7 @@
parsed_ast: python_ast.AST,
source_code: str,
modification_offsets: Optional[ModificationOffsets] = None,
loop_var_annotations: Optional[dict[int, python_ast.AST]] = None,
source_id: int = 0,
contract_name: Optional[str] = None,
) -> python_ast.AST:
Expand All @@ -272,5 +273,9 @@
tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast))
visitor = AnnotatingVisitor(source_code, modification_offsets, tokens, source_id, contract_name)
visitor.visit(parsed_ast)
for k, v in loop_var_annotations.items():
tokens = asttokens.ASTTokens(v["source_code"], tree=cast(Optional[python_ast.Module], v["parsed_ast"]))
visitor = AnnotatingVisitor(v["source_code"], {}, tokens, source_id, contract_name)
visitor.visit(v["parsed_ast"])

return parsed_ast
43 changes: 42 additions & 1 deletion vyper/ast/pre_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,16 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]:
result = []
modification_offsets: ModificationOffsets = {}
settings = Settings()
loop_var_annotations = {}

try:
code_bytes = code.encode("utf-8")
token_list = list(tokenize(io.BytesIO(code_bytes).readline))

is_for_loop = False
after_loop_var = False
loop_var_annotation = []

for i in range(len(token_list)):
token = token_list[i]
toks = [token]
Expand Down Expand Up @@ -142,8 +147,44 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]:

if (typ, string) == (OP, ";"):
raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1])

if typ == NAME and string == "for":
is_for_loop = True
#print("for loop!")
#print(token)

if is_for_loop:
if typ == NAME and string == "in":
loop_var_annotations[start[0]] = loop_var_annotation

is_for_loop = False
after_loop_var = False
loop_var_annotation = []

elif (typ, string) == (OP, ":"):
after_loop_var = True
continue

elif after_loop_var and not (typ == NAME and string == "for"):
#print("adding to loop var: ", toks)
loop_var_annotation.extend(toks)
continue

result.extend(toks)
except TokenError as e:
raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e

return settings, modification_offsets, untokenize(result).decode("utf-8")
for k, v in loop_var_annotations.items():


updated_v = untokenize(v)
#print("untokenized v: ", updated_v)
updated_v = updated_v.replace("\\", "")
updated_v = updated_v.replace("\n", "")
import textwrap
#print("updated v: ", textwrap.dedent(updated_v))
loop_var_annotations[k] = {"source_code": textwrap.dedent(updated_v)}

#print("untokenized result: ", type(untokenize(result)))
#print("untokenized result decoded: ", untokenize(result).decode("utf-8"))
return settings, modification_offsets, loop_var_annotations, untokenize(result).decode("utf-8")
19 changes: 15 additions & 4 deletions vyper/ast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def parse_to_ast_with_settings(
source_id: int = 0,
contract_name: Optional[str] = None,
add_fn_node: Optional[str] = None,
) -> tuple[Settings, vy_ast.Module]:
) -> tuple[Settings, vy_ast.Module, dict[int, dict[str, Any]]]:
"""
Parses a Vyper source string and generates basic Vyper AST nodes.

Expand All @@ -39,9 +39,16 @@ def parse_to_ast_with_settings(
"""
if "\x00" in source_code:
raise ParserException("No null bytes (\\x00) allowed in the source code.")
settings, class_types, reformatted_code = pre_parse(source_code)
settings, class_types, loop_var_annotations, reformatted_code = pre_parse(source_code)
try:
py_ast = python_ast.parse(reformatted_code)

print("loop vars: ", loop_var_annotations)
for k, v in loop_var_annotations.items():
print("v: ", v)
parsed_v = python_ast.parse(v["source_code"])
print("parsed v: ", parsed_v.body[0].value)
loop_var_annotations[k]["parsed_ast"] = parsed_v
except SyntaxError as e:
# TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors
raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e
Expand All @@ -53,12 +60,16 @@ def parse_to_ast_with_settings(
fn_node.body = py_ast.body
fn_node.args = python_ast.arguments(defaults=[])
py_ast.body = [fn_node]
annotate_python_ast(py_ast, source_code, class_types, source_id, contract_name)
annotate_python_ast(py_ast, source_code, class_types, loop_var_annotations, source_id, contract_name)

# Convert to Vyper AST.
module = vy_ast.get_node(py_ast)

for k, v in loop_var_annotations.items():
loop_var_annotations[k]["vy_ast"] = vy_ast.get_node(v["parsed_ast"])

assert isinstance(module, vy_ast.Module) # mypy hint
return settings, module
return settings, module, loop_var_annotations


def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]:
Expand Down
13 changes: 7 additions & 6 deletions vyper/compiler/phases.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import warnings
from functools import cached_property
from typing import Optional, Tuple
from typing import Any, Optional, Tuple

from vyper import ast as vy_ast
from vyper.codegen import module
Expand Down Expand Up @@ -90,7 +90,7 @@

@cached_property
def _generate_ast(self):
settings, ast = generate_ast(self.source_code, self.source_id, self.contract_name)
settings, ast, loop_var_annotations = generate_ast(self.source_code, self.source_id, self.contract_name)
# validate the compiler settings
# XXX: this is a bit ugly, clean up later
if settings.evm_version is not None:
Expand All @@ -117,7 +117,7 @@
if self.settings.optimize is None:
self.settings.optimize = OptimizationLevel.default()

return ast
return ast, loop_var_annotations

@cached_property
def vyper_module(self):
Expand All @@ -128,12 +128,12 @@
# This phase is intended to generate an AST for tooling use, and is not
# used in the compilation process.

return generate_unfolded_ast(self.vyper_module, self.interface_codes)
return generate_unfolded_ast(self.vyper_module[0], self.interface_codes)

@cached_property
def _folded_module(self):
return generate_folded_ast(
self.vyper_module, self.interface_codes, self.storage_layout_override
self.vyper_module[0], self.vyper_module[1], self.interface_codes, self.storage_layout_override
)

@property
Expand Down Expand Up @@ -233,13 +233,14 @@
vy_ast.folding.replace_builtin_constants(vyper_module)
vy_ast.folding.replace_builtin_functions(vyper_module)
# note: validate_semantics does type inference on the AST
validate_semantics(vyper_module, interface_codes)

Check failure

Code scanning / CodeQL

Wrong number of arguments in a call Error

Call to
function validate_semantics
with too few arguments; should be no fewer than 3.

return vyper_module


def generate_folded_ast(
vyper_module: vy_ast.Module,
loop_var_annotations: dict[int, dict[str, Any]],
interface_codes: Optional[InterfaceImports],
storage_layout_overrides: StorageLayout = None,
) -> Tuple[vy_ast.Module, StorageLayout]:
Expand All @@ -262,7 +263,7 @@

vyper_module_folded = copy.deepcopy(vyper_module)
vy_ast.folding.fold(vyper_module_folded)
validate_semantics(vyper_module_folded, interface_codes)
validate_semantics(vyper_module_folded, loop_var_annotations, interface_codes)
symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides)

return vyper_module_folded, symbol_tables
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from .utils import _ExprAnalyser


def validate_semantics(vyper_ast, interface_codes):
def validate_semantics(vyper_ast, loop_var_annotations, interface_codes):
# validate semantics and annotate AST with type/semantics information
namespace = get_namespace()

with namespace.enter_scope():
add_module_namespace(vyper_ast, interface_codes)
vy_ast.expansion.expand_annotated_ast(vyper_ast)
validate_functions(vyper_ast)
validate_functions(vyper_ast, loop_var_annotations)
16 changes: 12 additions & 4 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional

Fixed Show fixed Hide fixed
from vyper import ast as vy_ast
from vyper.ast.metadata import NodeMetadata
Expand Down Expand Up @@ -50,15 +50,15 @@
from vyper.semantics.types.utils import type_from_annotation


def validate_functions(vy_module: vy_ast.Module) -> None:
def validate_functions(vy_module: vy_ast.Module, loop_var_annotations: dict[int, dict[str, Any]]) -> None:
"""Analyzes a vyper ast and validates the function-level namespaces."""

err_list = ExceptionList()
namespace = get_namespace()
for node in vy_module.get_children(vy_ast.FunctionDef):
with namespace.enter_scope():
try:
FunctionNodeVisitor(vy_module, node, namespace)
FunctionNodeVisitor(vy_module, loop_var_annotations, node, namespace)
except VyperException as e:
err_list.append(e)

Expand Down Expand Up @@ -165,9 +165,10 @@ class FunctionNodeVisitor(VyperNodeVisitorBase):
scope_name = "function"

def __init__(
self, vyper_module: vy_ast.Module, fn_node: vy_ast.FunctionDef, namespace: dict
self, vyper_module: vy_ast.Module, loop_var_annotations: dict[int, dict[str, Any]], fn_node: vy_ast.FunctionDef, namespace: dict
) -> None:
self.vyper_module = vyper_module
self.loop_var_annotations = loop_var_annotations
self.fn_node = fn_node
self.namespace = namespace
self.func = fn_node._metadata["type"]
Expand Down Expand Up @@ -340,6 +341,11 @@ def visit_For(self, node):
if isinstance(node.iter, vy_ast.Subscript):
raise StructureException("Cannot iterate over a nested list", node.iter)

print("visit_For: ", node.lineno)
iter_type = type_from_annotation(self.loop_var_annotations[node.lineno]["vy_ast"].body[0].value)
print("iter type: ", iter_type)
node.target._metadata["type"] = iter_type

if isinstance(node.iter, vy_ast.Call):
# iteration via range()
if node.iter.get("func.id") != "range":
Expand Down Expand Up @@ -468,6 +474,7 @@ def visit_For(self, node):
if not isinstance(node.target, vy_ast.Name):
raise StructureException("Invalid syntax for loop iterator", node.target)

"""
for_loop_exceptions = []
iter_name = node.target.id
for type_ in type_list:
Expand Down Expand Up @@ -514,6 +521,7 @@ def visit_For(self, node):
for type_, exc in zip(type_list, for_loop_exceptions)
),
)
"""

def visit_Expr(self, node):
if not isinstance(node.value, vy_ast.Call):
Expand Down
Loading