Skip to content

Commit

Permalink
detect import cycles
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Nov 11, 2023
1 parent febaa96 commit 19404e9
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 16 deletions.
3 changes: 2 additions & 1 deletion vyper/ast/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def _visit_docstring(self, node):
return node

def visit_Module(self, node):
node.path = self._module_path
node.name = self._module_name
node.path = self._module_path
node.source_id = self._source_id
return self._visit_docstring(node)

def visit_FunctionDef(self, node):
Expand Down
3 changes: 2 additions & 1 deletion vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,8 @@ def __contains__(self, obj):


class Module(TopLevel):
__slots__ = ()
# metadata
__slots__ = ("path", "source_id")

def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None:
"""
Expand Down
4 changes: 4 additions & 0 deletions vyper/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ class CallViolation(VyperException):
"""Illegal function call."""


class ImportCycle(VyperException):
"""An import cycle"""


class ImmutableViolation(VyperException):
"""Modifying an immutable variable, constant, or definition."""

Expand Down
35 changes: 34 additions & 1 deletion vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import contextlib
import enum
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from vyper import ast as vy_ast
from vyper.exceptions import (
CompilerPanic,
ImmutableViolation,
ImportCycle,
StateAccessViolation,
VyperInternalException,
)
Expand Down Expand Up @@ -145,6 +147,37 @@ def __repr__(self):
return f"<CodeOffset: {self.offset}>"


@dataclass
class ImportGraph:
_graph: dict[vy_ast.Module, list[vy_ast.Module]] = field(default_factory=dict)

# the current path in the import graph traversal
_path: list[vy_ast.Module] = field(default_factory=list)

def push_path(self, module_ast: vy_ast.Module):
if module_ast in self._path:
raise ImportCycle(
msg=" imports ".join(f'"{t.name}" (located at {t.path})' for t in self._path)
)

if len(self._path) > 0:
parent = self._graph.setdefault(self._path[-1], [])
parent.append(module_ast)

self._path.append(module_ast)

def pop_path(self, expected: vy_ast.Module):
assert expected == self._path.pop()

@contextlib.contextmanager
def enter_path(self, module_ast: vy_ast.Module):
self.push_path(module_ast)
try:
yield
finally:
self.pop_path(module_ast)


# base class for things that are the "result" of analysis
class AnalysisResult:
pass
Expand Down
66 changes: 53 additions & 13 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import os
from pathlib import Path, PurePath
from typing import Any
Expand All @@ -18,7 +19,7 @@
VariableDeclarationException,
VyperException,
)
from vyper.semantics.analysis.base import ModuleInfo, VarInfo
from vyper.semantics.analysis.base import ImportGraph, ModuleInfo, VarInfo
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.local import validate_functions
from vyper.semantics.analysis.utils import (
Expand All @@ -34,17 +35,22 @@
from vyper.semantics.types.utils import type_from_annotation


def validate_semantics(module_ast: vy_ast.Module, input_bundle: InputBundle):
def validate_semantics(module_ast, input_bundle):
return validate_semantics_r(module_ast, input_bundle, ImportGraph())


def validate_semantics_r(
module_ast: vy_ast.Module, input_bundle: InputBundle, import_graph: ImportGraph
):
"""
Analyze a Vyper module AST node, add all module-level objects to the
namespace, type-check/validate semantics and annotate with type and analysis info
"""
# validate semantics and annotate AST with type/semantics information
namespace = get_namespace()

with namespace.enter_scope():
namespace = get_namespace()
analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace)
with namespace.enter_scope(), import_graph.enter_path(module_ast):
analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace, import_graph)
analyzer.analyze()

vy_ast.expansion.expand_annotated_ast(module_ast)
Expand Down Expand Up @@ -85,14 +91,32 @@ def _find_cyclic_call(fn_t: ContractFunctionT, path: list = None):
class ModuleAnalyzer(VyperNodeVisitorBase):
scope_name = "module"

# class object
_ast_of: dict[int, vy_ast.Module] = {}

def __init__(
self, module_node: vy_ast.Module, input_bundle: InputBundle, namespace: Namespace
self,
module_node: vy_ast.Module,
input_bundle: InputBundle,
namespace: Namespace,
import_graph: ImportGraph,
) -> None:
self.ast = module_node
self.input_bundle = input_bundle
self.namespace = namespace
self._import_graph = import_graph

self.module_t = None

def analyze(self) -> ModuleT:
# generate a `ModuleT` from the top-level node
# note: also validates unique method ids
if "type" in self.ast._metadata:
assert isinstance(self.ast._metadata["type"], ModuleT)
# we don't need to analyse again, skip out
self.module_t = self.ast._metadata["type"]
return self.module_t

module_nodes = self.ast.body.copy()
while module_nodes:
count = len(module_nodes)
Expand All @@ -113,8 +137,6 @@ def analyze(self) -> ModuleT:
if count == len(module_nodes):
err_list.raise_if_not_empty()

# generate a `ModuleT` from the top-level node
# note: also validates unique method ids
self.module_t = ModuleT(self.ast)
self.ast._metadata["type"] = self.module_t

Expand Down Expand Up @@ -157,6 +179,15 @@ def analyze_call_graph(self):

_compute_reachable_set(fn_t)

@classmethod
def _ast_from_file(cls, file: FileInput, alias: str):
if file.source_id not in cls._ast_of:
cls._ast_of[file.source_id] = vy_ast.parse_to_ast(
file.source_code, module_path=str(file.path), module_name=alias
)

return cls._ast_of[file.source_id]

def visit_ImplementsDecl(self, node):
type_ = type_from_annotation(node.annotation)
if not isinstance(type_, InterfaceT):
Expand Down Expand Up @@ -355,12 +386,13 @@ def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alia
file = self.input_bundle.load_file(path_vy)
assert isinstance(file, FileInput) # mypy hint

# TODO share work if same file is imported
module_ast = vy_ast.parse_to_ast(
file.source_code, module_path=str(path_vy), module_name=alias
)
module_ast = self.__class__._ast_from_file(file, alias)

with override_global_namespace(Namespace()):
validate_semantics(module_ast, self.input_bundle)
with tag_exceptions(node):
validate_semantics_r(
module_ast, self.input_bundle, import_graph=self._import_graph
)
module_t = module_ast._metadata["type"]

return ModuleInfo(module_t, decl_node=node)
Expand All @@ -376,6 +408,14 @@ def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alia
raise ModuleNotFoundError(module_str)


@contextlib.contextmanager
def tag_exceptions(node: vy_ast.VyperNode) -> Any:
try:
yield
except VyperException as e:
raise e.with_annotation(node) from None


# convert an import to a path (without suffix)
def _import_to_path(level: int, module_str: str) -> PurePath:
base_path = ""
Expand Down

0 comments on commit 19404e9

Please sign in to comment.