diff --git a/docs/conf.py b/docs/conf.py index 8861a6e8..58a37bce 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,18 +17,24 @@ "obj", "compwa_policy.check_dev_files.dependabot.DependabotOption", ), + "IO": "typing.IO", + "Iterable": "typing.Iterable", "K": "typing.TypeVar", + "Mapping": "collections.abc.Mapping", "NotRequired": ("obj", "typing.NotRequired"), - "P": "typing.ParamSpec", "P.args": ("attr", "typing.ParamSpec.args"), "P.kwargs": ("attr", "typing.ParamSpec.kwargs"), + "P": "typing.ParamSpec", "Path": "pathlib.Path", - "PythonVersion": "typing.TypeVar", + "ProjectURLs": "list", + "PyprojectTOML": "dict", + "PythonVersion": "str", + "Sequence": "typing.Sequence", "T": "typing.TypeVar", - "TOMLDocument": "tomlkit.TOMLDocument", "Table": "tomlkit.items.Table", - "V": "typing.TypeVar", + "TOMLDocument": "tomlkit.TOMLDocument", "typing_extensions.NotRequired": ("obj", "typing.NotRequired"), + "V": "typing.TypeVar", } author = "Common Partial Wave Analysis" autodoc_member_order = "bysource" @@ -102,6 +108,7 @@ ] nitpick_ignore = [ ("py:class", "CommentedMap"), + ("py:class", "ProjectURLs"), ] nitpick_ignore_regex = [ ("py:class", r"^.*.[A-Z]$"), diff --git a/src/compwa_policy/check_dev_files/black.py b/src/compwa_policy/check_dev_files/black.py index d1e3e073..cd00dc0e 100644 --- a/src/compwa_policy/check_dev_files/black.py +++ b/src/compwa_policy/check_dev_files/black.py @@ -54,8 +54,8 @@ def _remove_outdated_settings(pyproject: ModifiablePyproject) -> None: removed_options = set() for option in forbidden_options: if option in settings: + settings.pop(option) removed_options.add(option) - settings.remove(option) if removed_options: msg = ( f"Removed {', '.join(sorted(removed_options))} option from black" diff --git a/src/compwa_policy/check_dev_files/pytest.py b/src/compwa_policy/check_dev_files/pytest.py index 1be8dee0..96f30874 100644 --- a/src/compwa_policy/check_dev_files/pytest.py +++ b/src/compwa_policy/check_dev_files/pytest.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable import tomlkit from ini2toml.api import Translator @@ -70,7 +70,7 @@ def _update_settings(pyproject: ModifiablePyproject) -> None: pyproject.append_to_changelog(msg) -def __get_expected_addopts(existing: str | Array) -> Array: +def __get_expected_addopts(existing: str | Iterable) -> Array: if isinstance(existing, str): options = {opt.strip() for opt in __split_options(existing)} else: diff --git a/src/compwa_policy/check_dev_files/ruff.py b/src/compwa_policy/check_dev_files/ruff.py index bd80515a..a37edc70 100644 --- a/src/compwa_policy/check_dev_files/ruff.py +++ b/src/compwa_policy/check_dev_files/ruff.py @@ -3,10 +3,10 @@ from __future__ import annotations import os -from typing import Iterable +from collections import abc +from typing import TYPE_CHECKING, Any, Iterable, Mapping from ruamel.yaml import YAML -from tomlkit.items import Array, Table from compwa_policy.utilities import CONFIG_PATH, natural_sorting, remove_configs, vscode from compwa_policy.utilities.executor import Executor @@ -25,6 +25,9 @@ from compwa_policy.utilities.readme import add_badge, remove_badge from compwa_policy.utilities.toml import to_toml_array +if TYPE_CHECKING: + from tomlkit.items import Array + def main(has_notebooks: bool) -> None: with Executor() as do, ModifiablePyproject.load() as pyproject: @@ -98,18 +101,17 @@ def __remove_nbqa_option(pyproject: ModifiablePyproject, option: str) -> None: nbqa_table = pyproject.get_table(table_key) if option not in nbqa_table: return - nbqa_table.remove(option) + nbqa_table.pop(option) msg = f"Removed {option!r} nbQA options from {CONFIG_PATH.pyproject}" pyproject.append_to_changelog(msg) def __remove_tool_table(pyproject: ModifiablePyproject, tool_table: str) -> None: - table_key = f"tool.{tool_table}" - if not pyproject.has_table(table_key): - return - pyproject._document["tool"].remove(tool_table) # type: ignore[union-attr] - msg = f"Removed [tool.{tool_table}] section from {CONFIG_PATH.pyproject}" - pyproject.append_to_changelog(msg) + tools = pyproject._document.get("tool") + if isinstance(tools, dict) and tool_table in tools: + tools.pop(tool_table) + msg = f"Removed [tool.{tool_table}] section from {CONFIG_PATH.pyproject}" + pyproject.append_to_changelog(msg) def _remove_pydocstyle(pyproject: ModifiablePyproject) -> None: @@ -153,11 +155,13 @@ def _move_ruff_lint_config(pyproject: ModifiablePyproject) -> None: } global_settings = pyproject.get_table("tool.ruff", create=True) lint_settings = {k: v for k, v in global_settings.items() if k in lint_option_keys} - lint_arrays = {k: v for k, v in lint_settings.items() if isinstance(v, Array)} + lint_arrays = { + k: v for k, v in lint_settings.items() if isinstance(v, abc.Sequence) + } if lint_arrays: lint_config = pyproject.get_table("tool.ruff.lint", create=True) lint_config.update(lint_arrays) - lint_tables = {k: v for k, v in lint_settings.items() if isinstance(v, Table)} + lint_tables = {k: v for k, v in lint_settings.items() if isinstance(v, abc.Mapping)} for table in lint_tables: lint_config = pyproject.get_table(f"tool.ruff.lint.{table}", create=True) lint_config.update(lint_tables[table]) @@ -322,7 +326,7 @@ def ___get_selected_ruff_rules(pyproject: Pyproject) -> Array: return to_toml_array(sorted(rules)) -def ___get_task_tags(ruff_settings: Table) -> Array: +def ___get_task_tags(ruff_settings: Mapping[str, Any]) -> Array: existing: set[str] = set(ruff_settings.get("task-tags", set())) expected = { "cspell", diff --git a/src/compwa_policy/utilities/pyproject/__init__.py b/src/compwa_policy/utilities/pyproject/__init__.py index a49ad1ce..d8b12911 100644 --- a/src/compwa_policy/utilities/pyproject/__init__.py +++ b/src/compwa_policy/utilities/pyproject/__init__.py @@ -7,7 +7,17 @@ from contextlib import AbstractContextManager from pathlib import Path from textwrap import indent -from typing import IO, TYPE_CHECKING, Iterable, Sequence, TypeVar, overload +from typing import ( + IO, + TYPE_CHECKING, + Any, + Iterable, + Mapping, + MutableMapping, + Sequence, + TypeVar, + overload, +) import tomlkit from attrs import field, frozen @@ -22,6 +32,7 @@ get_sub_table, get_supported_python_versions, has_sub_table, + load_pyproject_toml, ) from compwa_policy.utilities.pyproject.setters import ( add_dependency, @@ -40,8 +51,7 @@ if TYPE_CHECKING: from types import TracebackType - from tomlkit.items import Table - from tomlkit.toml_document import TOMLDocument + from compwa_policy.utilities.pyproject._struct import PyprojectTOML T = TypeVar("T", bound="Pyproject") @@ -50,34 +60,24 @@ class Pyproject: """Read-only representation of a :code:`pyproject.toml` file.""" - _document: TOMLDocument + _document: PyprojectTOML _source: IO | Path | None = field(default=None) @final @classmethod def load(cls: type[T], source: IO | Path | str = CONFIG_PATH.pyproject) -> T: """Load a :code:`pyproject.toml` file from a file, I/O stream, or `str`.""" - if isinstance(source, io.IOBase): - current_position = source.tell() - source.seek(0) - document = tomlkit.load(source) # type:ignore[arg-type] - source.seek(current_position) - return cls(document, source) - if isinstance(source, Path): - with open(source) as stream: - document = tomlkit.load(stream) - return cls(document, source) + document = load_pyproject_toml(source) if isinstance(source, str): - return cls(tomlkit.loads(source)) - msg = f"Source of type {type(source).__name__} is not supported" - raise TypeError(msg) + return cls(document) + return cls(document, source) @final def dumps(self) -> str: src = tomlkit.dumps(self._document, sort_keys=True) return f"{src.strip()}\n" - def get_table(self, dotted_header: str, create: bool = False) -> Table: + def get_table(self, dotted_header: str, create: bool = False) -> Mapping[str, Any]: if create: msg = "Cannot create sub-tables in a read-only pyproject.toml" raise TypeError(msg) @@ -166,11 +166,13 @@ def dump(self, target: IO | Path | str | None = None) -> None: raise TypeError(msg) @override - def get_table(self, dotted_header: str, create: bool = False) -> Table: + def get_table( + self, dotted_header: str, create: bool = False + ) -> MutableMapping[str, Any]: self.__assert_is_in_context() if create: create_sub_table(self._document, dotted_header) - return super().get_table(dotted_header) + return super().get_table(dotted_header) # type:ignore[return-value] def add_dependency( self, package: str, optional_key: str | Sequence[str] | None = None @@ -200,7 +202,7 @@ def append_to_changelog(self, message: str) -> None: self._changelog.append(message) -def complies_with_subset(settings: dict, minimal_settings: dict) -> bool: +def complies_with_subset(settings: Mapping, minimal_settings: Mapping) -> bool: return all(settings.get(key) == value for key, value in minimal_settings.items()) diff --git a/src/compwa_policy/utilities/pyproject/_struct.py b/src/compwa_policy/utilities/pyproject/_struct.py new file mode 100644 index 00000000..bf69e830 --- /dev/null +++ b/src/compwa_policy/utilities/pyproject/_struct.py @@ -0,0 +1,63 @@ +"""This module is hidden Sphinx can't handle `typing.TypedDict` with hyphens. + +See https://github.com/sphinx-doc/sphinx/issues/11039. +""" + +import sys +from typing import Dict, List + +if sys.version_info < (3, 8): + from typing_extensions import TypedDict +else: + from typing import TypedDict +if sys.version_info < (3, 11): + from typing_extensions import NotRequired +else: + from typing import NotRequired + +PyprojectTOML = TypedDict( + "PyprojectTOML", + { + "build-system": NotRequired["BuildSystem"], + "project": "Project", + "tool": NotRequired[Dict[str, Dict[str, str]]], + }, +) +"""Structure of a `pyproject.toml` file. + +See [pyproject.toml +specification](https://packaging.python.org/en/latest/specifications/pyproject-toml). +""" + + +BuildSystem = TypedDict( + "BuildSystem", + { + "requires": List[str], + "build-backend": str, + }, +) + + +Project = TypedDict( + "Project", + { + "name": str, + "version": NotRequired[str], + "dependencies": NotRequired[List[str]], + "optional-dependencies": NotRequired[Dict[str, List[str]]], + "urls": NotRequired["ProjectURLs"], + }, +) + + +class ProjectURLs(TypedDict): + """Project for PyPI.""" + + Changelog: NotRequired[str] + Documentation: NotRequired[str] + Homepage: NotRequired[str] + Issues: NotRequired[str] + Repository: NotRequired[str] + Source: NotRequired[str] + Tracker: NotRequired[str] diff --git a/src/compwa_policy/utilities/pyproject/getters.py b/src/compwa_policy/utilities/pyproject/getters.py index ad4463c1..20816c75 100644 --- a/src/compwa_policy/utilities/pyproject/getters.py +++ b/src/compwa_policy/utilities/pyproject/getters.py @@ -13,28 +13,29 @@ from compwa_policy.errors import PrecommitError from compwa_policy.utilities import CONFIG_PATH +if TYPE_CHECKING: + from collections.abc import Mapping + + from compwa_policy.utilities.pyproject._struct import ProjectURLs, PyprojectTOML + if sys.version_info < (3, 8): from typing_extensions import Literal else: from typing import Literal -if TYPE_CHECKING: - from tomlkit.container import Container - from tomlkit.items import Table - from tomlkit.toml_document import TOMLDocument PythonVersion = Literal["3.6", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] @overload -def get_package_name(doc: TOMLDocument) -> str | None: ... +def get_package_name(doc: PyprojectTOML) -> str | None: ... @overload def get_package_name( - doc: TOMLDocument, raise_on_missing: Literal[False] + doc: PyprojectTOML, raise_on_missing: Literal[False] ) -> str | None: ... @overload -def get_package_name(doc: TOMLDocument, raise_on_missing: Literal[True]) -> str: ... -def get_package_name(doc: TOMLDocument, raise_on_missing: bool = False): # type: ignore[no-untyped-def] +def get_package_name(doc: PyprojectTOML, raise_on_missing: Literal[True]) -> str: ... +def get_package_name(doc: PyprojectTOML, raise_on_missing: bool = False): # type: ignore[no-untyped-def] if not has_sub_table(doc, "project"): if raise_on_missing: msg = "Please provide a name for the package under the [project] table in pyproject.toml" @@ -47,7 +48,7 @@ def get_package_name(doc: TOMLDocument, raise_on_missing: bool = False): # type return package_name -def get_project_urls(pyproject: TOMLDocument) -> Table: +def get_project_urls(pyproject: PyprojectTOML) -> ProjectURLs: project_table = get_sub_table(pyproject, "project") urls = project_table.get("urls") if urls is None: @@ -64,7 +65,7 @@ def get_project_urls(pyproject: TOMLDocument) -> Table: return urls -def get_source_url(pyproject: TOMLDocument) -> str: +def get_source_url(pyproject: PyprojectTOML) -> str: urls = get_project_urls(pyproject) source_url = urls.get("Source") if source_url is None: @@ -73,7 +74,7 @@ def get_source_url(pyproject: TOMLDocument) -> str: return source_url -def get_supported_python_versions(pyproject: TOMLDocument) -> list[PythonVersion]: +def get_supported_python_versions(pyproject: PyprojectTOML) -> list[PythonVersion]: """Extract sorted list of supported Python versions from package classifiers. >>> import tomlkit @@ -120,9 +121,9 @@ def _extract_python_versions(classifiers: list[str]) -> list[PythonVersion]: return [s.replace(prefix, "") for s in version_classifiers] # type: ignore[misc] -def get_sub_table(config: Container, dotted_header: str) -> Table: +def get_sub_table(config: Mapping[str, Any], dotted_header: str) -> Mapping[str, Any]: """Get a TOML sub-table through a dotted header key.""" - current_table: Any = config + current_table = config for header in dotted_header.split("."): if header not in current_table: msg = f"TOML data does not contain {dotted_header!r}" @@ -131,7 +132,7 @@ def get_sub_table(config: Container, dotted_header: str) -> Table: return current_table -def has_sub_table(config: Container, dotted_header: str) -> bool: +def has_sub_table(config: Mapping[str, Any], dotted_header: str) -> bool: current_table: Any = config for header in dotted_header.split("."): if header in current_table: @@ -143,7 +144,7 @@ def has_sub_table(config: Container, dotted_header: str) -> bool: def load_pyproject_toml( source: IO | Path | str = CONFIG_PATH.pyproject, -) -> TOMLDocument: +) -> PyprojectTOML: """Load a :code:`pyproject.toml` file from a file, I/O stream, or `str`.""" if isinstance(source, io.IOBase): current_position = source.tell() @@ -153,8 +154,8 @@ def load_pyproject_toml( return document if isinstance(source, Path): with open(source) as stream: - return tomlkit.load(stream) + return tomlkit.load(stream) # type:ignore[return-value] if isinstance(source, str): - return tomlkit.loads(source) + return tomlkit.loads(source) # type:ignore[return-value] msg = f"Source of type {type(source).__name__} is not supported" raise TypeError(msg) diff --git a/src/compwa_policy/utilities/pyproject/setters.py b/src/compwa_policy/utilities/pyproject/setters.py index e304b628..242ed8b9 100644 --- a/src/compwa_policy/utilities/pyproject/setters.py +++ b/src/compwa_policy/utilities/pyproject/setters.py @@ -3,26 +3,28 @@ from __future__ import annotations from collections import abc -from typing import TYPE_CHECKING, Any, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Iterable, Mapping, MutableMapping, Sequence, cast import tomlkit -from tomlkit import TOMLDocument -from compwa_policy.utilities.pyproject.getters import get_package_name, get_sub_table +from compwa_policy.utilities.pyproject.getters import get_package_name +from compwa_policy.utilities.pyproject.getters import ( + get_sub_table as get_immutable_sub_table, +) from compwa_policy.utilities.toml import to_toml_array if TYPE_CHECKING: - from tomlkit.container import Container from tomlkit.items import Table + from compwa_policy.utilities.pyproject._struct import PyprojectTOML + def add_dependency( - pyproject: TOMLDocument, + pyproject: PyprojectTOML, package: str, optional_key: str | Sequence[str] | None = None, ) -> bool: if optional_key is None: - create_sub_table(pyproject, "project") project = get_sub_table(pyproject, "project") existing_dependencies: set[str] = set(project.get("dependencies", [])) if package in existing_dependencies: @@ -32,7 +34,6 @@ def add_dependency( return True if isinstance(optional_key, str): table_key = "project.optional-dependencies" - create_sub_table(pyproject, table_key) optional_dependencies = get_sub_table(pyproject, table_key) existing_dependencies = set(optional_dependencies.get(optional_key, [])) if package in existing_dependencies: @@ -63,7 +64,7 @@ def _sort_taplo(items: Iterable[str]) -> list[str]: return sorted(items, key=lambda s: ('"' in s, s)) -def create_sub_table(config: Container, dotted_header: str) -> Table: +def create_sub_table(config: Mapping[str, Any], dotted_header: str) -> Table: """Create a TOML sub-table through a dotted header key.""" current_table: Any = config for header in dotted_header.split("."): @@ -73,8 +74,16 @@ def create_sub_table(config: Container, dotted_header: str) -> Table: return current_table +def get_sub_table( + config: Mapping[str, Any], dotted_header: str +) -> MutableMapping[str, Any]: + create_sub_table(config, dotted_header) + table = get_immutable_sub_table(config, dotted_header) + return cast(MutableMapping[str, Any], table) + + def remove_dependency( - pyproject: TOMLDocument, + pyproject: PyprojectTOML, package: str, ignored_sections: Iterable[str] | None = None, ) -> bool: diff --git a/tests/utilities/pyproject/test_getters.py b/tests/utilities/pyproject/test_getters.py index dd0726da..cb4014ba 100644 --- a/tests/utilities/pyproject/test_getters.py +++ b/tests/utilities/pyproject/test_getters.py @@ -1,5 +1,4 @@ import pytest -import tomlkit from compwa_policy.errors import PrecommitError from compwa_policy.utilities.pyproject.getters import ( @@ -9,11 +8,12 @@ get_sub_table, get_supported_python_versions, has_sub_table, + load_pyproject_toml, ) def test_get_package_name(): - src = tomlkit.loads(""" + src = load_pyproject_toml(""" [project] name = "my-package" """) @@ -22,7 +22,7 @@ def test_get_package_name(): def test_get_package_name_missing(): - src = tomlkit.loads(""" + src = load_pyproject_toml(""" [server] ip = "192.168.1.1" port = 8000 @@ -43,7 +43,7 @@ def test_get_package_name_missing(): def test_get_project_urls(): - pyproject = tomlkit.loads(""" + pyproject = load_pyproject_toml(""" [project] name = "my-package" @@ -61,7 +61,7 @@ def test_get_project_urls(): def test_get_project_urls_missing(): - pyproject = tomlkit.loads(""" + pyproject = load_pyproject_toml(""" [project] name = "my-package" """) @@ -73,7 +73,7 @@ def test_get_project_urls_missing(): def test_get_source_url_missing(): - pyproject = tomlkit.loads(""" + pyproject = load_pyproject_toml(""" [project.urls] Documentation = "https://ampform.rtfd.io" """) @@ -85,7 +85,7 @@ def test_get_source_url_missing(): def test_get_supported_python_versions(): - pyproject = tomlkit.loads(""" + pyproject = load_pyproject_toml(""" [project] name = "my-package" classifiers = [ @@ -100,7 +100,7 @@ def test_get_supported_python_versions(): def test_get_sub_table(): - document = tomlkit.loads(""" + document = load_pyproject_toml(""" [project] name = "my-package" @@ -124,7 +124,7 @@ def test_get_sub_table(): def test_has_sub_table(): - document = tomlkit.loads(""" + document = load_pyproject_toml(""" [project] name = "my-package" diff --git a/tests/utilities/pyproject/test_setters.py b/tests/utilities/pyproject/test_setters.py index 0623bdbd..4343882b 100644 --- a/tests/utilities/pyproject/test_setters.py +++ b/tests/utilities/pyproject/test_setters.py @@ -2,8 +2,9 @@ import pytest import tomlkit -from tomlkit import TOMLDocument +from compwa_policy.utilities.pyproject._struct import PyprojectTOML +from compwa_policy.utilities.pyproject.getters import load_pyproject_toml from compwa_policy.utilities.pyproject.setters import ( add_dependency, create_sub_table, @@ -12,7 +13,7 @@ def test_add_dependency(): - pyproject = tomlkit.loads(""" + pyproject = load_pyproject_toml(""" [project] name = "my-package" """) @@ -29,7 +30,7 @@ def test_add_dependency(): def test_add_dependency_existing(): - pyproject = tomlkit.loads(""" + pyproject = load_pyproject_toml(""" [project] dependencies = ["attrs"] @@ -49,7 +50,7 @@ def test_add_dependency_nested(): name = "my-package" """) - pyproject = tomlkit.loads(src) + pyproject = load_pyproject_toml(src) add_dependency(pyproject, "ruff", optional_key=["lint", "sty", "dev"]) new_content = tomlkit.dumps(pyproject) @@ -70,7 +71,7 @@ def test_add_dependency_optional(): [project] name = "my-package" """) - pyproject = tomlkit.loads(src) + pyproject = load_pyproject_toml(src) add_dependency(pyproject, "ruff", optional_key="lint") new_content = tomlkit.dumps(pyproject) @@ -85,7 +86,7 @@ def test_add_dependency_optional(): @pytest.fixture(scope="function") -def pyproject_example() -> TOMLDocument: +def pyproject_example() -> PyprojectTOML: src = dedent(""" [project] name = "my-package" @@ -98,10 +99,10 @@ def pyproject_example() -> TOMLDocument: ] sty = ["ruff"] """) - return tomlkit.loads(src) + return load_pyproject_toml(src) -def test_remove_dependency(pyproject_example: TOMLDocument): +def test_remove_dependency(pyproject_example: PyprojectTOML): remove_dependency(pyproject_example, "attrs") expected = dedent(""" [project] @@ -119,7 +120,7 @@ def test_remove_dependency(pyproject_example: TOMLDocument): assert new_content == expected -def test_remove_dependency_nested(pyproject_example: TOMLDocument): +def test_remove_dependency_nested(pyproject_example: PyprojectTOML): remove_dependency(pyproject_example, "ruff", ignored_sections=["sty"]) new_content = tomlkit.dumps(pyproject_example) expected = dedent(""" @@ -138,7 +139,7 @@ def test_remove_dependency_nested(pyproject_example: TOMLDocument): @pytest.mark.parametrize("table_key", ["project", "project.optional-dependencies"]) def test_create_sub_table(table_key: str): - pyproject = tomlkit.loads("") + pyproject = load_pyproject_toml("") dependencies = create_sub_table(pyproject, table_key) new_content = tomlkit.dumps(pyproject) diff --git a/tests/utilities/test_pyproject.py b/tests/utilities/test_pyproject.py index 3e851b1e..86656e74 100644 --- a/tests/utilities/test_pyproject.py +++ b/tests/utilities/test_pyproject.py @@ -4,7 +4,6 @@ from textwrap import dedent import pytest -from tomlkit.items import Table from compwa_policy.utilities.pyproject import ModifiablePyproject, Pyproject from compwa_policy.utilities.toml import to_toml_array @@ -39,8 +38,12 @@ def test_load_from_str(self): name = "my-package" requires-python = ">=3.7" """) - assert isinstance(pyproject._document["build-system"], Table) - assert pyproject._document["project"]["dependencies"] == [ # type: ignore[index] + assert set(pyproject._document) == { + "build-system", + "project", + } + project = pyproject.get_table("project") + assert project.get("dependencies") == [ "attrs", "sympy >=1.10", ]