diff --git a/vyper/ast/annotation.py b/vyper/ast/annotation.py index 90baaab59d..e21f3ae5a2 100644 --- a/vyper/ast/annotation.py +++ b/vyper/ast/annotation.py @@ -85,8 +85,9 @@ def _visit_docstring(self, node): return node def visit_Module(self, node): - node.path = self._module_path node.name = self._module_name + node.path = self._module_path + node.source_id = self._source_id return self._visit_docstring(node) def visit_FunctionDef(self, node): diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 2497928035..604ccd36e9 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -589,7 +589,8 @@ def __contains__(self, obj): class Module(TopLevel): - __slots__ = () + # metadata + __slots__ = ("path", "source_id") def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: """ diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 3bde20356e..2d9073e90c 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -229,6 +229,10 @@ class CallViolation(VyperException): """Illegal function call.""" +class ImportCycle(VyperException): + """An import cycle""" + + class ImmutableViolation(VyperException): """Modifying an immutable variable, constant, or definition.""" diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 6d1d50d688..c39342384b 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,11 +1,13 @@ +import contextlib import enum -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Dict, List, Optional from vyper import ast as vy_ast from vyper.exceptions import ( CompilerPanic, ImmutableViolation, + ImportCycle, StateAccessViolation, VyperInternalException, ) @@ -145,6 +147,37 @@ def __repr__(self): return f"" +@dataclass +class ImportGraph: + _graph: dict[vy_ast.Module, list[vy_ast.Module]] = field(default_factory=dict) + + # the current path in the import graph traversal + _path: list[vy_ast.Module] = field(default_factory=list) + + def push_path(self, module_ast: vy_ast.Module): + if module_ast in self._path: + raise ImportCycle( + msg=" imports ".join(f'"{t.name}" (located at {t.path})' for t in self._path) + ) + + if len(self._path) > 0: + parent = self._graph.setdefault(self._path[-1], []) + parent.append(module_ast) + + self._path.append(module_ast) + + def pop_path(self, expected: vy_ast.Module): + assert expected == self._path.pop() + + @contextlib.contextmanager + def enter_path(self, module_ast: vy_ast.Module): + self.push_path(module_ast) + try: + yield + finally: + self.pop_path(module_ast) + + # base class for things that are the "result" of analysis class AnalysisResult: pass diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index edfb6bde8d..8999084373 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -1,3 +1,4 @@ +import contextlib import os from pathlib import Path, PurePath from typing import Any @@ -18,7 +19,7 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ImportGraph, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.local import validate_functions from vyper.semantics.analysis.utils import ( @@ -34,7 +35,13 @@ from vyper.semantics.types.utils import type_from_annotation -def validate_semantics(module_ast: vy_ast.Module, input_bundle: InputBundle): +def validate_semantics(module_ast, input_bundle): + return validate_semantics_r(module_ast, input_bundle, ImportGraph()) + + +def validate_semantics_r( + module_ast: vy_ast.Module, input_bundle: InputBundle, import_graph: ImportGraph +): """ Analyze a Vyper module AST node, add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info @@ -42,9 +49,8 @@ def validate_semantics(module_ast: vy_ast.Module, input_bundle: InputBundle): # validate semantics and annotate AST with type/semantics information namespace = get_namespace() - with namespace.enter_scope(): - namespace = get_namespace() - analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace) + with namespace.enter_scope(), import_graph.enter_path(module_ast): + analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace, import_graph) analyzer.analyze() vy_ast.expansion.expand_annotated_ast(module_ast) @@ -85,14 +91,32 @@ def _find_cyclic_call(fn_t: ContractFunctionT, path: list = None): class ModuleAnalyzer(VyperNodeVisitorBase): scope_name = "module" + # class object + _ast_of: dict[int, vy_ast.Module] = {} + def __init__( - self, module_node: vy_ast.Module, input_bundle: InputBundle, namespace: Namespace + self, + module_node: vy_ast.Module, + input_bundle: InputBundle, + namespace: Namespace, + import_graph: ImportGraph, ) -> None: self.ast = module_node self.input_bundle = input_bundle self.namespace = namespace + self._import_graph = import_graph + + self.module_t = None def analyze(self) -> ModuleT: + # generate a `ModuleT` from the top-level node + # note: also validates unique method ids + if "type" in self.ast._metadata: + assert isinstance(self.ast._metadata["type"], ModuleT) + # we don't need to analyse again, skip out + self.module_t = self.ast._metadata["type"] + return self.module_t + module_nodes = self.ast.body.copy() while module_nodes: count = len(module_nodes) @@ -113,8 +137,6 @@ def analyze(self) -> ModuleT: if count == len(module_nodes): err_list.raise_if_not_empty() - # generate a `ModuleT` from the top-level node - # note: also validates unique method ids self.module_t = ModuleT(self.ast) self.ast._metadata["type"] = self.module_t @@ -157,6 +179,15 @@ def analyze_call_graph(self): _compute_reachable_set(fn_t) + @classmethod + def _ast_from_file(cls, file: FileInput, alias: str): + if file.source_id not in cls._ast_of: + cls._ast_of[file.source_id] = vy_ast.parse_to_ast( + file.source_code, module_path=str(file.path), module_name=alias + ) + + return cls._ast_of[file.source_id] + def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) if not isinstance(type_, InterfaceT): @@ -355,12 +386,13 @@ def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alia file = self.input_bundle.load_file(path_vy) assert isinstance(file, FileInput) # mypy hint - # TODO share work if same file is imported - module_ast = vy_ast.parse_to_ast( - file.source_code, module_path=str(path_vy), module_name=alias - ) + module_ast = self.__class__._ast_from_file(file, alias) + with override_global_namespace(Namespace()): - validate_semantics(module_ast, self.input_bundle) + with tag_exceptions(node): + validate_semantics_r( + module_ast, self.input_bundle, import_graph=self._import_graph + ) module_t = module_ast._metadata["type"] return ModuleInfo(module_t, decl_node=node) @@ -376,6 +408,14 @@ def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alia raise ModuleNotFoundError(module_str) +@contextlib.contextmanager +def tag_exceptions(node: vy_ast.VyperNode) -> Any: + try: + yield + except VyperException as e: + raise e.with_annotation(node) from None + + # convert an import to a path (without suffix) def _import_to_path(level: int, module_str: str) -> PurePath: base_path = ""