Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tag the original function with the decorator type #2146

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 30 additions & 49 deletions src/quacc/wflow_tools/customizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from copy import deepcopy
from functools import partial
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING

from quacc.utils.dicts import recursive_dict_merge

Expand All @@ -28,57 +28,42 @@ def strip_decorator(func: Callable) -> Callable:
"""
from quacc import SETTINGS

if SETTINGS.WORKFLOW_ENGINE == "covalent":
from covalent._workflow.lattice import Lattice
decorator = getattr(func, "quacc_decorator", None)
if decorator is None:
return func

if hasattr(func, "electron_object"):
func = func.electron_object.function
if SETTINGS.WORKFLOW_ENGINE == "covalent":
if decorator in ("job", "subflow"):
func_ = func.electron_object.function

if isinstance(func, Lattice):
func = func.workflow_function.get_deserialized()
if decorator in ("flow", "subflow"):
func_ = func.workflow_function.get_deserialized()

elif SETTINGS.WORKFLOW_ENGINE == "dask":
from dask.delayed import Delayed

from quacc.wflow_tools.decorators import Delayed_

if isinstance(func, Delayed_):
if decorator == "job":
func = func.func
if isinstance(func, Delayed):
func = func.__wrapped__
if hasattr(func, "__wrapped__"):
# Needed for custom `@subflow` decorator
func = func.__wrapped__
func_ = func.__wrapped__
if decorator == "subflow":
func_ = func.__wrapped__

elif SETTINGS.WORKFLOW_ENGINE == "jobflow":
if hasattr(func, "original"):
func = func.original
func_ = func.original

elif SETTINGS.WORKFLOW_ENGINE == "parsl":
from parsl.app.python import PythonApp

if isinstance(func, PythonApp):
func = func.func
func_ = func.func

elif SETTINGS.WORKFLOW_ENGINE == "prefect":
from prefect import Flow as PrefectFlow
from prefect import Task

if isinstance(func, (Task, PrefectFlow)):
func = func.fn
elif hasattr(func, "__wrapped__"):
if SETTINGS.PREFECT_AUTO_SUBMIT:
func = func.__wrapped__
func_ = func.fn

elif SETTINGS.WORKFLOW_ENGINE == "redun":
from redun import Task

if isinstance(func, Task):
func = func.func
func_ = func.func

return func
return func_


def redecorate(func: Callable, decorator: Callable | None) -> Callable:
def redecorate(func: Callable, decorator: Callable) -> Callable:
"""
Redecorate a pre-decorated function with a custom decorator.

Expand All @@ -87,23 +72,18 @@ def redecorate(func: Callable, decorator: Callable | None) -> Callable:
func
The pre-decorated function.
decorator
The new decorator to apply. If `None`, the function is stripped of its
decorators.
The new decorator to apply.

Returns
-------
Callable
The newly decorated function.
"""
func = strip_decorator(func)
return func if decorator is None else decorator(func)
return decorator(func)


def update_parameters(
func: Callable,
params: dict[str, Any],
decorator: Literal["job", "flow", "subflow"] | None = "job",
) -> Callable:
def update_parameters(func: Callable, params: dict[str, Any]) -> Callable:
"""
Update the parameters of a (potentially decorated) function.

Expand All @@ -113,8 +93,6 @@ def update_parameters(
The function to update.
params
The parameters and associated values to update.
decorator
The decorator associated with `func`.

Returns
-------
Expand All @@ -123,12 +101,15 @@ def update_parameters(
"""
from quacc import SETTINGS, flow, job, subflow

if decorator and SETTINGS.WORKFLOW_ENGINE == "dask":
if decorator == "job":
if (
decorators_type := hasattr(func, "quacc_decorator")
and SETTINGS.WORKFLOW_ENGINE == "dask"
):
if decorators_type == "job":
decorator = job
elif decorator == "flow":
elif decorators_type == "flow":
decorator = flow
elif decorator == "subflow":
elif decorators_type == "subflow":
decorator = subflow

func = strip_decorator(func)
Expand Down
51 changes: 30 additions & 21 deletions src/quacc/wflow_tools/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def add(a, b):
if _func is None:
return partial(job, **kwargs)

elif SETTINGS.WORKFLOW_ENGINE == "covalent":
if SETTINGS.WORKFLOW_ENGINE == "covalent":
import covalent as ct

return ct.electron(_func, **kwargs)
decorated_func = ct.electron(_func, **kwargs)
elif SETTINGS.WORKFLOW_ENGINE == "dask":
from dask import delayed

Expand All @@ -153,22 +153,22 @@ def add(a, b):
def wrapper(*f_args, **f_kwargs):
return _func(*f_args, **f_kwargs)

return Delayed_(delayed(wrapper, **kwargs))
decorated_func = Delayed_(delayed(wrapper, **kwargs))

elif SETTINGS.WORKFLOW_ENGINE == "jobflow":
from jobflow import job as jf_job

return jf_job(_func, **kwargs)
decorated_func = jf_job(_func, **kwargs)
elif SETTINGS.WORKFLOW_ENGINE == "parsl":
from parsl import python_app

wrapped_fn = _get_parsl_wrapped_func(_func, kwargs)

return python_app(wrapped_fn, **kwargs)
decorated_func = python_app(wrapped_fn, **kwargs)
elif SETTINGS.WORKFLOW_ENGINE == "redun":
from redun import task

return task(_func, namespace=_func.__module__, **kwargs)
decorated_func = task(_func, namespace=_func.__module__, **kwargs)
elif SETTINGS.WORKFLOW_ENGINE == "prefect":
from prefect import task

Expand All @@ -179,11 +179,14 @@ def wrapper(*f_args, **f_kwargs):
decorated = task(_func, **kwargs)
return decorated.submit(*f_args, **f_kwargs)

return wrapper
decorated_func = wrapper
else:
return task(_func, **kwargs)
decorated_func = task(_func, **kwargs)
else:
return _func
decorated_func = _func

decorated_func.quacc_decorator = "job"
return decorated_func


def flow(_func: Callable | None = None, **kwargs) -> Flow:
Expand Down Expand Up @@ -333,20 +336,23 @@ def workflow(a, b, c):
if _func is None:
return partial(flow, **kwargs)

elif SETTINGS.WORKFLOW_ENGINE == "covalent":
if SETTINGS.WORKFLOW_ENGINE == "covalent":
import covalent as ct

return ct.lattice(_func, **kwargs)
decorated_func = ct.lattice(_func, **kwargs)
elif SETTINGS.WORKFLOW_ENGINE == "redun":
from redun import task

return task(_func, namespace=_func.__module__, **kwargs)
decorated_func = task(_func, namespace=_func.__module__, **kwargs)
elif SETTINGS.WORKFLOW_ENGINE == "prefect":
from prefect import flow as prefect_flow

return prefect_flow(_func, validate_parameters=False, **kwargs)
decorated_func = prefect_flow(_func, validate_parameters=False, **kwargs)
else:
return _func
decorated_func = _func

decorated_func.quacc_decorator = "flow"
return decorated_func


def subflow(_func: Callable | None = None, **kwargs) -> Subflow:
Expand Down Expand Up @@ -547,10 +553,10 @@ def workflow(a, b, c):
if _func is None:
return partial(subflow, **kwargs)

elif SETTINGS.WORKFLOW_ENGINE == "covalent":
if SETTINGS.WORKFLOW_ENGINE == "covalent":
import covalent as ct

return ct.electron(ct.lattice(_func), **kwargs)
decorated_func = ct.electron(ct.lattice(_func), **kwargs)
elif SETTINGS.WORKFLOW_ENGINE == "dask":
from dask import delayed
from dask.distributed import worker_client
Expand All @@ -563,23 +569,26 @@ def wrapper(*f_args, **f_kwargs):
futures = client.compute(_func(*f_args, **f_kwargs))
return client.gather(futures)

return delayed(wrapper, **kwargs)
decorated_func = delayed(wrapper, **kwargs)
elif SETTINGS.WORKFLOW_ENGINE == "parsl":
from parsl import join_app

wrapped_fn = _get_parsl_wrapped_func(_func, kwargs)

return join_app(wrapped_fn, **kwargs)
decorated_func = join_app(wrapped_fn, **kwargs)
elif SETTINGS.WORKFLOW_ENGINE == "prefect":
from prefect import flow as prefect_flow

return prefect_flow(_func, validate_parameters=False, **kwargs)
decorated_func = prefect_flow(_func, validate_parameters=False, **kwargs)
elif SETTINGS.WORKFLOW_ENGINE == "redun":
from redun import task

return task(_func, namespace=_func.__module__, **kwargs)
decorated_func = task(_func, namespace=_func.__module__, **kwargs)
else:
return _func
decorated_func = _func

decorated_func.quacc_decorator = "subflow"
return decorated_func


def _get_parsl_wrapped_func(
Expand Down
8 changes: 3 additions & 5 deletions tests/dask/test_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,10 @@ def dynamic_workflow(a, b, c=1):
def test_dynamic_workflow(a, b, c=3):
result1 = add(a, b)
result2 = make_more(result1)
return update_parameters(add_distributed, {"d": 1}, decorator="subflow")(
result2, c
)
return update_parameters(add_distributed, {"d": 1})(result2, c)

add_ = update_parameters(add, {"b": 3}, decorator="job")
dynamic_workflow_ = update_parameters(dynamic_workflow, {"c": 4}, decorator="flow")
add_ = update_parameters(add, {"b": 3})
dynamic_workflow_ = update_parameters(dynamic_workflow, {"c": 4})
assert client.compute(add_(1)).result() == 4
assert client.compute(dynamic_workflow_(1, 2)).result() == [7, 7, 7]
assert client.compute(test_dynamic_workflow(1, 2)).result() == [7, 7, 7]
6 changes: 6 additions & 0 deletions tests/prefect/test_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

prefect = pytest.importorskip("prefect")

from prefect import Flow as PrefectFlow
from prefect import Task

from quacc import change_settings, flow, job, strip_decorator, subflow


Expand Down Expand Up @@ -178,9 +181,12 @@ def add3(a, b):

stripped_add = strip_decorator(add)
assert stripped_add(1, 2) == 3
assert not isinstance(stripped_add, Task)

stripped_add2 = strip_decorator(add2)
assert stripped_add2(1, 2) == 3
assert not isinstance(stripped_add2, PrefectFlow)

stripped_add3 = strip_decorator(add3)
assert stripped_add3(1, 2) == 3
assert not isinstance(stripped_add3, PrefectFlow)
Loading