Skip to content

Commit

Permalink
add global initialization constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Feb 3, 2024
1 parent de98cbc commit d82eb3e
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 16 deletions.
8 changes: 5 additions & 3 deletions vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def typ(self):
def set_ownership(self, module_ownership: ModuleOwnership, node: Optional[vy_ast.VyperNode]):
if self.ownership != ModuleOwnership.NO_OWNERSHIP:
raise StructureException(
f"ownership already set to {self.module_ownership}", node, self.ownership_decl
f"ownership already set to {self.ownership}", node, self.ownership_decl
)
self.ownership = module_ownership

Expand All @@ -226,14 +226,16 @@ class ImportInfo(AnalysisResult):
# analysis result of InitializesDecl
@dataclass
class InitializesInfo(AnalysisResult):
module_t: "ModuleT"
module_info: ModuleInfo
dependencies: list["ModuleT"]
node: Optional[vy_ast.VyperNode] = None


# analysis result of UsesDecl
@dataclass
class UsesInfo(AnalysisResult):
used_modules: list["ModuleT"]
used_modules: list[ModuleInfo]
node: Optional[vy_ast.VyperNode] = None


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/analysis/data_positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def set_storage_slots_r(

for node in vyper_module.body:
if isinstance(node, vy_ast.InitializesDecl):
module_t = node._metadata["initializes_info"].module_t
module_t = node._metadata["initializes_info"].module_info.module_t
set_storage_slots_r(module_t._module, allocator)
continue

Expand Down
6 changes: 4 additions & 2 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,10 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:
self.visit(node.func, call_type)

# check mutability level of the function
expr_info = get_expr_info(node.func.value)
expr_info.validate_modification(node, self.func.mutability)
if isinstance(node.func, vy_ast.Attribute) and self.func is not None:
expr_info = get_expr_info(node.func.value)
# TODO: have mutability property on `self` (FunctionAnalyzer)
expr_info.validate_modification(node, self.func.mutability)

if isinstance(call_type, ContractFunctionT):
# function calls
Expand Down
76 changes: 68 additions & 8 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from collections import defaultdict
from pathlib import Path, PurePath
from typing import Any, Optional

Expand Down Expand Up @@ -48,7 +49,11 @@


def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT:
return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface)
ret = validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface)

_validate_global_initializes_constraint(ret)

return ret


def validate_semantics_r(
Expand Down Expand Up @@ -109,6 +114,59 @@ def _compute_reachable_set(fn_t: ContractFunctionT, path: list[ContractFunctionT
path.pop()


def _collect_used_modules_r(module_t):
ret: defaultdict[ModuleT, list[UsesInfo]] = defaultdict(list)

for uses_decl in module_t.uses_decls:
for used_module in uses_decl._metadata["uses_info"].used_modules:
ret[used_module.module_t].append(uses_decl)

for m_info in module_t.used_modules:
used_modules = _collect_used_modules_r(m_info.module_t)
for k, v in used_modules.items():
ret[k].extend(v)

return ret


def _collect_initialized_modules_r(module_t, seen=None):
seen: dict[ModuleT, InitializesInfo] = seen or {}

# list of InitializedInfo
initialized_infos = module_t.initialized_modules

for i in initialized_infos:
if (other := seen.get(i.module_info.module_t)) is not None:
raise StructureException("{i.module_info.alias} initialized twice!", i.node, other)
seen[i.module_info.module_t] = i

for d in i.dependencies:
_collect_initialized_modules_r(d.module_t, seen)

return seen


# validate that each module which is `used` in the import graph is
# `initialized`.
def _validate_global_initializes_constraint(module_t: ModuleT):
all_used_modules = _collect_used_modules_r(module_t)
all_initialized_modules = _collect_initialized_modules_r(module_t)

err_list = ExceptionList()

for u, uses in all_used_modules.items():
if u not in all_initialized_modules:
err_list.append(
StructureException(
f"module {u} is used but never initialized!\n "
f"(hint: add `initializes: module_name` to your main contract",
*uses,
)
)

err_list.raise_if_not_empty()


class ModuleAnalyzer(VyperNodeVisitorBase):
scope_name = "module"

Expand Down Expand Up @@ -193,6 +251,7 @@ def analyze(self) -> ModuleT:

def analyze_call_graph(self):
# get list of internal function calls made by each function
# CMC 2024-02-03 note: this could be cleaner in analysis/local.py
function_defs = self.module_t.function_defs

for func in function_defs:
Expand Down Expand Up @@ -258,7 +317,7 @@ def visit_UsesDecl(self, node):

used_modules.append(module_info)

node._metadata["uses_info"] = UsesInfo(used_modules)
node._metadata["uses_info"] = UsesInfo(used_modules, node)

def visit_InitializesDecl(self, node):
module_ref = node.annotation
Expand All @@ -274,14 +333,15 @@ def visit_InitializesDecl(self, node):
if module_info is None:
raise StructureException("Not a module!", module_ref)

used_modules = module_info.module_t.used_modules.copy()
used_modules = [i.module_t for i in module_info.module_t.used_modules]

dependencies = []
for named_expr in dependencies_ast:
assert isinstance(named_expr, vy_ast.NamedExpr)

with override_global_namespace(module_info.module_node._metadata["namespace"]):
# lhs of the named_expr is evaluated in the namespace of the initialized module
with module_info.module_node.namespace():
# lhs of the named_expr is evaluated in the namespace of the
# initialized module!
lhs_module = get_expr_info(named_expr.target).module_info
rhs_module = get_expr_info(named_expr.value).module_info

Expand All @@ -291,14 +351,14 @@ def visit_InitializesDecl(self, node):
)
dependencies.append(lhs_module)

if lhs_module not in used_modules:
if lhs_module.module_t not in used_modules:
raise StructureException(
f"`{module_info.alias}` is initialized with `{lhs_module.alias}`, "
f"but `{module_info.alias}` does not use `{lhs_module.alias}`!",
named_expr,
)

used_modules.remove(lhs_module)
used_modules.remove(lhs_module.module_t)

if len(used_modules) > 0:
item = used_modules[0] # just pick one
Expand All @@ -312,7 +372,7 @@ def visit_InitializesDecl(self, node):
# note: try to refactor. not a huge fan of mutating the
# ModuleInfo after it's constructed
module_info.set_ownership(ModuleOwnership.INITIALIZES, node)
node._metadata["initializes_info"] = InitializesInfo(module_info.module_t, dependencies)
node._metadata["initializes_info"] = InitializesInfo(module_info, dependencies, node)

def visit_VariableDecl(self, node):
name = node.get("target.id")
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def uses_decls(self):
def initializes_decls(self):
return self._module.get_children(vy_ast.InitializesDecl)

@property
@cached_property
def used_modules(self):
# modules which are written to
ret = []
Expand All @@ -376,7 +376,7 @@ def initialized_modules(self):
ret = []
for node in self.initializes_decls:
info = node._metadata["initializes_info"]
ret.append(info.module_t)
ret.append(info)
return ret

@cached_property
Expand Down

0 comments on commit d82eb3e

Please sign in to comment.