Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not merge] Demo of serialization issues in Beam executor with current pipeline model. #99

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions rechunker/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,17 @@
__all__.append("PrefectPipelineExecutor")
except ImportError:
pass

try:
from .beam import BeamExecutor

__all__.append("BeamExecutor")
except ImportError:
pass

try:
from .beam import BeamExecutor

__all__.append("BeamExecutor")
except ImportError:
pass
65 changes: 49 additions & 16 deletions rechunker/executors/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,40 @@
chunk_keys,
split_into_direct_copies,
)
from rechunker.types import CopySpec, CopySpecExecutor, ReadableArray, WriteableArray
from rechunker.types import (
CopySpec,
CopySpecExecutor,
ReadableArray,
WriteableArray,
PipelineExecutor,
ParallelPipelines,
StatefulMultiStagePipeline
)


class BeamPipelineExecutor(PipelineExecutor[beam.PTransform]):
"""An execution engine based on Apache Beam.

Execution plans for BeamExecutors are beam.PTransform objects.
"""

def pipelines_to_plan(self, pipelines: ParallelPipelines) -> beam.PTransform:
return (
"Create Pipelines" >> beam.Create(pipelines)
| "Execute Parallel Pipelines" >> beam.Map(execute_pipeline)
)

def execute_plan(self, plan: beam.PTransform, **kwargs):
with beam.Pipeline(**kwargs) as pipeline:
pipeline | plan


def execute_pipeline(pipeline: StatefulMultiStagePipeline) -> None:
for stage in pipeline.stages:
if stage.map_args is None:
stage.func(context=pipeline.context)
else:
stage.func(stage.map_args, context=pipeline.context)


class BeamExecutor(CopySpecExecutor[beam.PTransform]):
Expand Down Expand Up @@ -63,38 +96,38 @@ def __init__(self, specs_by_target: Mapping[str, DirectCopySpec]):

def expand(self, pcoll):
return (
pcoll
| "Start" >> beam.FlatMap(_start_stage, self.specs_by_target)
| "CreateTasks" >> beam.FlatMapTuple(_copy_tasks)
# prevent undesirable fusion
# https://stackoverflow.com/a/54131856/809705
| "Reshuffle" >> beam.Reshuffle()
| "CopyChunks" >> beam.MapTuple(_copy_chunk)
# prepare inputs for the next stage (if any)
| "Finish" >> beam.Distinct()
pcoll
| "Start" >> beam.FlatMap(_start_stage, self.specs_by_target)
| "CreateTasks" >> beam.FlatMapTuple(_copy_tasks)
# prevent undesirable fusion
# https://stackoverflow.com/a/54131856/809705
| "Reshuffle" >> beam.Reshuffle()
| "CopyChunks" >> beam.MapTuple(_copy_chunk)
# prepare inputs for the next stage (if any)
| "Finish" >> beam.Distinct()
)


def _start_stage(
target_id: str, specs_by_target: Mapping[str, DirectCopySpec],
target_id: str, specs_by_target: Mapping[str, DirectCopySpec],
) -> Iterator[Tuple[str, DirectCopySpec]]:
spec = specs_by_target.get(target_id)
if spec is not None:
yield target_id, spec


def _copy_tasks(
target_id: str, spec: DirectCopySpec
target_id: str, spec: DirectCopySpec
) -> Iterator[Tuple[str, Tuple[slice, ...], ReadableArray, WriteableArray]]:
for key in chunk_keys(spec.source.shape, spec.chunks):
yield target_id, key, spec.source, spec.target


def _copy_chunk(
target_id: str,
key: Tuple[slice, ...],
source: ReadableArray,
target: WriteableArray,
target_id: str,
key: Tuple[slice, ...],
source: ReadableArray,
target: WriteableArray,
) -> str:
target[key] = source[key]
return target_id
21 changes: 18 additions & 3 deletions rechunker/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
NamedTuple,
Optional,
Tuple,
TypeVar,
Union,
)

# TODO: replace with Protocols, once Python 3.8+ is required
Expand Down Expand Up @@ -57,7 +59,7 @@ class CopySpec(NamedTuple):


class Stage(NamedTuple):
"""A Stage is when single function is mapped over multiple imputs.
"""A Stage is when single function is mapped over multiple inputs.

Attributes
----------
Expand All @@ -66,15 +68,28 @@ class Stage(NamedTuple):
map_args: List, Optional
Arguments which will be mapped to the function
"""

func: Callable
map_args: Optional[Iterable] = None
# TODO: figure out how to make optional, like for a dataclass
# annotations: Dict = {}


class StatefulMultiStagePipeline(NamedTuple):
"""A pipeline where each stage shares the same context.

Attributes
----------
stages: Iterable[Stage]
A function to be called in this stage. Accepts either 0 or 1 arguments.
context: Dict, Optional
Named shared state for all stages.
"""
stages: Iterable[Stage]
context: Optional[Dict] = None


# A MultiStagePipeline contains one or more stages, to be executed in sequence
MultiStagePipeline = Iterable[Stage]
MultiStagePipeline = Union[Iterable[Stage], StatefulMultiStagePipeline]

# ParallelPipelines contains one or more MultiStagePipelines, to be executed in parallel
ParallelPipelines = Iterable[MultiStagePipeline]
Expand Down
3 changes: 2 additions & 1 deletion tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rechunker.executors.dask import DaskPipelineExecutor
from rechunker.executors.prefect import PrefectPipelineExecutor
from rechunker.executors.python import PythonPipelineExecutor
from rechunker.executors.beam import BeamPipelineExecutor
from rechunker.types import Stage


Expand Down Expand Up @@ -37,7 +38,7 @@ def func1(arg):


@pytest.mark.parametrize(
"Executor", [PythonPipelineExecutor, DaskPipelineExecutor, PrefectPipelineExecutor]
"Executor", [PythonPipelineExecutor, DaskPipelineExecutor, PrefectPipelineExecutor, BeamPipelineExecutor]
)
def test_pipeline(example_pipeline, Executor):
pipeline, tmpdir = example_pipeline
Expand Down