Skip to content

Commit

Permalink
fix inheritance on JSONInputBundle
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Oct 31, 2023
1 parent bdfce82 commit ed98b0d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
2 changes: 1 addition & 1 deletion vyper/cli/vyper_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def compile_files(
contract_sources: ContractCodes = dict()
for file_name in input_files:
file_path = Path(file_name)
contract_sources[file_path] = input_bundle.load_file(Path(file_path))
contract_sources[file_path] = input_bundle.load_file(Path(file_path)).source_code

storage_layouts = dict()
if storage_layout:
Expand Down
24 changes: 14 additions & 10 deletions vyper/compiler/input_bundle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import contextlib
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path, PurePath
from typing import Any, Optional

Expand Down Expand Up @@ -30,22 +30,24 @@ class _NotFound(Exception):
pass


@dataclass
class InputBundle:
search_paths: list[Path]
# compilation_targets: dict[str, str] # contract names => contract sources
source_id_counter = 0
source_ids: dict[Path, int] = field(default_factory=dict)

def __init__(self, search_paths):
self.search_paths = search_paths
self._source_id_counter = 0
self._source_ids: dict[Path, int] = {}

def _load_from_path(self, path):
raise NotImplementedError(f"not implemented! {self.__class__}._load_from_path()")

def _generate_source_id(self, path: Path) -> int:
if path not in self.source_ids:
self.source_ids[path] = self.source_id_counter
self.source_id_counter += 1
if path not in self._source_ids:
self._source_ids[path] = self._source_id_counter
self._source_id_counter += 1

return self.source_ids[path]
return self._source_ids[path]

def load_file(self, path: Path) -> str:
for p in self.search_paths:
Expand Down Expand Up @@ -86,7 +88,6 @@ def search_path(self, path: Optional[Path]) -> None:

# regular input. takes a search path(s), and `load_file()` will search all
# search paths for the file and read it from the filesystem
@dataclass
class FilesystemInputBundle(InputBundle):
def _load_from_path(self, path: Path) -> CompilerInput:
try:
Expand All @@ -102,10 +103,13 @@ def _load_from_path(self, path: Path) -> CompilerInput:
# fake filesystem for JSON inputs. takes a base path, and `load_file()`
# "reads" the file from the JSON input. Note that this input bundle type
# never actually interacts with the filesystem -- it is guaranteed to be pure!
@dataclass
class JSONInputBundle(InputBundle):
input_json: dict[PurePath, Any]

def __init__(self, search_paths, input_json):
super().__init__(search_paths)
self.input_json = input_json

def _load_from_path(self, path: PurePath) -> CompilerInput:
try:
value = self.input_json[path]
Expand Down
14 changes: 8 additions & 6 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import vyper.builtins.interfaces
from vyper import ast as vy_ast
from vyper.compiler.input_bundle import InputBundle
from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle
from vyper.evm.opcodes import version_check
from vyper.exceptions import (
CallViolation,
Expand Down Expand Up @@ -317,7 +317,6 @@ def visit_StructDef(self, node):
def _add_import(
self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str
) -> None:

type_ = self._load_import(level, qualified_module_name)

try:
Expand Down Expand Up @@ -352,12 +351,15 @@ def _import_to_path(level: int, module_str: str) -> PurePath:
base_path = "./"
return PurePath(f"{base_path}{module_str.replace('.','/')}/")


# can add more, e.g. "vyper.builtins.interfaces", etc.
BUILTIN_PREFIXES = ["vyper.interfaces"]


def _is_builtin(module_str):
return any(module_str.startswith(prefix) for prefix in BUILTIN_PREFIXES)


def _load_builtin_import(level: int, module_str: str):
if not _is_builtin(module_str):
raise ModuleNotFoundError(f"Not a builtin: {module_str}")
Expand All @@ -367,10 +369,10 @@ def _load_builtin_import(level: int, module_str: str):

# remap builtins directory --
# vyper/interfaces => vyper/builtins/interfaces
module_str = (
vyper.builtins.interfaces.__package__
+ qualified_module_name.removeprefix(INTERFACES_PATH)
)
if module_str.startswith("vyper.interfaces"):
module_str = vyper.builtins.interfaces.__package__ + module_str.removeprefix(
"vyper.interfaces"
)
path = _import_to_path(level, module_str).with_suffix(".vy")

try:
Expand Down

0 comments on commit ed98b0d

Please sign in to comment.