diff --git a/rechunker/executors/__init__.py b/rechunker/executors/__init__.py index a4f5364..9d0d29c 100644 --- a/rechunker/executors/__init__.py +++ b/rechunker/executors/__init__.py @@ -15,3 +15,17 @@ __all__.append("PrefectPipelineExecutor") except ImportError: pass + +try: + from .beam import BeamExecutor + + __all__.append("BeamExecutor") +except ImportError: + pass + +try: + from .beam import BeamExecutor + + __all__.append("BeamExecutor") +except ImportError: + pass diff --git a/rechunker/executors/beam.py b/rechunker/executors/beam.py index a2d51b0..d34a001 100644 --- a/rechunker/executors/beam.py +++ b/rechunker/executors/beam.py @@ -8,7 +8,40 @@ chunk_keys, split_into_direct_copies, ) -from rechunker.types import CopySpec, CopySpecExecutor, ReadableArray, WriteableArray +from rechunker.types import ( + CopySpec, + CopySpecExecutor, + ReadableArray, + WriteableArray, + PipelineExecutor, + ParallelPipelines, + StatefulMultiStagePipeline +) + + +class BeamPipelineExecutor(PipelineExecutor[beam.PTransform]): + """An execution engine based on Apache Beam. + + Execution plans for BeamExecutors are beam.PTransform objects. + """ + + def pipelines_to_plan(self, pipelines: ParallelPipelines) -> beam.PTransform: + return ( + "Create Pipelines" >> beam.Create(pipelines) + | "Execute Parallel Pipelines" >> beam.Map(execute_pipeline) + ) + + def execute_plan(self, plan: beam.PTransform, **kwargs): + with beam.Pipeline(**kwargs) as pipeline: + pipeline | plan + + +def execute_pipeline(pipeline: StatefulMultiStagePipeline) -> None: + for stage in pipeline.stages: + if stage.map_args is None: + stage.func(context=pipeline.context) + else: + stage.func(stage.map_args, context=pipeline.context) class BeamExecutor(CopySpecExecutor[beam.PTransform]): @@ -63,20 +96,20 @@ def __init__(self, specs_by_target: Mapping[str, DirectCopySpec]): def expand(self, pcoll): return ( - pcoll - | "Start" >> beam.FlatMap(_start_stage, self.specs_by_target) - | "CreateTasks" >> beam.FlatMapTuple(_copy_tasks) - # prevent undesirable fusion - # https://stackoverflow.com/a/54131856/809705 - | "Reshuffle" >> beam.Reshuffle() - | "CopyChunks" >> beam.MapTuple(_copy_chunk) - # prepare inputs for the next stage (if any) - | "Finish" >> beam.Distinct() + pcoll + | "Start" >> beam.FlatMap(_start_stage, self.specs_by_target) + | "CreateTasks" >> beam.FlatMapTuple(_copy_tasks) + # prevent undesirable fusion + # https://stackoverflow.com/a/54131856/809705 + | "Reshuffle" >> beam.Reshuffle() + | "CopyChunks" >> beam.MapTuple(_copy_chunk) + # prepare inputs for the next stage (if any) + | "Finish" >> beam.Distinct() ) def _start_stage( - target_id: str, specs_by_target: Mapping[str, DirectCopySpec], + target_id: str, specs_by_target: Mapping[str, DirectCopySpec], ) -> Iterator[Tuple[str, DirectCopySpec]]: spec = specs_by_target.get(target_id) if spec is not None: @@ -84,17 +117,17 @@ def _start_stage( def _copy_tasks( - target_id: str, spec: DirectCopySpec + target_id: str, spec: DirectCopySpec ) -> Iterator[Tuple[str, Tuple[slice, ...], ReadableArray, WriteableArray]]: for key in chunk_keys(spec.source.shape, spec.chunks): yield target_id, key, spec.source, spec.target def _copy_chunk( - target_id: str, - key: Tuple[slice, ...], - source: ReadableArray, - target: WriteableArray, + target_id: str, + key: Tuple[slice, ...], + source: ReadableArray, + target: WriteableArray, ) -> str: target[key] = source[key] return target_id diff --git a/rechunker/types.py b/rechunker/types.py index aa44431..6b9caa3 100644 --- a/rechunker/types.py +++ b/rechunker/types.py @@ -2,12 +2,14 @@ from typing import ( Any, Callable, + Dict, Generic, Iterable, NamedTuple, Optional, Tuple, TypeVar, + Union, ) # TODO: replace with Protocols, once Python 3.8+ is required @@ -57,7 +59,7 @@ class CopySpec(NamedTuple): class Stage(NamedTuple): - """A Stage is when single function is mapped over multiple imputs. + """A Stage is when single function is mapped over multiple inputs. Attributes ---------- @@ -66,15 +68,28 @@ class Stage(NamedTuple): map_args: List, Optional Arguments which will be mapped to the function """ - func: Callable map_args: Optional[Iterable] = None # TODO: figure out how to make optional, like for a dataclass # annotations: Dict = {} +class StatefulMultiStagePipeline(NamedTuple): + """A pipeline where each stage shares the same context. + + Attributes + ---------- + stages: Iterable[Stage] + A function to be called in this stage. Accepts either 0 or 1 arguments. + context: Dict, Optional + Named shared state for all stages. + """ + stages: Iterable[Stage] + context: Optional[Dict] = None + + # A MultiStagePipeline contains one or more stages, to be executed in sequence -MultiStagePipeline = Iterable[Stage] +MultiStagePipeline = Union[Iterable[Stage], StatefulMultiStagePipeline] # ParallelPipelines contains one or more MultiStagePipelines, to be executed in parallel ParallelPipelines = Iterable[MultiStagePipeline] diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index e19da54..e81b696 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -10,6 +10,7 @@ from rechunker.executors.dask import DaskPipelineExecutor from rechunker.executors.prefect import PrefectPipelineExecutor from rechunker.executors.python import PythonPipelineExecutor +from rechunker.executors.beam import BeamPipelineExecutor from rechunker.types import Stage @@ -37,7 +38,7 @@ def func1(arg): @pytest.mark.parametrize( - "Executor", [PythonPipelineExecutor, DaskPipelineExecutor, PrefectPipelineExecutor] + "Executor", [PythonPipelineExecutor, DaskPipelineExecutor, PrefectPipelineExecutor, BeamPipelineExecutor] ) def test_pipeline(example_pipeline, Executor): pipeline, tmpdir = example_pipeline