diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 3bde20356e6..2d9073e90ce 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -229,6 +229,10 @@ class CallViolation(VyperException): """Illegal function call.""" +class ImportCycle(VyperException): + """An import cycle""" + + class ImmutableViolation(VyperException): """Modifying an immutable variable, constant, or definition.""" diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index edfb6bde8d4..4cb2e64c43d 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -34,7 +34,11 @@ from vyper.semantics.types.utils import type_from_annotation -def validate_semantics(module_ast: vy_ast.Module, input_bundle: InputBundle): +def validate_semantics(module_ast, input_bundle): + return validate_semantics_r(module_ast, input_bundle, []) + + +def validate_semantics_r(module_ast: vy_ast.Module, input_bundle: InputBundle, import_graph: list = None): """ Analyze a Vyper module AST node, add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info @@ -44,7 +48,7 @@ def validate_semantics(module_ast: vy_ast.Module, input_bundle: InputBundle): with namespace.enter_scope(): namespace = get_namespace() - analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace) + analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace, import_graph) analyzer.analyze() vy_ast.expansion.expand_annotated_ast(module_ast) @@ -85,14 +89,37 @@ def _find_cyclic_call(fn_t: ContractFunctionT, path: list = None): class ModuleAnalyzer(VyperNodeVisitorBase): scope_name = "module" + # class object + _ast_of: dict[FileInput, vy_ast.Module] + def __init__( - self, module_node: vy_ast.Module, input_bundle: InputBundle, namespace: Namespace + self, module_node: vy_ast.Module, input_bundle: InputBundle, namespace: Namespace, import_graph: list ) -> None: self.ast = module_node self.input_bundle = input_bundle self.namespace = namespace + self.import_graph = import_graph + + self.module_t = None def analyze(self) -> ModuleT: + # generate a `ModuleT` from the top-level node + # note: also validates unique method ids + if "type" in self.ast._metadata: + assert isinstance(self.ast._metadata["type"], ModuleT) + # we don't need to analyse again, skip out + self.module_t = self.ast._metadata["type"] + return self.module_t + + else: + self.module_t = ModuleT(self.ast) + self.ast._metadata["type"] = self.module_t + + if self.module_t in self.import_graph: + raise ImportCycle(" -> ".join(str(t) for t in self.import_graph + [self.module_t])) + + self.import_graph.append(self.module_t) + module_nodes = self.ast.body.copy() while module_nodes: count = len(module_nodes) @@ -113,11 +140,6 @@ def analyze(self) -> ModuleT: if count == len(module_nodes): err_list.raise_if_not_empty() - # generate a `ModuleT` from the top-level node - # note: also validates unique method ids - self.module_t = ModuleT(self.ast) - self.ast._metadata["type"] = self.module_t - # attach namespace to the module for downstream use. _ns = Namespace() # note that we don't just copy the namespace because @@ -127,6 +149,8 @@ def analyze(self) -> ModuleT: self.analyze_call_graph() + assert self.module_t == self.import_graph.pop() + def analyze_call_graph(self): # get list of internal function calls made by each function function_defs = self.module_t.functions @@ -157,6 +181,15 @@ def analyze_call_graph(self): _compute_reachable_set(fn_t) + @classmethod + def _ast_from_file(cls, file: FileInput): + if file not in cls._ast_of: + cls._ast_of[file] = vy_ast.parse_to_ast( + file.source_code, module_path=str(path_vy), module_name=alias + ) + + return cls._ast_of[file] + def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) if not isinstance(type_, InterfaceT): @@ -355,12 +388,10 @@ def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alia file = self.input_bundle.load_file(path_vy) assert isinstance(file, FileInput) # mypy hint - # TODO share work if same file is imported - module_ast = vy_ast.parse_to_ast( - file.source_code, module_path=str(path_vy), module_name=alias - ) + module_ast = self.__class__._ast_from_file(file) + with override_global_namespace(Namespace()): - validate_semantics(module_ast, self.input_bundle) + validate_semantics_r(module_ast, self.input_bundle, import_graph=self._import_graph) module_t = module_ast._metadata["type"] return ModuleInfo(module_t, decl_node=node)