From b8fa45327ccd95653ed33205b6df4093f7bbe0fe Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Tue, 14 May 2024 10:45:45 -0700 Subject: [PATCH 1/3] Tag the original function with the decorator --- src/quacc/wflow_tools/customizers.py | 71 ++++++++++------------------ src/quacc/wflow_tools/decorators.py | 51 +++++++++++--------- tests/dask/test_syntax.py | 6 +-- tests/prefect/test_syntax.py | 6 +++ 4 files changed, 64 insertions(+), 70 deletions(-) diff --git a/src/quacc/wflow_tools/customizers.py b/src/quacc/wflow_tools/customizers.py index 03b14d4ba9..9f034832e6 100644 --- a/src/quacc/wflow_tools/customizers.py +++ b/src/quacc/wflow_tools/customizers.py @@ -4,7 +4,7 @@ from copy import deepcopy from functools import partial -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING from quacc.utils.dicts import recursive_dict_merge @@ -28,57 +28,42 @@ def strip_decorator(func: Callable) -> Callable: """ from quacc import SETTINGS - if SETTINGS.WORKFLOW_ENGINE == "covalent": - from covalent._workflow.lattice import Lattice + decorator = getattr(func, "quacc_decorator", None) + if decorator is None: + return func - if hasattr(func, "electron_object"): + if SETTINGS.WORKFLOW_ENGINE == "covalent": + if decorator in ("job", "subflow"): func = func.electron_object.function - if isinstance(func, Lattice): + if decorator in ("flow", "subflow"): func = func.workflow_function.get_deserialized() elif SETTINGS.WORKFLOW_ENGINE == "dask": - from dask.delayed import Delayed - - from quacc.wflow_tools.decorators import Delayed_ - - if isinstance(func, Delayed_): + if decorator == "job": func = func.func - if isinstance(func, Delayed): + func = func.__wrapped__ + if decorator == "subflow": func = func.__wrapped__ - if hasattr(func, "__wrapped__"): - # Needed for custom `@subflow` decorator - func = func.__wrapped__ elif SETTINGS.WORKFLOW_ENGINE == "jobflow": - if hasattr(func, "original"): - func = func.original + func = func.original elif SETTINGS.WORKFLOW_ENGINE == "parsl": - from parsl.app.python import PythonApp - - if isinstance(func, PythonApp): - func = func.func + func = func.func elif SETTINGS.WORKFLOW_ENGINE == "prefect": - from prefect import Flow as PrefectFlow - from prefect import Task - - if isinstance(func, (Task, PrefectFlow)): - func = func.fn - elif hasattr(func, "__wrapped__"): + if SETTINGS.PREFECT_AUTO_SUBMIT: func = func.__wrapped__ + func = func.fn elif SETTINGS.WORKFLOW_ENGINE == "redun": - from redun import Task - - if isinstance(func, Task): - func = func.func + func = func.func return func -def redecorate(func: Callable, decorator: Callable | None) -> Callable: +def redecorate(func: Callable, decorator: Callable) -> Callable: """ Redecorate a pre-decorated function with a custom decorator. @@ -87,8 +72,7 @@ def redecorate(func: Callable, decorator: Callable | None) -> Callable: func The pre-decorated function. decorator - The new decorator to apply. If `None`, the function is stripped of its - decorators. + The new decorator to apply. Returns ------- @@ -96,14 +80,10 @@ def redecorate(func: Callable, decorator: Callable | None) -> Callable: The newly decorated function. """ func = strip_decorator(func) - return func if decorator is None else decorator(func) + return decorator(func) -def update_parameters( - func: Callable, - params: dict[str, Any], - decorator: Literal["job", "flow", "subflow"] | None = "job", -) -> Callable: +def update_parameters(func: Callable, params: dict[str, Any]) -> Callable: """ Update the parameters of a (potentially decorated) function. @@ -113,8 +93,6 @@ def update_parameters( The function to update. params The parameters and associated values to update. - decorator - The decorator associated with `func`. Returns ------- @@ -123,12 +101,15 @@ def update_parameters( """ from quacc import SETTINGS, flow, job, subflow - if decorator and SETTINGS.WORKFLOW_ENGINE == "dask": - if decorator == "job": + if ( + decorators_type := hasattr(func, "quacc_decorator") + and SETTINGS.WORKFLOW_ENGINE == "dask" + ): + if decorators_type == "job": decorator = job - elif decorator == "flow": + elif decorators_type == "flow": decorator = flow - elif decorator == "subflow": + elif decorators_type == "subflow": decorator = subflow func = strip_decorator(func) diff --git a/src/quacc/wflow_tools/decorators.py b/src/quacc/wflow_tools/decorators.py index 3794622480..c277cc6870 100644 --- a/src/quacc/wflow_tools/decorators.py +++ b/src/quacc/wflow_tools/decorators.py @@ -140,10 +140,10 @@ def add(a, b): if _func is None: return partial(job, **kwargs) - elif SETTINGS.WORKFLOW_ENGINE == "covalent": + if SETTINGS.WORKFLOW_ENGINE == "covalent": import covalent as ct - return ct.electron(_func, **kwargs) + decorated_func= ct.electron(_func, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "dask": from dask import delayed @@ -153,22 +153,22 @@ def add(a, b): def wrapper(*f_args, **f_kwargs): return _func(*f_args, **f_kwargs) - return Delayed_(delayed(wrapper, **kwargs)) + decorated_func= Delayed_(delayed(wrapper, **kwargs)) elif SETTINGS.WORKFLOW_ENGINE == "jobflow": from jobflow import job as jf_job - return jf_job(_func, **kwargs) + decorated_func= jf_job(_func, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "parsl": from parsl import python_app wrapped_fn = _get_parsl_wrapped_func(_func, kwargs) - return python_app(wrapped_fn, **kwargs) + decorated_func= python_app(wrapped_fn, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "redun": from redun import task - return task(_func, namespace=_func.__module__, **kwargs) + decorated_func= task(_func, namespace=_func.__module__, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "prefect": from prefect import task @@ -178,12 +178,14 @@ def wrapper(*f_args, **f_kwargs): def wrapper(*f_args, **f_kwargs): decorated = task(_func, **kwargs) return decorated.submit(*f_args, **f_kwargs) - - return wrapper + decorated_func = wrapper else: - return task(_func, **kwargs) + decorated_func= task(_func, **kwargs) else: - return _func + decorated_func= _func + + decorated_func.quacc_decorator = "job" + return decorated_func def flow(_func: Callable | None = None, **kwargs) -> Flow: @@ -333,20 +335,23 @@ def workflow(a, b, c): if _func is None: return partial(flow, **kwargs) - elif SETTINGS.WORKFLOW_ENGINE == "covalent": + if SETTINGS.WORKFLOW_ENGINE == "covalent": import covalent as ct - return ct.lattice(_func, **kwargs) + decorated_func= ct.lattice(_func, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "redun": from redun import task - return task(_func, namespace=_func.__module__, **kwargs) + decorated_func= task(_func, namespace=_func.__module__, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "prefect": from prefect import flow as prefect_flow - return prefect_flow(_func, validate_parameters=False, **kwargs) + decorated_func= prefect_flow(_func, validate_parameters=False, **kwargs) else: - return _func + decorated_func= _func + + decorated_func.quacc_decorator = "flow" + return decorated_func def subflow(_func: Callable | None = None, **kwargs) -> Subflow: @@ -547,10 +552,10 @@ def workflow(a, b, c): if _func is None: return partial(subflow, **kwargs) - elif SETTINGS.WORKFLOW_ENGINE == "covalent": + if SETTINGS.WORKFLOW_ENGINE == "covalent": import covalent as ct - return ct.electron(ct.lattice(_func), **kwargs) + decorated_func= ct.electron(ct.lattice(_func), **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "dask": from dask import delayed from dask.distributed import worker_client @@ -563,24 +568,26 @@ def wrapper(*f_args, **f_kwargs): futures = client.compute(_func(*f_args, **f_kwargs)) return client.gather(futures) - return delayed(wrapper, **kwargs) + decorated_func= delayed(wrapper, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "parsl": from parsl import join_app wrapped_fn = _get_parsl_wrapped_func(_func, kwargs) - return join_app(wrapped_fn, **kwargs) + decorated_func= join_app(wrapped_fn, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "prefect": from prefect import flow as prefect_flow - return prefect_flow(_func, validate_parameters=False, **kwargs) + decorated_func= prefect_flow(_func, validate_parameters=False, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "redun": from redun import task - return task(_func, namespace=_func.__module__, **kwargs) + decorated_func= task(_func, namespace=_func.__module__, **kwargs) else: - return _func + decorated_func= _func + decorated_func.quacc_decorator = "subflow" + return decorated_func def _get_parsl_wrapped_func( func: Callable, decorator_kwargs: dict[str, Any] diff --git a/tests/dask/test_syntax.py b/tests/dask/test_syntax.py index 3923212c22..8dcd567318 100644 --- a/tests/dask/test_syntax.py +++ b/tests/dask/test_syntax.py @@ -162,12 +162,12 @@ def dynamic_workflow(a, b, c=1): def test_dynamic_workflow(a, b, c=3): result1 = add(a, b) result2 = make_more(result1) - return update_parameters(add_distributed, {"d": 1}, decorator="subflow")( + return update_parameters(add_distributed, {"d": 1})( result2, c ) - add_ = update_parameters(add, {"b": 3}, decorator="job") - dynamic_workflow_ = update_parameters(dynamic_workflow, {"c": 4}, decorator="flow") + add_ = update_parameters(add, {"b": 3}) + dynamic_workflow_ = update_parameters(dynamic_workflow, {"c": 4}) assert client.compute(add_(1)).result() == 4 assert client.compute(dynamic_workflow_(1, 2)).result() == [7, 7, 7] assert client.compute(test_dynamic_workflow(1, 2)).result() == [7, 7, 7] diff --git a/tests/prefect/test_syntax.py b/tests/prefect/test_syntax.py index cae89c60b1..d8005dc940 100644 --- a/tests/prefect/test_syntax.py +++ b/tests/prefect/test_syntax.py @@ -4,6 +4,9 @@ prefect = pytest.importorskip("prefect") +from prefect import Flow as PrefectFlow +from prefect import Task + from quacc import change_settings, flow, job, strip_decorator, subflow @@ -178,9 +181,12 @@ def add3(a, b): stripped_add = strip_decorator(add) assert stripped_add(1, 2) == 3 + assert not isinstance(stripped_add, Task) stripped_add2 = strip_decorator(add2) assert stripped_add2(1, 2) == 3 + assert not isinstance(stripped_add2,PrefectFlow) stripped_add3 = strip_decorator(add3) assert stripped_add3(1, 2) == 3 + assert not isinstance(stripped_add3,PrefectFlow) From 271e361798a820d06917b0cab2c5acd807e35a03 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 May 2024 17:50:40 +0000 Subject: [PATCH 2/3] pre-commit auto-fixes --- src/quacc/wflow_tools/decorators.py | 36 +++++++++++++++-------------- tests/dask/test_syntax.py | 4 +--- tests/prefect/test_syntax.py | 4 ++-- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/quacc/wflow_tools/decorators.py b/src/quacc/wflow_tools/decorators.py index c277cc6870..27e61d9822 100644 --- a/src/quacc/wflow_tools/decorators.py +++ b/src/quacc/wflow_tools/decorators.py @@ -143,7 +143,7 @@ def add(a, b): if SETTINGS.WORKFLOW_ENGINE == "covalent": import covalent as ct - decorated_func= ct.electron(_func, **kwargs) + decorated_func = ct.electron(_func, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "dask": from dask import delayed @@ -153,22 +153,22 @@ def add(a, b): def wrapper(*f_args, **f_kwargs): return _func(*f_args, **f_kwargs) - decorated_func= Delayed_(delayed(wrapper, **kwargs)) + decorated_func = Delayed_(delayed(wrapper, **kwargs)) elif SETTINGS.WORKFLOW_ENGINE == "jobflow": from jobflow import job as jf_job - decorated_func= jf_job(_func, **kwargs) + decorated_func = jf_job(_func, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "parsl": from parsl import python_app wrapped_fn = _get_parsl_wrapped_func(_func, kwargs) - decorated_func= python_app(wrapped_fn, **kwargs) + decorated_func = python_app(wrapped_fn, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "redun": from redun import task - decorated_func= task(_func, namespace=_func.__module__, **kwargs) + decorated_func = task(_func, namespace=_func.__module__, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "prefect": from prefect import task @@ -178,11 +178,12 @@ def wrapper(*f_args, **f_kwargs): def wrapper(*f_args, **f_kwargs): decorated = task(_func, **kwargs) return decorated.submit(*f_args, **f_kwargs) + decorated_func = wrapper else: - decorated_func= task(_func, **kwargs) + decorated_func = task(_func, **kwargs) else: - decorated_func= _func + decorated_func = _func decorated_func.quacc_decorator = "job" return decorated_func @@ -338,17 +339,17 @@ def workflow(a, b, c): if SETTINGS.WORKFLOW_ENGINE == "covalent": import covalent as ct - decorated_func= ct.lattice(_func, **kwargs) + decorated_func = ct.lattice(_func, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "redun": from redun import task - decorated_func= task(_func, namespace=_func.__module__, **kwargs) + decorated_func = task(_func, namespace=_func.__module__, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "prefect": from prefect import flow as prefect_flow - decorated_func= prefect_flow(_func, validate_parameters=False, **kwargs) + decorated_func = prefect_flow(_func, validate_parameters=False, **kwargs) else: - decorated_func= _func + decorated_func = _func decorated_func.quacc_decorator = "flow" return decorated_func @@ -555,7 +556,7 @@ def workflow(a, b, c): if SETTINGS.WORKFLOW_ENGINE == "covalent": import covalent as ct - decorated_func= ct.electron(ct.lattice(_func), **kwargs) + decorated_func = ct.electron(ct.lattice(_func), **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "dask": from dask import delayed from dask.distributed import worker_client @@ -568,27 +569,28 @@ def wrapper(*f_args, **f_kwargs): futures = client.compute(_func(*f_args, **f_kwargs)) return client.gather(futures) - decorated_func= delayed(wrapper, **kwargs) + decorated_func = delayed(wrapper, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "parsl": from parsl import join_app wrapped_fn = _get_parsl_wrapped_func(_func, kwargs) - decorated_func= join_app(wrapped_fn, **kwargs) + decorated_func = join_app(wrapped_fn, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "prefect": from prefect import flow as prefect_flow - decorated_func= prefect_flow(_func, validate_parameters=False, **kwargs) + decorated_func = prefect_flow(_func, validate_parameters=False, **kwargs) elif SETTINGS.WORKFLOW_ENGINE == "redun": from redun import task - decorated_func= task(_func, namespace=_func.__module__, **kwargs) + decorated_func = task(_func, namespace=_func.__module__, **kwargs) else: - decorated_func= _func + decorated_func = _func decorated_func.quacc_decorator = "subflow" return decorated_func + def _get_parsl_wrapped_func( func: Callable, decorator_kwargs: dict[str, Any] ) -> Callable: diff --git a/tests/dask/test_syntax.py b/tests/dask/test_syntax.py index 8dcd567318..0f63c210cb 100644 --- a/tests/dask/test_syntax.py +++ b/tests/dask/test_syntax.py @@ -162,9 +162,7 @@ def dynamic_workflow(a, b, c=1): def test_dynamic_workflow(a, b, c=3): result1 = add(a, b) result2 = make_more(result1) - return update_parameters(add_distributed, {"d": 1})( - result2, c - ) + return update_parameters(add_distributed, {"d": 1})(result2, c) add_ = update_parameters(add, {"b": 3}) dynamic_workflow_ = update_parameters(dynamic_workflow, {"c": 4}) diff --git a/tests/prefect/test_syntax.py b/tests/prefect/test_syntax.py index d8005dc940..0397cf2b79 100644 --- a/tests/prefect/test_syntax.py +++ b/tests/prefect/test_syntax.py @@ -185,8 +185,8 @@ def add3(a, b): stripped_add2 = strip_decorator(add2) assert stripped_add2(1, 2) == 3 - assert not isinstance(stripped_add2,PrefectFlow) + assert not isinstance(stripped_add2, PrefectFlow) stripped_add3 = strip_decorator(add3) assert stripped_add3(1, 2) == 3 - assert not isinstance(stripped_add3,PrefectFlow) + assert not isinstance(stripped_add3, PrefectFlow) From 34dbfab3f83b18ff7d03225b3aa27eb95dcf4cb6 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Tue, 14 May 2024 10:52:19 -0700 Subject: [PATCH 3/3] fix --- src/quacc/wflow_tools/customizers.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/quacc/wflow_tools/customizers.py b/src/quacc/wflow_tools/customizers.py index 9f034832e6..8957372163 100644 --- a/src/quacc/wflow_tools/customizers.py +++ b/src/quacc/wflow_tools/customizers.py @@ -34,33 +34,33 @@ def strip_decorator(func: Callable) -> Callable: if SETTINGS.WORKFLOW_ENGINE == "covalent": if decorator in ("job", "subflow"): - func = func.electron_object.function + func_ = func.electron_object.function if decorator in ("flow", "subflow"): - func = func.workflow_function.get_deserialized() + func_ = func.workflow_function.get_deserialized() elif SETTINGS.WORKFLOW_ENGINE == "dask": if decorator == "job": func = func.func - func = func.__wrapped__ + func_ = func.__wrapped__ if decorator == "subflow": - func = func.__wrapped__ + func_ = func.__wrapped__ elif SETTINGS.WORKFLOW_ENGINE == "jobflow": - func = func.original + func_ = func.original elif SETTINGS.WORKFLOW_ENGINE == "parsl": - func = func.func + func_ = func.func elif SETTINGS.WORKFLOW_ENGINE == "prefect": if SETTINGS.PREFECT_AUTO_SUBMIT: func = func.__wrapped__ - func = func.fn + func_ = func.fn elif SETTINGS.WORKFLOW_ENGINE == "redun": - func = func.func + func_ = func.func - return func + return func_ def redecorate(func: Callable, decorator: Callable) -> Callable: