Skip to content

Commit

Permalink
Generalizing jobs argument implementation in Namespace (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
thirtytwobits authored Jan 9, 2025
1 parent 453f667 commit 91bb087
Showing 1 changed file with 93 additions and 52 deletions.
145 changes: 93 additions & 52 deletions src/nunavut/_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"""

import collections
import itertools
import multiprocessing
import multiprocessing.pool
import sys
Expand All @@ -65,7 +66,9 @@
KeysView,
List,
Optional,
Protocol,
Tuple,
TypeVar,
Union,
cast,
)
Expand All @@ -84,6 +87,86 @@ def _register(self, cls, method=None): # type: ignore

singledispatchmethod.register = _register # type: ignore

# +--------------------------------------------------------------------------------------------------------------------+


class AsyncResultProtocol(Protocol):
"""
Defines the protocol for a duck-type compatible with multiprocessing.pool.AsyncResult.
"""

def get(self, timeout: Optional[Any] = None) -> Any:
"""
See multiprocessing.pool.AsyncResult.get
"""


class NotAsyncResult:
"""
Duck-type compatible with multiprocessing.pool.AsyncResult that is not actually asynchronous. All work is performed
synchronously when the get method is called.
"""

def __init__(self, read_method: Callable[..., Any], args: Tuple[Any, ...]) -> None:
self.read_method = read_method
self.args = args

def get(self, _: Optional[Any] = None) -> Any:
"""
Perform the work synchronously.
"""
return self.read_method(*self.args)


ApplyMethodT = TypeVar("ApplyMethodT", bound=Callable[..., AsyncResultProtocol])


def _read_files_strategy(
index: "Namespace",
apply_method: ApplyMethodT,
dsdl_files: Union[Path, str, Iterable[Union[Path, str]]],
job_timeout_seconds: float,
omit_dependencies: bool,
args: Iterable[Any],
) -> "Namespace":
"""
Strategy for reading a set of dsdl files and building a namespace tree. This strategy is compatible with both
synchronous and asynchronous invocation of the pydsdl.read_files method.
"""
if isinstance(dsdl_files, (str, Path)):
fileset = {Path(dsdl_files)}
else:
fileset = {Path(file) for file in dsdl_files}
resolve_cache: dict[Path, Path] = {}

def _resolve_file(file: Path) -> Path:
# limit filesystem access by caching resolved files. This assumes the non-canonical path form
# is consistent. If not this would not limit filesystem access but would be correct.
if file not in resolve_cache:
resolve_cache[file] = file.resolve()
return resolve_cache[file]

running_lookups: list[AsyncResultProtocol] = []
already_read: set[Path] = set()
while fileset:
next_file = fileset.pop()
running_lookups.append(apply_method(pydsdl.read_files, args=itertools.chain([next_file], args)))
already_read.add(_resolve_file(next_file))
if not fileset:
for lookup in running_lookups:
if job_timeout_seconds <= 0:
target_type, dependent_types = lookup.get()
else:
target_type, dependent_types = lookup.get(timeout=job_timeout_seconds)
Namespace.add_types(index, (target_type[0], dependent_types))
if not omit_dependencies:
for dependent_type in dependent_types:
if _resolve_file(dependent_type.source_file_path) not in already_read:
fileset.add(dependent_type.source_file_path)
running_lookups.clear()

return index


# +--------------------------------------------------------------------------------------------------------------------+
class Generatable(type(Path())): # type: ignore
Expand Down Expand Up @@ -480,62 +563,20 @@ def read_files(
if not index.is_index:
raise ValueError("Namespace passed in as index argument is not an index namespace.")

if isinstance(dsdl_files, (str, Path)):
fileset = {Path(dsdl_files)}
else:
fileset = {Path(file) for file in dsdl_files}

already_read: set[Path] = set()

args = (
root_namespace_directories_or_names,
lookup_directories,
print_output_handler,
allow_unregulated_fixed_port_id,
)
if jobs == 1:
# Don't use multiprocessing when jobs is 1.
while fileset:
next_file = fileset.pop()
target_type, dependent_types = pydsdl.read_files(
next_file,
root_namespace_directories_or_names,
lookup_directories,
print_output_handler,
allow_unregulated_fixed_port_id,
)
already_read.add(next_file) # TODO: canonical paths for keying here?
Namespace.add_types(index, (target_type[0], dependent_types))
if not omit_dependencies:
for dependent_type in dependent_types:
if dependent_type.source_file_path not in already_read:
fileset.add(dependent_type.source_file_path)
return _read_files_strategy(index, NotAsyncResult, dsdl_files, job_timeout_seconds, omit_dependencies, args)
else:
running_lookups: list[multiprocessing.pool.AsyncResult] = []
with multiprocessing.pool.Pool(processes=None if jobs == 0 else jobs) as pool:
while fileset:
next_file = fileset.pop()
running_lookups.append(
pool.apply_async(
pydsdl.read_files,
args=(
next_file,
root_namespace_directories_or_names,
lookup_directories,
print_output_handler,
allow_unregulated_fixed_port_id,
),
)
)
already_read.add(next_file) # TODO: canonical paths for keying here?
if not fileset:
for lookup in running_lookups:
if job_timeout_seconds <= 0:
target_type, dependent_types = lookup.get()
else:
target_type, dependent_types = lookup.get(timeout=job_timeout_seconds)
Namespace.add_types(index, (target_type[0], dependent_types))
if not omit_dependencies:
for dependent_type in dependent_types:
if dependent_type.source_file_path not in already_read:
fileset.add(dependent_type.source_file_path)
running_lookups.clear()

return index
return _read_files_strategy(
index, pool.apply_async, dsdl_files, job_timeout_seconds, omit_dependencies, args
)

@read_files.register
@classmethod
Expand Down

0 comments on commit 91bb087

Please sign in to comment.