Skip to content

Commit

Permalink
Allow tasks to depend on other tasks.
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasraabe committed Nov 13, 2023
1 parent 580f415 commit 701b46d
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
when a product annotation is used with the argument name `produces`. And, allow
`produces` to intake any node.
- {pull}`490` refactors and better tests parsing of dependencies.
- {pull}`491` allows tasks to depend on other tasks.

## 0.4.2 - 2023-11-8

Expand Down
6 changes: 5 additions & 1 deletion src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ def pytask_collect_task(
)

markers = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []
collection_id = obj.pytask_meta._id if hasattr(obj, "pytask_meta") else None
after = obj.pytask_meta.after if hasattr(obj, "pytask_meta") else []

# Get the underlying function to avoid having different states of the function,
# e.g. due to pytask_meta, in different layers of the wrapping.
Expand All @@ -267,6 +269,7 @@ def pytask_collect_task(
depends_on=dependencies,
produces=products,
markers=markers,
attributes={"collection_id": collection_id, "after": after},
)
return Task(
base_name=name,
Expand All @@ -275,6 +278,7 @@ def pytask_collect_task(
depends_on=dependencies,
produces=products,
markers=markers,
attributes={"collection_id": collection_id, "after": after},
)
if isinstance(obj, PTask):
return obj
Expand All @@ -295,7 +299,7 @@ def pytask_collect_task(
Please, align the names to ensure reproducibility on case-sensitive file systems \
(often Linux or macOS) or disable this error with 'check_casing_of_paths = false' in \
your pytask configuration file.
the pyproject.toml file.
Hint: If parts of the path preceding your project directory are not properly \
formatted, check whether you need to call `.resolve()` on `SRC`, `BLD` or other paths \
Expand Down
23 changes: 23 additions & 0 deletions src/_pytask/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from _pytask.database_utils import State
from _pytask.exceptions import ResolvingDependenciesError
from _pytask.mark import Mark
from _pytask.mark import select_by_after_keyword
from _pytask.node_protocols import PNode
from _pytask.node_protocols import PTask
from _pytask.nodes import PythonNode
Expand Down Expand Up @@ -101,6 +102,28 @@ def _add_product(dag: nx.DiGraph, task: PTask, node: PNode) -> None:
return dag


@hookimpl
def pytask_dag_modify_dag(session: Session, dag: nx.DiGraph) -> None:
"""Create dependencies between tasks when using ``@task(after=...)``."""
temporary_id_to_task = {
task.attributes["collection_id"]: task
for task in session.tasks
if "collection_id" in task.attributes
}
for task in session.tasks:
after = task.attributes.get("after")
if isinstance(after, list):
for temporary_id in after:
other_task = temporary_id_to_task[temporary_id]
dag.add_edge(other_task.signature, task.signature)
elif isinstance(after, str):
task_signature = task.signature
signatures = select_by_after_keyword(session, after)
signatures.discard(task_signature)
for signature in signatures:
dag.add_edge(signature, task_signature)


@hookimpl
def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
"""Select the tasks which need to be executed."""
Expand Down
6 changes: 3 additions & 3 deletions src/_pytask/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None:
"""
for dependency in session.dag.predecessors(task.signature):
node = session.dag.nodes[dependency]["node"]
if not node.state():
node = session.dag.nodes[dependency].get("node")
if isinstance(node, PNode) and not node.state():
msg = f"{task.name!r} requires missing node {node.name!r}."
if IS_FILE_SYSTEM_CASE_SENSITIVE:
msg += (
Expand All @@ -138,7 +138,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None:
# Create directory for product if it does not exist. Maybe this should be a `setup`
# method for the node classes.
for product in session.dag.successors(task.signature):
node = session.dag.nodes[product]["node"]
node = session.dag.nodes[product].get("node")
if isinstance(node, PPathNode):
node.path.parent.mkdir(parents=True, exist_ok=True)

Expand Down
17 changes: 17 additions & 0 deletions src/_pytask/mark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"MarkDecorator",
"MarkGenerator",
"ParseError",
"select_by_after_keyword",
"select_by_keyword",
"select_by_mark",
]
Expand Down Expand Up @@ -168,6 +169,22 @@ def select_by_keyword(session: Session, dag: nx.DiGraph) -> set[str]:
return remaining


def select_by_after_keyword(session: Session, after: str) -> set[str]:
"""Select tasks defined by the after keyword."""
try:
expression = Expression.compile_(after)
except ParseError as e:
msg = f"Wrong expression passed to 'after': {after}: {e}"
raise ValueError(msg) from None

ancestors: set[str] = set()
for task in session.tasks:
if after and expression.evaluate(KeywordMatcher.from_task(task)):
ancestors.add(task.signature)

return ancestors


@define(slots=True)
class MarkMatcher:
"""A matcher for markers which are present.
Expand Down
2 changes: 2 additions & 0 deletions src/_pytask/mark/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ from _pytask.tree_util import PyTree
from _pytask.session import Session
import networkx as nx

def select_by_after_keyword(session: Session, after: str) -> set[str]: ...
def select_by_keyword(session: Session, dag: nx.DiGraph) -> set[str]: ...
def select_by_mark(session: Session, dag: nx.DiGraph) -> set[str]: ...

Expand Down Expand Up @@ -54,4 +55,5 @@ __all__ = [
"ParseError",
"select_by_keyword",
"select_by_mark",
"select_by_after_keyword",
]
38 changes: 31 additions & 7 deletions src/_pytask/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from __future__ import annotations

from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import TYPE_CHECKING
from uuid import UUID
from uuid import uuid4

from attrs import define
from attrs import field
Expand All @@ -16,18 +19,39 @@

@define
class CollectionMetadata:
"""A class for carrying metadata from functions to tasks."""

"""A class for carrying metadata from functions to tasks.
Attributes
----------
after
An expression or a task function or a list of task functions that need to be
executed before this task can.
id_
An id for the task if it is part of a parametrization. Otherwise, an automatic
id will be generated. See
:doc:`this tutorial <../tutorials/repeating_tasks_with_different_inputs>` for
more information.
kwargs
A dictionary containing keyword arguments which are passed to the task when it
is executed.
markers
A list of markers that are attached to the task.
name
Use it to override the name of the task that is, by default, the name of the
callable.
produces
Definition of products to parse the function returns and store them. See
:doc:`this how-to guide <../how_to_guides/using_task_returns>` for more
information.
"""

after: str | list[Callable[..., Any]] = field(factory=list) # type: ignore[assignment]
id_: str | None = None
"""The id for a single parametrization."""
kwargs: dict[str, Any] = field(factory=dict)
"""Contains kwargs which are necessary for the task function on execution."""
markers: list[Mark] = field(factory=list)
"""Contains the markers of the function."""
name: str | None = None
"""The name of the task function."""
produces: PyTree[Any] | None = None
"""Definition of products to handle returns."""
_id: UUID = field(factory=uuid4)


class NodeInfo(NamedTuple):
Expand Down
31 changes: 31 additions & 0 deletions src/_pytask/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
def task(
name: str | None = None,
*,
after: str | Callable[..., Any] | list[Callable[..., Any]] | None = None,
id: str | None = None, # noqa: A002
kwargs: dict[Any, Any] | None = None,
produces: PyTree[Any] | None = None,
Expand All @@ -55,6 +56,9 @@ def task(
name
Use it to override the name of the task that is, by default, the name of the
callable.
after
An expression or a task function or a list of task functions that need to be
executed before this task can.
id
An id for the task if it is part of a parametrization. Otherwise, an automatic
id will be generated. See
Expand Down Expand Up @@ -102,20 +106,23 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:

parsed_kwargs = {} if kwargs is None else kwargs
parsed_name = name if isinstance(name, str) else func.__name__
parsed_after = _parse_after(after)

if hasattr(unwrapped, "pytask_meta"):
unwrapped.pytask_meta.name = parsed_name
unwrapped.pytask_meta.kwargs = parsed_kwargs
unwrapped.pytask_meta.markers.append(Mark("task", (), {}))
unwrapped.pytask_meta.id_ = id
unwrapped.pytask_meta.produces = produces
unwrapped.pytask_meta.after = parsed_after
else:
unwrapped.pytask_meta = CollectionMetadata(
name=parsed_name,
kwargs=parsed_kwargs,
markers=[Mark("task", (), {})],
id_=id,
produces=produces,
after=parsed_after,
)

# Store it in the global variable ``COLLECTED_TASKS`` to avoid garbage
Expand All @@ -131,6 +138,30 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
return wrapper


def _parse_after(
after: str | Callable[..., Any] | list[Callable[..., Any]] | None
) -> str | list[Callable[..., Any]]:
if not after:
return []
if isinstance(after, str):
return after
if callable(after):
if not hasattr(after, "pytask_meta"):
after.pytask_meta = CollectionMetadata() # type: ignore[attr-defined]
return [after.pytask_meta._id] # type: ignore[attr-defined]
if isinstance(after, list):
new_after = []
for func in after:
if not hasattr(func, "pytask_meta"):
func.pytask_meta = CollectionMetadata() # type: ignore[attr-defined]
new_after.append(func.pytask_meta._id) # type: ignore[attr-defined]
msg = (
"'after' should be an expression string, a task, or a list of class. Got "
f"{after}, instead."
)
raise TypeError(msg)


def parse_collected_tasks_with_task_marker(
tasks: list[Callable[..., Any]],
) -> dict[str, Callable[..., Any]]:
Expand Down
6 changes: 3 additions & 3 deletions src/_pytask/traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class Traceback:

show_locals: ClassVar[bool] = False
suppress: ClassVar[tuple[Path, ...]] = (
_PLUGGY_DIRECTORY,
TREE_UTIL_LIB_DIRECTORY,
_PYTASK_DIRECTORY,
# _PLUGGY_DIRECTORY,
# TREE_UTIL_LIB_DIRECTORY,
# _PYTASK_DIRECTORY,
)

def __rich_console__(
Expand Down
39 changes: 39 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,42 @@ def func(path: Annotated[Path, Product]):
assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "Duplicated tasks" in result.output
assert "id=b.txt" in result.output


def test_task_will_be_executed_after_another_one_with_string(runner, tmp_path):
source = """
from pytask import task
from pathlib import Path
from typing_extensions import Annotated
@task(after="task_first")
def task_second():
assert Path("out.txt").exists()
def task_first() -> Annotated[str, Path("out.txt")]:
return "Hello, World!"
"""
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))

result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.OK
assert "2 Succeeded" in result.output


def test_task_will_be_executed_after_another_one_with_function(tmp_path):
source = """
from pytask import task
from pathlib import Path
from typing_extensions import Annotated
def task_first() -> Annotated[str, Path("out.txt")]:
return "Hello, World!"
@task(after=task_first)
def task_second():
assert Path("out.txt").exists()
"""
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))

session = build(paths=tmp_path)
assert session.exit_code == ExitCode.OK

0 comments on commit 701b46d

Please sign in to comment.