From 80af221cee1dad4e295b03097bc4b83d1926ca39 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 21 Nov 2023 18:04:27 +0100 Subject: [PATCH] ENH: automatically format `pip install` cell --- .cspell.json | 2 + src/repoma/pin_nb_requirements.py | 128 +++++++++++++++++++----------- 2 files changed, 83 insertions(+), 47 deletions(-) diff --git a/.cspell.json b/.cspell.json index 5c95eb2c..b89a0bb9 100644 --- a/.cspell.json +++ b/.cspell.json @@ -47,9 +47,11 @@ "ipython", "mkdir", "mypy", + "oneline", "pytest", "PYTHONHASHSEED", "repoma", + "sympy", "toctree", "Zenodo" ], diff --git a/src/repoma/pin_nb_requirements.py b/src/repoma/pin_nb_requirements.py index ef871fd5..f639b604 100644 --- a/src/repoma/pin_nb_requirements.py +++ b/src/repoma/pin_nb_requirements.py @@ -7,8 +7,10 @@ """ import argparse +import re import sys from functools import lru_cache +from textwrap import dedent from typing import List, Optional, Sequence import nbformat @@ -19,7 +21,7 @@ from .errors import PrecommitError -__PIP_INSTALL_STATEMENT = "%pip install -q" +__EXPECTED_PIP_INSTALL_LINE = "%pip install -q" def check_pinned_requirements(filename: str) -> None: @@ -29,23 +31,19 @@ def check_pinned_requirements(filename: str) -> None: for cell_id, cell in enumerate(notebook["cells"]): if cell["cell_type"] != "code": continue - source: str = cell["source"] - src_lines = source.split("\n") - if len(src_lines) == 0: - continue - cell_content = "".join(s.strip("\\") for s in src_lines) - if not cell_content.startswith(__PIP_INSTALL_STATEMENT): + source = __to_oneline(cell["source"]) + pip_requirements = extract_pip_requirements(source) + if pip_requirements is None: continue executor = Executor() - executor(_check_install_statement, filename, cell_content) - executor(_check_requirements, filename, cell_content) - executor(_sort_requirements, filename, cell_content, notebook, cell_id) + executor(_check_pip_requirements, filename, pip_requirements) + executor(_format_pip_requirements, filename, source, notebook, cell_id) executor(_update_metadata, filename, cell["metadata"], notebook) executor.finalize() return msg = ( f'Notebook "{filename}" does not contain a pip install cell of the form' - f" {__PIP_INSTALL_STATEMENT} some-package==0.1.0 package2==3.2" + f" {__EXPECTED_PIP_INSTALL_LINE} some-package==0.1.0 package2==3.2" ) raise PrecommitError(msg) @@ -58,64 +56,100 @@ def __has_python_kernel(notebook: dict) -> bool: return "python" in kernel_language -def _check_install_statement(filename: str, install_statement: str) -> None: - if not install_statement.startswith(__PIP_INSTALL_STATEMENT): - msg = ( - f"First shell cell in notebook {filename} does not start with" - f" {__PIP_INSTALL_STATEMENT}" - ) - raise PrecommitError(msg) - if install_statement.endswith("/dev/null"): - msg = ( - "Remove the /dev/null from the pip install statement in notebook" - f" {filename}" - ) - raise PrecommitError(msg) +@lru_cache(maxsize=1) +def __to_oneline(source: str) -> str: + src_lines = source.split("\n") + return "".join(s.rstrip().rstrip("\\") for s in src_lines) -def _check_requirements(filename: str, install_statement: str) -> None: - requirements = __extract_requirements(install_statement) +@lru_cache(maxsize=1) +def extract_pip_requirements(source: str) -> Optional[List[str]]: + r"""Check if the source in a cell is a pip install statement. + + >>> extract_pip_requirements("Not a pip install statement") + >>> extract_pip_requirements("pip install") + [] + >>> extract_pip_requirements("pip3 install attrs") + ['attrs'] + >>> extract_pip_requirements("pip3 install -q attrs") + ['attrs'] + >>> extract_pip_requirements("pip3 install attrs &> /dev/null") + ['attrs'] + >>> extract_pip_requirements("%pip install attrs numpy==1.24.4 ") + ['attrs', 'numpy==1.24.4'] + >>> extract_pip_requirements("!python3 -mpip install sympy") + ['sympy'] + >>> extract_pip_requirements(''' + ... python3 -m pip install \ + ... attrs numpy \ + ... sympy \ + ... tensorflow + ... ''') + ['attrs', 'numpy', 'sympy', 'tensorflow'] + """ + # cspell:ignore mpip + matches = re.match( + r"[%\!]?\s*(python3?\s+-m\s*)?pip3?\s+install\s*(-q)?(.*?)(&?>\s*/dev/null)?$", + __to_oneline(source).strip(), + ) + if matches is None: + return None + packages = matches.group(3).split(" ") + packages = [p.strip() for p in packages] + return [p for p in packages if p] + + +def _check_pip_requirements(filename: str, requirements: List[str]) -> None: if len(requirements) == 0: msg = f'At least one dependency required in install cell of "{filename}"' raise PrecommitError(msg) - for requirement in requirements: - requirement = requirement.strip() - if not requirement: + for req in requirements: + req = req.strip() + if not req: continue - if "git+" in requirement: + if "git+" in req: continue - if not any(equal_sign in requirement for equal_sign in ["==", "~="]): + unpinned_requirements = [] + for req in requirements: + if req.startswith("git+"): + continue + if any(equal_sign in req for equal_sign in ["==", "~="]): + continue + package = req.split("<")[0].split(">")[0].strip() + unpinned_requirements.append(package) + if unpinned_requirements: msg = ( - f'Install cell in notebook "{filename}" contains a requirement without' - f" == or ~= ({requirement})" + f'Install cell in notebook "{filename}" contains requirements without' + "pinning (== or ~=):" ) + for req in unpinned_requirements: + msg += f"\n - {req}" + msg += dedent(f""" + Get the currently installed versions with: + + python3 -m pip freeze | grep -iE '{"|".join(sorted(unpinned_requirements))}' + """) raise PrecommitError(msg) -def _sort_requirements( +def _format_pip_requirements( filename: str, install_statement: str, notebook: NotebookNode, cell_id: int ) -> None: - requirements = __extract_requirements(install_statement) + requirements = extract_pip_requirements(install_statement) + if requirements is None: + return git_requirements = {r for r in requirements if r.startswith("git+")} pip_requirements = set(requirements) - git_requirements pip_requirements = {r.lower().replace("_", "-") for r in pip_requirements} sorted_requirements = sorted(pip_requirements) + sorted(git_requirements) - if sorted_requirements != requirements: - new_source = f"{__PIP_INSTALL_STATEMENT} {' '.join(sorted_requirements)}" - notebook["cells"][cell_id]["source"] = new_source + expected = f"{__EXPECTED_PIP_INSTALL_LINE} {' '.join(sorted_requirements)}" + if install_statement != expected: + notebook["cells"][cell_id]["source"] = expected nbformat.write(notebook, filename) - msg = f'Ordered and formatted pip install cell in "{filename}"' + msg = f'Ordered and formatted pip install cell in "{filename}"' raise PrecommitError(msg) -@lru_cache(maxsize=1) -def __extract_requirements(install_statement: str) -> List[str]: - package_listing = install_statement.replace(__PIP_INSTALL_STATEMENT, "") - requirements = package_listing.split(" ") - requirements = [r.strip() for r in requirements] - return [r for r in requirements if r] - - def _update_metadata(filename: str, metadata: dict, notebook: NotebookNode) -> None: updated_metadata = False jupyter_metadata = metadata.get("jupyter")