Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: require type annotations for loop variables #3596

Merged
merged 56 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
0dcb40e
wip
tserg Sep 7, 2023
6aacc2c
Merge branch 'master' of https://github.com/vyperlang/vyper into feat…
tserg Jan 5, 2024
dabe7e7
apply bts suggestion
tserg Jan 6, 2024
899699e
fix for visit
tserg Jan 6, 2024
8d9b2ef
clean up prints
tserg Jan 6, 2024
bc5422a
update examples
tserg Jan 6, 2024
deb860f
delete py ast key
tserg Jan 6, 2024
2c2792f
remove prints
tserg Jan 6, 2024
daef0b7
fix exc in for
tserg Jan 6, 2024
f644f8a
update grammar
tserg Jan 6, 2024
49ad2cd
update tests
tserg Jan 6, 2024
635e0c0
fix tests
tserg Jan 6, 2024
42e06f5
fix lint
tserg Jan 6, 2024
e7a4612
revert a change
charles-cooper Jan 6, 2024
6f6acea
add visit_For in py ast parse
tserg Jan 6, 2024
578d471
remove typechecker speculation
tserg Jan 6, 2024
07fcde2
remove prints
tserg Jan 6, 2024
76c3d2d
fix sqrt
tserg Jan 6, 2024
9db7a36
fix visit_Num
tserg Jan 6, 2024
e055ee9
update comments
tserg Jan 6, 2024
c7acdaa
fix tests
tserg Jan 6, 2024
0dd86fd
fix lint
tserg Jan 6, 2024
f64e6f1
fix mypy
tserg Jan 6, 2024
b913d45
Merge branch 'feat/loop_var_annotation2' of https://github.com/tserg/…
tserg Jan 6, 2024
5bc54c0
simpliy visit_For
tserg Jan 6, 2024
b951b47
rewrite for loop slurper with a small state machine
charles-cooper Jan 6, 2024
3c5c0cb
rewrite visit_For, use AnnAssign for the target
charles-cooper Jan 6, 2024
fe6721c
add a comment
charles-cooper Jan 6, 2024
f74fe50
fix lint
charles-cooper Jan 6, 2024
c5bcb9b
fix comment in for parser
tserg Jan 7, 2024
01cc34b
fix For codegen
tserg Jan 7, 2024
2948dc7
call ASTTokens ctor for side effects
tserg Jan 7, 2024
54c31b8
fix visit_For semantics
tserg Jan 7, 2024
ef7841f
replace ctor with mark_tokens
tserg Jan 7, 2024
1caba88
fix more codegen
tserg Jan 7, 2024
62208d5
fix tests
tserg Jan 7, 2024
c140574
improve diagnostics for For
tserg Jan 7, 2024
58c1356
update tests
tserg Jan 7, 2024
d068ebb
revert var name
tserg Jan 7, 2024
296ea7c
minor refactor
tserg Jan 7, 2024
7e45133
fix test
tserg Jan 7, 2024
2722b10
fix grammar test
tserg Jan 7, 2024
5bbbb96
remove TODO
tserg Jan 7, 2024
6b72c38
revert empty line
tserg Jan 7, 2024
8fe23c9
clean up For visit
tserg Jan 7, 2024
f329bb4
add test
tserg Jan 7, 2024
77d9bd3
rename iter_type to target_type to be consistent with AST
charles-cooper Jan 7, 2024
7dbdee6
remove a dead parameter
charles-cooper Jan 7, 2024
7cbc854
fix a couple small lint bugs
charles-cooper Jan 7, 2024
b4f3c39
fix exception messages for folded nodes
charles-cooper Jan 7, 2024
0d759c7
fix lint
charles-cooper Jan 7, 2024
9dcac6c
allow iterating over subscript
charles-cooper Jan 7, 2024
2351bb2
revert removed tests
charles-cooper Jan 7, 2024
6e04ed5
rename some variables
charles-cooper Jan 7, 2024
afb3215
clarify a comment
charles-cooper Jan 7, 2024
9678f91
rename a function
charles-cooper Jan 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions vyper/ast/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from vyper.typing import ModificationOffsets


def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module:
_settings, ast = parse_to_ast_with_settings(*args, **kwargs)
return ast
def parse_to_ast(*args: Any, **kwargs: Any) -> tuple[vy_ast.Module, dict[int, dict[str, Any]]]:
_settings, ast, loop_var_annotations = parse_to_ast_with_settings(*args, **kwargs)
return ast, loop_var_annotations


def parse_to_ast_with_settings(
Expand All @@ -23,7 +23,7 @@ def parse_to_ast_with_settings(
module_path: Optional[str] = None,
resolved_path: Optional[str] = None,
add_fn_node: Optional[str] = None,
) -> tuple[Settings, vy_ast.Module]:
) -> tuple[Settings, vy_ast.Module, dict[int, dict[str, Any]]]:
"""
Parses a Vyper source string and generates basic Vyper AST nodes.

Expand Down Expand Up @@ -54,7 +54,7 @@ def parse_to_ast_with_settings(
"""
if "\x00" in source_code:
raise ParserException("No null bytes (\\x00) allowed in the source code.")
settings, class_types, reformatted_code = pre_parse(source_code)
settings, class_types, loop_var_annotations, reformatted_code = pre_parse(source_code)
try:
py_ast = python_ast.parse(reformatted_code)
except SyntaxError as e:
Expand All @@ -73,6 +73,7 @@ def parse_to_ast_with_settings(
py_ast,
source_code,
class_types,
loop_var_annotations,
source_id,
module_path=module_path,
resolved_path=resolved_path,
Expand All @@ -82,7 +83,7 @@ def parse_to_ast_with_settings(
module = vy_ast.get_node(py_ast)
assert isinstance(module, vy_ast.Module) # mypy hint

return settings, module
return settings, module, loop_var_annotations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think instead of passing around loop_var_annotations everywhere we should just tag it on the AST during the annotate_vyper_ast portion of ast parsing.



def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]:
Expand Down Expand Up @@ -356,6 +357,7 @@ def annotate_python_ast(
parsed_ast: python_ast.AST,
source_code: str,
modification_offsets: Optional[ModificationOffsets] = None,
loop_var_annotations: Optional[dict[int, python_ast.AST]] = None,
source_id: int = 0,
module_path: Optional[str] = None,
resolved_path: Optional[str] = None,
Expand Down Expand Up @@ -387,5 +389,9 @@ def annotate_python_ast(
resolved_path=resolved_path,
)
visitor.visit(parsed_ast)
for k, v in loop_var_annotations.items():
tokens = asttokens.ASTTokens(v["source_code"], tree=cast(Optional[python_ast.Module], v["parsed_ast"]))
visitor = AnnotatingVisitor(v["source_code"], {}, tokens, source_id, contract_name)
visitor.visit(v["parsed_ast"])

return parsed_ast
43 changes: 42 additions & 1 deletion vyper/ast/pre_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,16 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]:
result = []
modification_offsets: ModificationOffsets = {}
settings = Settings()
loop_var_annotations = {}

try:
code_bytes = code.encode("utf-8")
token_list = list(tokenize(io.BytesIO(code_bytes).readline))

is_for_loop = False
after_loop_var = False
loop_var_annotation = []

for i in range(len(token_list)):
token = token_list[i]
toks = [token]
Expand Down Expand Up @@ -146,8 +151,44 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]:

if (typ, string) == (OP, ";"):
raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1])

if typ == NAME and string == "for":
is_for_loop = True
#print("for loop!")
#print(token)

if is_for_loop:
if typ == NAME and string == "in":
loop_var_annotations[start[0]] = loop_var_annotation

is_for_loop = False
after_loop_var = False
loop_var_annotation = []

elif (typ, string) == (OP, ":"):
after_loop_var = True
continue

elif after_loop_var and not (typ == NAME and string == "for"):
#print("adding to loop var: ", toks)
loop_var_annotation.extend(toks)
continue

result.extend(toks)
except TokenError as e:
raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e

return settings, modification_offsets, untokenize(result).decode("utf-8")
for k, v in loop_var_annotations.items():


updated_v = untokenize(v)
#print("untokenized v: ", updated_v)
updated_v = updated_v.replace("\\", "")
updated_v = updated_v.replace("\n", "")
import textwrap
#print("updated v: ", textwrap.dedent(updated_v))
loop_var_annotations[k] = {"source_code": textwrap.dedent(updated_v)}

#print("untokenized result: ", type(untokenize(result)))
#print("untokenized result decoded: ", untokenize(result).decode("utf-8"))
return settings, modification_offsets, loop_var_annotations, untokenize(result).decode("utf-8")
12 changes: 7 additions & 5 deletions vyper/compiler/phases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from functools import cached_property
from pathlib import Path, PurePath
from typing import Optional
from typing import Any, Optional
Fixed Show fixed Hide fixed

from vyper import ast as vy_ast
from vyper.codegen import module
Expand Down Expand Up @@ -129,7 +129,7 @@ def contract_path(self):

@cached_property
def _generate_ast(self):
settings, ast = vy_ast.parse_to_ast_with_settings(
settings, ast, loop_var_annotations = vy_ast.parse_to_ast_with_settings(
self.source_code,
self.source_id,
module_path=str(self.contract_path),
Expand All @@ -145,16 +145,17 @@ def _generate_ast(self):

# note self.settings.compiler_version is erased here as it is
# not used after pre-parsing
return ast
return ast, loop_var_annotations

@cached_property
def vyper_module(self):
return self._generate_ast

@cached_property
def _annotated_module(self):
ast, loop_var_annotations = self.vyper_module
return generate_annotated_ast(
self.vyper_module, self.input_bundle, self.storage_layout_override
ast, loop_var_annotations, self.input_bundle, self.storage_layout_override
)

@property
Expand Down Expand Up @@ -244,6 +245,7 @@ def blueprint_bytecode(self) -> bytes:

def generate_annotated_ast(
vyper_module: vy_ast.Module,
loop_var_annotations: dict[int, dict[str, Any]],
input_bundle: InputBundle,
storage_layout_overrides: StorageLayout = None,
) -> tuple[vy_ast.Module, StorageLayout]:
Expand All @@ -265,7 +267,7 @@ def generate_annotated_ast(
vyper_module = copy.deepcopy(vyper_module)
with input_bundle.search_path(Path(vyper_module.resolved_path).parent):
# note: validate_semantics does type inference on the AST
validate_semantics(vyper_module, input_bundle)
validate_semantics(vyper_module, loop_var_annotations, input_bundle)

symbol_tables = set_data_positions(vyper_module, storage_layout_overrides)

Expand Down
126 changes: 66 additions & 60 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional
Fixed Show fixed Hide fixed

from vyper import ast as vy_ast
from vyper.ast.metadata import NodeMetadata
Expand Down Expand Up @@ -53,15 +53,15 @@
from vyper.semantics.types.utils import type_from_annotation


def validate_functions(vy_module: vy_ast.Module) -> None:
def validate_functions(vy_module: vy_ast.Module, loop_var_annotations: dict[int, dict[str, Any]]) -> None:
"""Analyzes a vyper ast and validates the function bodies"""

err_list = ExceptionList()
namespace = get_namespace()
for node in vy_module.get_children(vy_ast.FunctionDef):
with namespace.enter_scope():
try:
analyzer = FunctionNodeVisitor(vy_module, node, namespace)
analyzer = FunctionNodeVisitor(vy_module, loop_var_annotations, node, namespace)
analyzer.analyze()
except VyperException as e:
err_list.append(e)
Expand Down Expand Up @@ -180,9 +180,10 @@
scope_name = "function"

def __init__(
self, vyper_module: vy_ast.Module, fn_node: vy_ast.FunctionDef, namespace: dict
self, vyper_module: vy_ast.Module, loop_var_annotations: dict[int, dict[str, Any]], fn_node: vy_ast.FunctionDef, namespace: dict
) -> None:
self.vyper_module = vyper_module
self.loop_var_annotations = loop_var_annotations
self.fn_node = fn_node
self.namespace = namespace
self.func = fn_node._metadata["func_type"]
Expand Down Expand Up @@ -350,6 +351,11 @@
if isinstance(node.iter, vy_ast.Subscript):
raise StructureException("Cannot iterate over a nested list", node.iter)

print("visit_For: ", node.lineno)
iter_type = type_from_annotation(self.loop_var_annotations[node.lineno]["vy_ast"].body[0].value)
print("iter type: ", iter_type)
node.target._metadata["type"] = iter_type

if isinstance(node.iter, vy_ast.Call):
# iteration via range()
if node.iter.get("func.id") != "range":
Expand Down Expand Up @@ -418,62 +424,62 @@
if not isinstance(node.target, vy_ast.Name):
raise StructureException("Invalid syntax for loop iterator", node.target)

for_loop_exceptions = []
iter_name = node.target.id
for possible_target_type in type_list:
# type check the for loop body using each possible type for iterator value

with self.namespace.enter_scope():
self.namespace[iter_name] = VarInfo(
possible_target_type, modifiability=Modifiability.RUNTIME_CONSTANT
)

try:
with NodeMetadata.enter_typechecker_speculation():
for stmt in node.body:
self.visit(stmt)

self.expr_visitor.visit(node.target, possible_target_type)

if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)):
iter_type = get_exact_type_from_node(node.iter)
# note CMC 2023-10-23: slightly redundant with how type_list is computed
validate_expected_type(node.target, iter_type.value_type)
self.expr_visitor.visit(node.iter, iter_type)
if isinstance(node.iter, vy_ast.List):
len_ = len(node.iter.elements)
self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_))
if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range":
for a in node.iter.args:
self.expr_visitor.visit(a, possible_target_type)
for a in node.iter.keywords:
if a.arg == "bound":
self.expr_visitor.visit(a.value, possible_target_type)

except (TypeMismatch, InvalidOperation) as exc:
for_loop_exceptions.append(exc)
else:
# success -- do not enter error handling section
return

# failed to find a good type. bail out
if len(set(str(i) for i in for_loop_exceptions)) == 1:
# if every attempt at type checking raised the same exception
raise for_loop_exceptions[0]

# return an aggregate TypeMismatch that shows all possible exceptions
# depending on which type is used
types_str = [str(i) for i in type_list]
given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}"
raise TypeMismatch(
f"Iterator value '{iter_name}' may be cast as {given_str}, "
"but type checking fails with all possible types:",
node,
*(
(f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0])
for typ, exc in zip(type_list, for_loop_exceptions)
),
)
# for_loop_exceptions = []
# iter_name = node.target.id
# for possible_target_type in type_list:
# # type check the for loop body using each possible type for iterator value

# with self.namespace.enter_scope():
# self.namespace[iter_name] = VarInfo(
# possible_target_type, modifiability=Modifiability.RUNTIME_CONSTANT
Fixed Show fixed Hide fixed
# )

# try:
# with NodeMetadata.enter_typechecker_speculation():
# for stmt in node.body:
# self.visit(stmt)

# self.expr_visitor.visit(node.target, possible_target_type)

# if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)):
# iter_type = get_exact_type_from_node(node.iter)
# # note CMC 2023-10-23: slightly redundant with how type_list is computed
# validate_expected_type(node.target, iter_type.value_type)
# self.expr_visitor.visit(node.iter, iter_type)
# if isinstance(node.iter, vy_ast.List):
# len_ = len(node.iter.elements)
# self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_))
# if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range":
# for a in node.iter.args:
# self.expr_visitor.visit(a, possible_target_type)
# for a in node.iter.keywords:
# if a.arg == "bound":
# self.expr_visitor.visit(a.value, possible_target_type)
Fixed Show fixed Hide fixed

# except (TypeMismatch, InvalidOperation) as exc:
# for_loop_exceptions.append(exc)
# else:
# # success -- do not enter error handling section
# return
Fixed Show fixed Hide fixed

# # failed to find a good type. bail out
# if len(set(str(i) for i in for_loop_exceptions)) == 1:
# # if every attempt at type checking raised the same exception
# raise for_loop_exceptions[0]

# # return an aggregate TypeMismatch that shows all possible exceptions
# # depending on which type is used
Fixed Show fixed Hide fixed
# types_str = [str(i) for i in type_list]
# given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}"
# raise TypeMismatch(
# f"Iterator value '{iter_name}' may be cast as {given_str}, "
# "but type checking fails with all possible types:",
# node,
# *(
# (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0])
# for typ, exc in zip(type_list, for_loop_exceptions)
# ),
# )

def visit_If(self, node):
validate_expected_type(node.test, BoolT())
Expand Down
Loading
Loading