Skip to content

Commit

Permalink
Change CompileAPI to ecosystem based compilers
Browse files Browse the repository at this point in the history
  • Loading branch information
bilbeyt committed Aug 25, 2023
1 parent bc52af9 commit dcf62c0
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 36 deletions.
5 changes: 5 additions & 0 deletions src/ape/api/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class CompilerAPI(BaseInterfaceModel):
def name(self) -> str:
...

@property
@abstractmethod
def extension(self) -> str:
...

@abstractmethod
def get_versions(self, all_paths: List[Path]) -> Set[str]:
"""
Expand Down
3 changes: 2 additions & 1 deletion src/ape/api/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def contracts(self) -> Dict[str, ContractType]:

@property
def _cache_folder(self) -> Path:
folder = self.contracts_folder.parent / ".build"
current_ecosystem = self.network_manager.network.ecosystem.name
folder = self.contracts_folder.parent / ".build" / current_ecosystem
# NOTE: If we use the cache folder, we expect it to exist
folder.mkdir(exist_ok=True, parents=True)
return folder
Expand Down
54 changes: 35 additions & 19 deletions src/ape/managers/compilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def __getattr__(self, name: str) -> Any:

raise ApeAttributeError(f"No attribute or compiler named '{name}'.")

@property
def supported_extensions(self) -> Set[str]:
extensions = set()
for compiler in self.registered_compilers.values():
extensions.add(compiler.extension)
return extensions

@property
def registered_compilers(self) -> Dict[str, CompilerAPI]:
"""
Expand All @@ -52,30 +59,35 @@ def registered_compilers(self) -> Dict[str, CompilerAPI]:
Dict[str, :class:`~ape.api.compiler.CompilerAPI`]: The mapping of file-extensions
to compiler API classes.
"""
current_ecosystem = self.network_manager.network.ecosystem.name
ecosystem_config = self.config_manager.get_config(current_ecosystem)
try:
supported_compilers = ecosystem_config.compilers
except AttributeError:
raise CompilerError(f"No compilers defined for ecosystem={current_ecosystem}.")

cache_key = self.config_manager.PROJECT_FOLDER
if cache_key in self._registered_compilers_cache:
return self._registered_compilers_cache[cache_key]

registered_compilers = {}

for plugin_name, (extensions, compiler_class) in self.plugin_manager.register_compiler:
for plugin_name, compiler_class in self.plugin_manager.register_compiler:
# TODO: Investigate side effects of loading compiler plugins.
# See if this needs to be refactored.
self.config_manager.get_config(plugin_name=plugin_name)

compiler = compiler_class()

for extension in extensions:
if extension not in registered_compilers:
registered_compilers[extension] = compiler
if compiler.name in supported_compilers:
registered_compilers[compiler.name] = compiler

self._registered_compilers_cache[cache_key] = registered_compilers
return registered_compilers

def get_compiler(self, name: str) -> Optional[CompilerAPI]:
def get_compiler(self, identifier: str) -> Optional[CompilerAPI]:
for compiler in self.registered_compilers.values():
if compiler.name == name:
if compiler.name == identifier or compiler.extension == identifier:
return compiler

return None
Expand Down Expand Up @@ -124,10 +136,9 @@ def compile(self, contract_filepaths: List[Path]) -> Dict[str, ContractType]:
for path in paths_to_compile:
source_id = get_relative_path(path, contracts_folder)
logger.info(f"Compiling '{source_id}'.")

compiled_contracts = self.registered_compilers[extension].compile(
paths_to_compile, base_path=contracts_folder
)
compiler = self.get_compiler(extension)
assert compiler is not None
compiled_contracts = compiler.compile(paths_to_compile, base_path=contracts_folder)
for contract_type in compiled_contracts:
contract_name = contract_type.name
if not contract_name:
Expand Down Expand Up @@ -176,9 +187,11 @@ def get_imports(
imports_dict: Dict[str, List[str]] = {}
base_path = base_path or self.project_manager.contracts_folder

for ext, compiler in self.registered_compilers.items():
for compiler in self.registered_compilers.values():
try:
sources = [p for p in contract_filepaths if p.suffix == ext and p.is_file()]
sources = [
p for p in contract_filepaths if p.suffix == compiler.extension and p.is_file()
]
imports = compiler.get_imports(contract_filepaths=sources, base_path=base_path)
except NotImplementedError:
imports = None
Expand Down Expand Up @@ -214,7 +227,7 @@ def get_references(self, imports_dict: Dict[str, List[str]]) -> Dict[str, List[s

def _get_contract_extensions(self, contract_filepaths: List[Path]) -> Set[str]:
extensions = {path.suffix for path in contract_filepaths}
unhandled_extensions = {s for s in extensions - set(self.registered_compilers) if s}
unhandled_extensions = {s for s in extensions - self.supported_extensions if s}
if len(unhandled_extensions) > 0:
unhandled_extensions_str = ", ".join(unhandled_extensions)
raise CompilerError(f"No compiler found for extensions [{unhandled_extensions_str}].")
Expand Down Expand Up @@ -249,11 +262,12 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError:
return err

ext = Path(contract.source_id).suffix
if ext not in self.registered_compilers:
if ext not in self.supported_extensions:
# Compiler not found.
return err

compiler = self.registered_compilers[ext]
compiler = self.get_compiler(ext)
assert compiler is not None
return compiler.enrich_error(err)

def flatten_contract(self, path: Path) -> Content:
Expand All @@ -268,12 +282,13 @@ def flatten_contract(self, path: Path) -> Content:
``ethpm_types.source.Content``: The flattened contract content.
"""

if path.suffix not in self.registered_compilers:
if path.suffix not in self.supported_extensions:
raise CompilerError(
f"Unable to flatten contract. Missing compiler for '{path.suffix}'."
)

compiler = self.registered_compilers[path.suffix]
compiler = self.get_compiler(path.suffix)
assert compiler is not None
return compiler.flatten_contract(path)

def can_trace_source(self, filename: str) -> bool:
Expand All @@ -293,8 +308,9 @@ def can_trace_source(self, filename: str) -> bool:
return False

extension = path.suffix
if extension in self.registered_compilers:
compiler = self.registered_compilers[extension]
if extension in self.supported_extensions:
compiler = self.get_compiler(extension)
assert compiler is not None
if compiler.supports_source_tracing:
return True

Expand Down
16 changes: 10 additions & 6 deletions src/ape/managers/project/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def source_paths(self) -> List[Path]:
if not self.contracts_folder.is_dir():
return files

for extension in self.compiler_manager.registered_compilers:
for extension in self.compiler_manager.supported_extensions:
files.extend((x for x in self.contracts_folder.rglob(f"*{extension}") if x.is_file()))

return files
Expand Down Expand Up @@ -169,8 +169,10 @@ def _get_compiler_data(self, compile_if_needed: bool = True):
)
compiler_list: List[Compiler] = []
contracts_folder = self.config_manager.contracts_folder
for ext, compiler in self.compiler_manager.registered_compilers.items():
sources = [x for x in self.source_paths if x.is_file() and x.suffix == ext]
for compiler in self.compiler_manager.registered_compilers.values():
sources = [
x for x in self.source_paths if x.is_file() and x.suffix == compiler.extension
]
if not sources:
continue

Expand All @@ -183,7 +185,9 @@ def _get_compiler_data(self, compile_if_needed: bool = True):
# These are unlikely to be part of the published manifest
continue
elif len(versions) > 1:
raise (ProjectError(f"Unable to create version map for '{ext}'."))
raise (
ProjectError(f"Unable to create version map for '{compiler.extension}'.")
)

version = versions[0]
version_map = {version: sources}
Expand Down Expand Up @@ -336,7 +340,7 @@ def get_project(
else path / "contracts"
)
if not contracts_folder.is_dir():
extensions = list(self.compiler_manager.registered_compilers.keys())
extensions = list(self.compiler_manager.supported_extensions)
path_patterns_to_ignore = self.config_manager.compiler.ignore_files

def find_contracts_folder(sub_dir: Path) -> Optional[Path]:
Expand Down Expand Up @@ -586,7 +590,7 @@ def _append_extensions_in_dir(directory: Path):
elif (
file.suffix
and file.suffix not in extensions_found
and file.suffix not in self.compiler_manager.registered_compilers
and file.suffix not in self.compiler_manager.supported_extensions
):
extensions_found.append(file.suffix)

Expand Down
4 changes: 2 additions & 2 deletions src/ape/managers/project/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def source_paths(self) -> List[Path]:
return files

compilers = self.compiler_manager.registered_compilers
for extension in compilers:
ext = extension.replace(".", "\\.")
for compiler in compilers.values():
ext = compiler.extension.replace(".", "\\.")
pattern = rf"[\w|-]+{ext}"
ext_files = get_all_files_in_directory(self.contracts_folder, pattern=pattern)
files.extend(ext_files)
Expand Down
8 changes: 4 additions & 4 deletions src/ape/plugins/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Type
from typing import Type

from ape.api import CompilerAPI

Expand All @@ -13,7 +13,7 @@ class CompilerPlugin(PluginType):
"""

@hookspec
def register_compiler(self) -> Tuple[Tuple[str], Type[CompilerAPI]]: # type: ignore[empty-body]
def register_compiler(self) -> Type[CompilerAPI]: # type: ignore[empty-body]
"""
A hook for returning the set of file extensions the plugin handles
and the compiler class that can be used to compile them.
Expand All @@ -22,8 +22,8 @@ def register_compiler(self) -> Tuple[Tuple[str], Type[CompilerAPI]]: # type: ig
@plugins.register(plugins.CompilerPlugin)
def register_compiler():
return (".json",), InterfaceCompiler
return InterfaceCompiler
Returns:
Tuple[Tuple[str], Type[:class:`~ape.api.CompilerAPI`]]
Type[:class:`~ape.api.CompilerAPI`]
"""
5 changes: 3 additions & 2 deletions src/ape/pytest/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ def _init_coverage_profile(
for src in self.sources:
source_cov = project_coverage.include(src)
ext = Path(src.source_id).suffix
if ext not in self.compiler_manager.registered_compilers:
if ext not in self.compiler_manager.supported_extensions:
continue

compiler = self.compiler_manager.registered_compilers[ext]
compiler = self.compiler_manager.get_compiler(ext)
assert compiler is not None
try:
compiler.init_coverage_profile(source_cov, src)
except NotImplementedError:
Expand Down
5 changes: 3 additions & 2 deletions src/ape/types/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,10 +503,11 @@ def create(cls, contract_type: ContractType, trace: Iterator[TraceFrame], data:
return cls.parse_obj([])

ext = f".{source_id.split('.')[-1]}"
if ext not in accessor.compiler_manager.registered_compilers:
if ext not in accessor.compiler_manager.supported_extensions:
return cls.parse_obj([])

compiler = accessor.compiler_manager.registered_compilers[ext]
compiler = accessor.compiler_manager.get_compiler(ext)
assert compiler is not None
try:
return compiler.trace_source(contract_type, trace, data)
except NotImplementedError:
Expand Down

0 comments on commit dcf62c0

Please sign in to comment.