diff --git a/.cspell.json b/.cspell.json index 5c95eb2c..b89a0bb9 100644 --- a/.cspell.json +++ b/.cspell.json @@ -47,9 +47,11 @@ "ipython", "mkdir", "mypy", + "oneline", "pytest", "PYTHONHASHSEED", "repoma", + "sympy", "toctree", "Zenodo" ], diff --git a/src/repoma/colab_toc_visible.py b/src/repoma/colab_toc_visible.py index 3c35dd3b..f48d831a 100644 --- a/src/repoma/colab_toc_visible.py +++ b/src/repoma/colab_toc_visible.py @@ -9,7 +9,8 @@ from typing import Optional, Sequence import nbformat -from nbformat import NotebookNode + +from repoma.utilities.notebook import load_notebook from .errors import PrecommitError from .utilities.executor import Executor @@ -30,7 +31,7 @@ def main(argv: Optional[Sequence[str]] = None) -> int: def _update_metadata(path: str) -> None: - notebook = open_notebook(path) + notebook = load_notebook(path) metadata = notebook["metadata"] updated = False if metadata.get("colab") is None: @@ -46,9 +47,5 @@ def _update_metadata(path: str) -> None: raise PrecommitError(msg) -def open_notebook(path: str) -> NotebookNode: - return nbformat.read(path, as_version=nbformat.NO_CONVERT) - - if __name__ == "__main__": sys.exit(main()) diff --git a/src/repoma/fix_nbformat_version.py b/src/repoma/fix_nbformat_version.py index e4464fed..26559479 100644 --- a/src/repoma/fix_nbformat_version.py +++ b/src/repoma/fix_nbformat_version.py @@ -11,6 +11,8 @@ import nbformat +from repoma.utilities.notebook import load_notebook + from .errors import PrecommitError from .utilities.executor import Executor @@ -34,14 +36,14 @@ def main(argv: Optional[Sequence[str]] = None) -> int: def set_nbformat_version(filename: str) -> None: - notebook = open_notebook(filename) + notebook = load_notebook(filename) if notebook["nbformat_minor"] != 4: # noqa: PLR2004 notebook["nbformat_minor"] = 4 nbformat.write(notebook, filename) def remove_cell_ids(filename: str) -> None: - notebook = open_notebook(filename) + notebook = load_notebook(filename) for cell in notebook["cells"]: if "id" in cell: del cell["id"] @@ -49,7 +51,7 @@ def remove_cell_ids(filename: str) -> None: def check_svg_output_cells(filename: str) -> None: - notebook = open_notebook(filename) + notebook = load_notebook(filename) for i, cell in enumerate(notebook["cells"]): for output in cell.get("outputs", []): data = output.get("data", {}) @@ -66,9 +68,5 @@ def check_svg_output_cells(filename: str) -> None: ) -def open_notebook(filename: str) -> dict: - return nbformat.read(filename, as_version=nbformat.NO_CONVERT) - - if __name__ == "__main__": sys.exit(main()) diff --git a/src/repoma/pin_nb_requirements.py b/src/repoma/pin_nb_requirements.py index 5cfa4e10..f639b604 100644 --- a/src/repoma/pin_nb_requirements.py +++ b/src/repoma/pin_nb_requirements.py @@ -7,42 +7,43 @@ """ import argparse +import re import sys +from functools import lru_cache +from textwrap import dedent from typing import List, Optional, Sequence import nbformat from nbformat import NotebookNode from repoma.utilities.executor import Executor +from repoma.utilities.notebook import load_notebook from .errors import PrecommitError -__PIP_INSTALL_STATEMENT = "%pip install -q " +__EXPECTED_PIP_INSTALL_LINE = "%pip install -q" def check_pinned_requirements(filename: str) -> None: - notebook = nbformat.read(filename, as_version=nbformat.NO_CONVERT) + notebook = load_notebook(filename) if not __has_python_kernel(notebook): return - for cell in notebook["cells"]: + for cell_id, cell in enumerate(notebook["cells"]): if cell["cell_type"] != "code": continue - source: str = cell["source"] - src_lines = source.split("\n") - if len(src_lines) == 0: - continue - cell_content = "".join(s.strip("\\") for s in src_lines) - if not cell_content.startswith(__PIP_INSTALL_STATEMENT): + source = __to_oneline(cell["source"]) + pip_requirements = extract_pip_requirements(source) + if pip_requirements is None: continue executor = Executor() - executor(__check_install_statement, filename, cell_content) - executor(__check_requirements, filename, cell_content) - executor(__update_metadata, filename, cell["metadata"], notebook) + executor(_check_pip_requirements, filename, pip_requirements) + executor(_format_pip_requirements, filename, source, notebook, cell_id) + executor(_update_metadata, filename, cell["metadata"], notebook) executor.finalize() return msg = ( f'Notebook "{filename}" does not contain a pip install cell of the form' - f" {__PIP_INSTALL_STATEMENT}some-package==0.1.0 package2==3.2" + f" {__EXPECTED_PIP_INSTALL_LINE} some-package==0.1.0 package2==3.2" ) raise PrecommitError(msg) @@ -55,50 +56,101 @@ def __has_python_kernel(notebook: dict) -> bool: return "python" in kernel_language -def __check_install_statement(filename: str, install_statement: str) -> None: - if not install_statement.startswith(__PIP_INSTALL_STATEMENT): - msg = ( - f"First shell cell in notebook {filename} does not start with" - f" {__PIP_INSTALL_STATEMENT}" - ) - raise PrecommitError(msg) - if install_statement.endswith("/dev/null"): - msg = ( - "Remove the /dev/null from the pip install statement in notebook" - f" {filename}" - ) - raise PrecommitError(msg) +@lru_cache(maxsize=1) +def __to_oneline(source: str) -> str: + src_lines = source.split("\n") + return "".join(s.rstrip().rstrip("\\") for s in src_lines) + + +@lru_cache(maxsize=1) +def extract_pip_requirements(source: str) -> Optional[List[str]]: + r"""Check if the source in a cell is a pip install statement. + + >>> extract_pip_requirements("Not a pip install statement") + >>> extract_pip_requirements("pip install") + [] + >>> extract_pip_requirements("pip3 install attrs") + ['attrs'] + >>> extract_pip_requirements("pip3 install -q attrs") + ['attrs'] + >>> extract_pip_requirements("pip3 install attrs &> /dev/null") + ['attrs'] + >>> extract_pip_requirements("%pip install attrs numpy==1.24.4 ") + ['attrs', 'numpy==1.24.4'] + >>> extract_pip_requirements("!python3 -mpip install sympy") + ['sympy'] + >>> extract_pip_requirements(''' + ... python3 -m pip install \ + ... attrs numpy \ + ... sympy \ + ... tensorflow + ... ''') + ['attrs', 'numpy', 'sympy', 'tensorflow'] + """ + # cspell:ignore mpip + matches = re.match( + r"[%\!]?\s*(python3?\s+-m\s*)?pip3?\s+install\s*(-q)?(.*?)(&?>\s*/dev/null)?$", + __to_oneline(source).strip(), + ) + if matches is None: + return None + packages = matches.group(3).split(" ") + packages = [p.strip() for p in packages] + return [p for p in packages if p] -def __check_requirements(filename: str, install_statement: str) -> None: - package_listing = install_statement.replace(__PIP_INSTALL_STATEMENT, "") - requirements = package_listing.split(" ") +def _check_pip_requirements(filename: str, requirements: List[str]) -> None: if len(requirements) == 0: msg = f'At least one dependency required in install cell of "{filename}"' raise PrecommitError(msg) - for requirement in requirements: - requirement = requirement.strip() - if not requirement: + for req in requirements: + req = req.strip() + if not req: continue - if "git+" in requirement: + if "git+" in req: continue - if not any(equal_sign in requirement for equal_sign in ["==", "~="]): + unpinned_requirements = [] + for req in requirements: + if req.startswith("git+"): + continue + if any(equal_sign in req for equal_sign in ["==", "~="]): + continue + package = req.split("<")[0].split(">")[0].strip() + unpinned_requirements.append(package) + if unpinned_requirements: msg = ( - f'Install cell in notebook "{filename}" contains a requirement without' - f" == or ~= ({requirement})" + f'Install cell in notebook "{filename}" contains requirements without' + "pinning (== or ~=):" ) + for req in unpinned_requirements: + msg += f"\n - {req}" + msg += dedent(f""" + Get the currently installed versions with: + + python3 -m pip freeze | grep -iE '{"|".join(sorted(unpinned_requirements))}' + """) raise PrecommitError(msg) - requirements_lower = [r.lower() for r in requirements if not r.startswith("git+")] - if sorted(requirements_lower) != requirements_lower: - sorted_requirements = " ".join(sorted(requirements)) - msg = ( - f'Requirements in notebook "{filename}" are not sorted alphabetically.' - f" Should be:\n\n {sorted_requirements}" - ) + + +def _format_pip_requirements( + filename: str, install_statement: str, notebook: NotebookNode, cell_id: int +) -> None: + requirements = extract_pip_requirements(install_statement) + if requirements is None: + return + git_requirements = {r for r in requirements if r.startswith("git+")} + pip_requirements = set(requirements) - git_requirements + pip_requirements = {r.lower().replace("_", "-") for r in pip_requirements} + sorted_requirements = sorted(pip_requirements) + sorted(git_requirements) + expected = f"{__EXPECTED_PIP_INSTALL_LINE} {' '.join(sorted_requirements)}" + if install_statement != expected: + notebook["cells"][cell_id]["source"] = expected + nbformat.write(notebook, filename) + msg = f'Ordered and formatted pip install cell in "{filename}"' raise PrecommitError(msg) -def __update_metadata(filename: str, metadata: dict, notebook: NotebookNode) -> None: +def _update_metadata(filename: str, metadata: dict, notebook: NotebookNode) -> None: updated_metadata = False jupyter_metadata = metadata.get("jupyter") if jupyter_metadata is not None and jupyter_metadata.get("source_hidden"): diff --git a/src/repoma/set_nb_cells.py b/src/repoma/set_nb_cells.py index bb91982a..64942f7f 100644 --- a/src/repoma/set_nb_cells.py +++ b/src/repoma/set_nb_cells.py @@ -26,6 +26,7 @@ import nbformat +from repoma.utilities.notebook import load_notebook from repoma.utilities.project_info import get_pypi_name __CONFIG_CELL_CONTENT = """ @@ -132,7 +133,7 @@ def _update_cell( ) -> None: if _skip_notebook(filename): return - notebook = nbformat.read(filename, as_version=nbformat.NO_CONVERT) + notebook = load_notebook(filename) exiting_cell = notebook["cells"][cell_id] new_cell = nbformat.v4.new_code_cell( new_content, @@ -150,7 +151,7 @@ def _update_cell( def _insert_autolink_concat(filename: str) -> None: if _skip_notebook(filename, ignore_statement=""): return - notebook = nbformat.read(filename, as_version=nbformat.NO_CONVERT) + notebook = load_notebook(filename) expected_cell_content = """ ```{autolink-concat} ``` @@ -173,7 +174,7 @@ def _insert_autolink_concat(filename: str) -> None: def _skip_notebook( filename: str, ignore_statement: str = "" ) -> bool: - notebook = nbformat.read(filename, as_version=nbformat.NO_CONVERT) + notebook = load_notebook(filename) for cell in notebook["cells"]: if cell["cell_type"] != "markdown": continue diff --git a/src/repoma/utilities/notebook.py b/src/repoma/utilities/notebook.py new file mode 100644 index 00000000..e586e52a --- /dev/null +++ b/src/repoma/utilities/notebook.py @@ -0,0 +1,8 @@ +"""Helper tools for working with Jupyter Notebooks.""" + +import nbformat +from nbformat import NotebookNode + + +def load_notebook(path: str) -> NotebookNode: + return nbformat.read(path, as_version=nbformat.NO_CONVERT)