From 541103579126f6860a5b1e23a87f1b195422b4cd Mon Sep 17 00:00:00 2001 From: Yun Wu Date: Mon, 19 Aug 2024 13:25:41 -0700 Subject: [PATCH] Add skip decorator; A few clean ups --- metaflow/plugins/__init__.py | 7 ++- metaflow/plugins/aip/aip.py | 6 +-- .../plugins/aip/interruptible_decorator.py | 2 +- metaflow/plugins/aip/s3_sensor_decorator.py | 1 - metaflow/plugins/aip/skip_decorator.py | 52 +++++++++++++++++++ .../plugins/aip/tests/flows/resources_flow.py | 1 - metaflow/plugins/aip/tests/flows/skip_flow.py | 41 +++++++++++++++ 7 files changed, 102 insertions(+), 8 deletions(-) create mode 100644 metaflow/plugins/aip/skip_decorator.py create mode 100644 metaflow/plugins/aip/tests/flows/skip_flow.py diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index 2a3923e92f2..00a5bf2d106 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -116,7 +116,8 @@ def get_plugin_cli(): from .frameworks.pytorch import PytorchParallelDecorator from .aip.aip_decorator import AIPInternalDecorator from .aip.accelerator_decorator import AcceleratorDecorator -from .aip.interruptible_decorator import interruptibleDecorator +from .aip.interruptible_decorator import InterruptibleDecorator +from .aip.skip_decorator import SkipDecorator STEP_DECORATORS = [ @@ -134,8 +135,9 @@ def get_plugin_cli(): PytorchParallelDecorator, InternalTestUnboundedForeachDecorator, AcceleratorDecorator, - interruptibleDecorator, + InterruptibleDecorator, AIPInternalDecorator, + SkipDecorator, ] _merge_lists(STEP_DECORATORS, _ext_plugins["STEP_DECORATORS"], "name") @@ -159,6 +161,7 @@ def get_plugin_cli(): from .aws.step_functions.schedule_decorator import ScheduleDecorator from .project_decorator import ProjectDecorator from .aip.s3_sensor_decorator import S3SensorDecorator + from .aip.exit_handler_decorator import ExitHandlerDecorator FLOW_DECORATORS = [ diff --git a/metaflow/plugins/aip/aip.py b/metaflow/plugins/aip/aip.py index 1d81a697e65..f7db99012a5 100644 --- a/metaflow/plugins/aip/aip.py +++ b/metaflow/plugins/aip/aip.py @@ -58,7 +58,7 @@ from metaflow.plugins.aip.aip_decorator import AIPException from .accelerator_decorator import AcceleratorDecorator from .argo_client import ArgoClient -from .interruptible_decorator import interruptibleDecorator +from .interruptible_decorator import InterruptibleDecorator from .aip_foreach_splits import graph_to_task_ids from ..aws.batch.batch_decorator import BatchDecorator from ..aws.step_functions.schedule_decorator import ScheduleDecorator @@ -106,7 +106,7 @@ def __init__( resource_requirements: Dict[str, str], aip_decorator: AIPInternalDecorator, accelerator_decorator: AcceleratorDecorator, - interruptible_decorator: interruptibleDecorator, + interruptible_decorator: InterruptibleDecorator, environment_decorator: EnvironmentDecorator, total_retries: int, minutes_between_retries: str, @@ -741,7 +741,7 @@ def build_aip_component(node: DAGNode, task_id: str) -> AIPComponent: ( deco for deco in node.decorators - if isinstance(deco, interruptibleDecorator) + if isinstance(deco, InterruptibleDecorator) ), None, # default ), diff --git a/metaflow/plugins/aip/interruptible_decorator.py b/metaflow/plugins/aip/interruptible_decorator.py index d316793f8c3..4596d14ae67 100644 --- a/metaflow/plugins/aip/interruptible_decorator.py +++ b/metaflow/plugins/aip/interruptible_decorator.py @@ -11,7 +11,7 @@ def _get_ec2_metadata(path: str) -> Optional[str]: return response.text -class interruptibleDecorator(StepDecorator): +class InterruptibleDecorator(StepDecorator): """ For AIP orchestrator plugin only. diff --git a/metaflow/plugins/aip/s3_sensor_decorator.py b/metaflow/plugins/aip/s3_sensor_decorator.py index 8fbdbf029bd..8e0ab1008b5 100644 --- a/metaflow/plugins/aip/s3_sensor_decorator.py +++ b/metaflow/plugins/aip/s3_sensor_decorator.py @@ -1,5 +1,4 @@ from types import FunctionType -from typing import Tuple from urllib.parse import urlparse from metaflow.decorators import FlowDecorator diff --git a/metaflow/plugins/aip/skip_decorator.py b/metaflow/plugins/aip/skip_decorator.py new file mode 100644 index 00000000000..378b0d30d41 --- /dev/null +++ b/metaflow/plugins/aip/skip_decorator.py @@ -0,0 +1,52 @@ +# Skip decorator is a workaround solution to implement conditional branching in metaflow. +# When condition variable is_skipping is evaluated to True, +# it will skip current step and execute the supplied next step. + +from functools import wraps +from metaflow.decorators import StepDecorator + + +class SkipDecorator(StepDecorator): + """ + The @skip decorator is a workaround for conditional branching. The @skip decorator checks an artifact + and if it is false, skips the evaluation of the step function and jumps to the supplied next step. + + **The `start` and `end` steps are always expected and should not be skipped.** + + Usage: + class SkipFlow(FlowSpec): + + condition = Parameter("condition", default=False) + + @step + def start(self): + print("Should skip:", self.condition) + self.next(self.middle) + + @skip(check='condition', next='end') + @step + def middle(self): + print("Running the middle step - not skipping") + self.next(self.end) + + @step + def end(self): + pass + """ + + name = "skip" + + def __init__(self, check="", next=""): + super().__init__() + self.check = check + self.next = next + + def __call__(self, f): + @wraps(f) + def func(step): + if getattr(step, self.check): + step.next(getattr(step, self.next)) + else: + return f(step) + + return func diff --git a/metaflow/plugins/aip/tests/flows/resources_flow.py b/metaflow/plugins/aip/tests/flows/resources_flow.py index f8b5cbe234a..90e576e0435 100644 --- a/metaflow/plugins/aip/tests/flows/resources_flow.py +++ b/metaflow/plugins/aip/tests/flows/resources_flow.py @@ -1,7 +1,6 @@ import os import pprint import subprocess -import time from typing import Dict, List from multiprocessing.shared_memory import SharedMemory diff --git a/metaflow/plugins/aip/tests/flows/skip_flow.py b/metaflow/plugins/aip/tests/flows/skip_flow.py new file mode 100644 index 00000000000..23ae936f55a --- /dev/null +++ b/metaflow/plugins/aip/tests/flows/skip_flow.py @@ -0,0 +1,41 @@ +from metaflow import Parameter, FlowSpec, step, skip + + +class SkipFlow(FlowSpec): + + condition_true = Parameter("condition-true", default=True) + + @step + def start(self): + print("Should skip:", self.condition) + self.desired_step_executed = False + self.condition_false = False + self.next(self.skipped_step) + + @skip(check="condition_true", next="desired_step") + @step + def skipped_step(self): + raise Exception( + "Unexpectedly ran the skipped_step step. This step should have been skipped." + ) + self.next(self.unreachable) + + def unreachable(self): + raise Exception( + "Unexpectedly ran the unreachable step. This step should have been skipped." + ) + self.next(self.end) + + @skip(check="condition_false", next="end") + @step + def desired_step(self): + self.desired_step_executed = True + self.next(self.end) + + @step + def end(self): + assert self.desired_step_executed, "Desired step was not executed" + + +if __name__ == "__main__": + SkipFlow()