From f0d599456f45b7586b496e1ece25550f019ce870 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Wed, 13 Dec 2023 00:03:24 +0100 Subject: [PATCH] Add error message for not collected tasks with @task decorator. (#521) --- docs/source/changes.md | 3 +++ src/_pytask/collect.py | 45 ++++++++++++++++++++++++++++++++++++++++++ src/_pytask/console.py | 5 ----- src/_pytask/task.py | 3 ++- tests/test_task.py | 27 +++++++++++++++++++++++++ 5 files changed, 77 insertions(+), 6 deletions(-) diff --git a/docs/source/changes.md b/docs/source/changes.md index bf94f0a6..3767bb2a 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -12,6 +12,9 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and (fixes #514). It also warns if the path is configured as a string and not a list of strings. - {pull}`519` raises an error when builtin functions are wrapped with + {func}`~pytask.task`. Closes {issue}`512`.pull +- {pull}`521` raises an error message when imported functions are wrapped with + {func}`@task ` in a task module. Fixes {issue}`513`. {func}`~pytask.task`. Closes {issue}`512`. - {pull}`522` improves the issue templates. - {pull}`523` refactors `_pytask.console._get_file`. diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index c8169de0..04ac49c4 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -21,6 +21,7 @@ from _pytask.console import create_summary_panel from _pytask.console import get_file from _pytask.exceptions import CollectionError +from _pytask.exceptions import NodeNotCollectedError from _pytask.mark_utils import get_all_marks from _pytask.mark_utils import has_mark from _pytask.node_protocols import PNode @@ -37,6 +38,7 @@ from _pytask.path import shorten_path from _pytask.reports import CollectionReport from _pytask.shared import find_duplicates +from _pytask.task_utils import COLLECTED_TASKS from _pytask.task_utils import task as task_decorator from _pytask.typing import is_task_function from rich.text import Text @@ -61,6 +63,7 @@ def pytask_collect(session: Session) -> bool: _collect_from_paths(session) _collect_from_tasks(session) + _collect_not_collected_tasks(session) session.tasks.extend( i.node @@ -108,6 +111,9 @@ def _collect_from_tasks(session: Session) -> None: path = get_file(raw_task) name = raw_task.pytask_meta.name + if has_mark(raw_task, "task"): + COLLECTED_TASKS[path].remove(raw_task) + # When a task is not a callable, it can be anything or a PTask. Set arbitrary # values and it will pass without errors and not collected. else: @@ -126,6 +132,45 @@ def _collect_from_tasks(session: Session) -> None: session.collection_reports.append(report) +_FAILED_COLLECTING_TASK = """\ +Failed to collect task '{name}'{path_desc}. + +This can happen when the task function is defined in another module, imported to a \ +task module and wrapped with the '@task' decorator. + +To collect this task correctly, wrap the imported function in a lambda expression like + +task(...)(lambda **x: imported_function(**x)). +""" + + +def _collect_not_collected_tasks(session: Session) -> None: + """Collect tasks that are not collected yet and create failed reports.""" + for path in list(COLLECTED_TASKS): + tasks = COLLECTED_TASKS.pop(path) + for task in tasks: + name = task.pytask_meta.name # type: ignore[attr-defined] + node: PTask + if path: + node = Task(base_name=name, path=path, function=task) + path_desc = f" in '{path}'" + else: + node = TaskWithoutPath(name=name, function=task) + path_desc = "" + report = CollectionReport( + outcome=CollectionOutcome.FAIL, + node=node, + exc_info=( + NodeNotCollectedError, + NodeNotCollectedError( + _FAILED_COLLECTING_TASK.format(name=name, path_desc=path_desc) + ), + None, + ), + ) + session.collection_reports.append(report) + + @hookimpl def pytask_ignore_collect(path: Path, config: dict[str, Any]) -> bool: """Ignore a path during the collection.""" diff --git a/src/_pytask/console.py b/src/_pytask/console.py index a0bf5b01..e3c2def5 100644 --- a/src/_pytask/console.py +++ b/src/_pytask/console.py @@ -225,18 +225,13 @@ def get_file( # noqa: PLR0911 return get_file(function.__wrapped__) source_file = inspect.getsourcefile(function) if source_file: - # Handle functions defined in the REPL. if "" in source_file: return None - # Handle lambda functions. if "" in source_file: try: return Path(function.__globals__["__file__"]).absolute().resolve() except KeyError: return None - # Handle functions defined in Jupyter notebooks. - if "ipykernel" in source_file or "ipython-input" in source_file: - return None return Path(source_file).absolute().resolve() return None diff --git a/src/_pytask/task.py b/src/_pytask/task.py index 0f25fa0f..2309695b 100644 --- a/src/_pytask/task.py +++ b/src/_pytask/task.py @@ -76,6 +76,7 @@ def _raise_error_when_task_functions_are_duplicated( msg = ( "There are some duplicates among the repeated tasks. It happens when you define" "the task function outside the loop body and merely wrap in the loop body with " - f"the '@task(...)' decorator.\n\n{flat_tree}" + "the 'task(...)(func)' decorator. As a workaround, wrap the task function in " + f"a lambda expression like 'task(...)(lambda **x: func(**x))'.\n\n{flat_tree}" ) raise ValueError(msg) diff --git a/tests/test_task.py b/tests/test_task.py index b95194ec..83357627 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -679,3 +679,30 @@ def test_raise_error_with_builtin_function_as_task(runner, tmp_path): result = runner.invoke(cli, [tmp_path.as_posix()]) assert result.exit_code == ExitCode.COLLECTION_FAILED assert "Builtin functions cannot be wrapped" in result.output + + +def test_task_function_in_another_module(runner, tmp_path): + source = """ + def func(): + return "Hello, World!" + """ + tmp_path.joinpath("module.py").write_text(textwrap.dedent(source)) + + source = """ + from pytask import task + from pathlib import Path + from _pytask.path import import_path + import inspect + + _ROOT_PATH = Path(__file__).parent + + module = import_path(_ROOT_PATH / "module.py", _ROOT_PATH) + name_to_obj = dict(inspect.getmembers(module)) + + task(produces=Path("out.txt"))(name_to_obj["func"]) + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.COLLECTION_FAILED + assert "1 Failed" in result.output