Skip to content

Commit

Permalink
ENH: automatically format pip install cell (#219)
Browse files Browse the repository at this point in the history
* MAINT: extract `load_notebook()` function
* MAINT: remove dunder from subhook functions
  • Loading branch information
redeboer authored Nov 21, 2023
1 parent 1da9190 commit 7290c1e
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 60 deletions.
2 changes: 2 additions & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@
"ipython",
"mkdir",
"mypy",
"oneline",
"pytest",
"PYTHONHASHSEED",
"repoma",
"sympy",
"toctree",
"Zenodo"
],
Expand Down
9 changes: 3 additions & 6 deletions src/repoma/colab_toc_visible.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from typing import Optional, Sequence

import nbformat
from nbformat import NotebookNode

from repoma.utilities.notebook import load_notebook

from .errors import PrecommitError
from .utilities.executor import Executor
Expand All @@ -30,7 +31,7 @@ def main(argv: Optional[Sequence[str]] = None) -> int:


def _update_metadata(path: str) -> None:
notebook = open_notebook(path)
notebook = load_notebook(path)
metadata = notebook["metadata"]
updated = False
if metadata.get("colab") is None:
Expand All @@ -46,9 +47,5 @@ def _update_metadata(path: str) -> None:
raise PrecommitError(msg)


def open_notebook(path: str) -> NotebookNode:
return nbformat.read(path, as_version=nbformat.NO_CONVERT)


if __name__ == "__main__":
sys.exit(main())
12 changes: 5 additions & 7 deletions src/repoma/fix_nbformat_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import nbformat

from repoma.utilities.notebook import load_notebook

from .errors import PrecommitError
from .utilities.executor import Executor

Expand All @@ -34,22 +36,22 @@ def main(argv: Optional[Sequence[str]] = None) -> int:


def set_nbformat_version(filename: str) -> None:
notebook = open_notebook(filename)
notebook = load_notebook(filename)
if notebook["nbformat_minor"] != 4: # noqa: PLR2004
notebook["nbformat_minor"] = 4
nbformat.write(notebook, filename)


def remove_cell_ids(filename: str) -> None:
notebook = open_notebook(filename)
notebook = load_notebook(filename)
for cell in notebook["cells"]:
if "id" in cell:
del cell["id"]
nbformat.write(notebook, filename)


def check_svg_output_cells(filename: str) -> None:
notebook = open_notebook(filename)
notebook = load_notebook(filename)
for i, cell in enumerate(notebook["cells"]):
for output in cell.get("outputs", []):
data = output.get("data", {})
Expand All @@ -66,9 +68,5 @@ def check_svg_output_cells(filename: str) -> None:
)


def open_notebook(filename: str) -> dict:
return nbformat.read(filename, as_version=nbformat.NO_CONVERT)


if __name__ == "__main__":
sys.exit(main())
140 changes: 96 additions & 44 deletions src/repoma/pin_nb_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,43 @@
"""

import argparse
import re
import sys
from functools import lru_cache
from textwrap import dedent
from typing import List, Optional, Sequence

import nbformat
from nbformat import NotebookNode

from repoma.utilities.executor import Executor
from repoma.utilities.notebook import load_notebook

from .errors import PrecommitError

__PIP_INSTALL_STATEMENT = "%pip install -q "
__EXPECTED_PIP_INSTALL_LINE = "%pip install -q"


def check_pinned_requirements(filename: str) -> None:
notebook = nbformat.read(filename, as_version=nbformat.NO_CONVERT)
notebook = load_notebook(filename)
if not __has_python_kernel(notebook):
return
for cell in notebook["cells"]:
for cell_id, cell in enumerate(notebook["cells"]):
if cell["cell_type"] != "code":
continue
source: str = cell["source"]
src_lines = source.split("\n")
if len(src_lines) == 0:
continue
cell_content = "".join(s.strip("\\") for s in src_lines)
if not cell_content.startswith(__PIP_INSTALL_STATEMENT):
source = __to_oneline(cell["source"])
pip_requirements = extract_pip_requirements(source)
if pip_requirements is None:
continue
executor = Executor()
executor(__check_install_statement, filename, cell_content)
executor(__check_requirements, filename, cell_content)
executor(__update_metadata, filename, cell["metadata"], notebook)
executor(_check_pip_requirements, filename, pip_requirements)
executor(_format_pip_requirements, filename, source, notebook, cell_id)
executor(_update_metadata, filename, cell["metadata"], notebook)
executor.finalize()
return
msg = (
f'Notebook "{filename}" does not contain a pip install cell of the form'
f" {__PIP_INSTALL_STATEMENT}some-package==0.1.0 package2==3.2"
f" {__EXPECTED_PIP_INSTALL_LINE} some-package==0.1.0 package2==3.2"
)
raise PrecommitError(msg)

Expand All @@ -55,50 +56,101 @@ def __has_python_kernel(notebook: dict) -> bool:
return "python" in kernel_language


def __check_install_statement(filename: str, install_statement: str) -> None:
if not install_statement.startswith(__PIP_INSTALL_STATEMENT):
msg = (
f"First shell cell in notebook {filename} does not start with"
f" {__PIP_INSTALL_STATEMENT}"
)
raise PrecommitError(msg)
if install_statement.endswith("/dev/null"):
msg = (
"Remove the /dev/null from the pip install statement in notebook"
f" {filename}"
)
raise PrecommitError(msg)
@lru_cache(maxsize=1)
def __to_oneline(source: str) -> str:
src_lines = source.split("\n")
return "".join(s.rstrip().rstrip("\\") for s in src_lines)


@lru_cache(maxsize=1)
def extract_pip_requirements(source: str) -> Optional[List[str]]:
r"""Check if the source in a cell is a pip install statement.
>>> extract_pip_requirements("Not a pip install statement")
>>> extract_pip_requirements("pip install")
[]
>>> extract_pip_requirements("pip3 install attrs")
['attrs']
>>> extract_pip_requirements("pip3 install -q attrs")
['attrs']
>>> extract_pip_requirements("pip3 install attrs &> /dev/null")
['attrs']
>>> extract_pip_requirements("%pip install attrs numpy==1.24.4 ")
['attrs', 'numpy==1.24.4']
>>> extract_pip_requirements("!python3 -mpip install sympy")
['sympy']
>>> extract_pip_requirements('''
... python3 -m pip install \
... attrs numpy \
... sympy \
... tensorflow
... ''')
['attrs', 'numpy', 'sympy', 'tensorflow']
"""
# cspell:ignore mpip
matches = re.match(
r"[%\!]?\s*(python3?\s+-m\s*)?pip3?\s+install\s*(-q)?(.*?)(&?>\s*/dev/null)?$",
__to_oneline(source).strip(),
)
if matches is None:
return None
packages = matches.group(3).split(" ")
packages = [p.strip() for p in packages]
return [p for p in packages if p]


def __check_requirements(filename: str, install_statement: str) -> None:
package_listing = install_statement.replace(__PIP_INSTALL_STATEMENT, "")
requirements = package_listing.split(" ")
def _check_pip_requirements(filename: str, requirements: List[str]) -> None:
if len(requirements) == 0:
msg = f'At least one dependency required in install cell of "{filename}"'
raise PrecommitError(msg)
for requirement in requirements:
requirement = requirement.strip()
if not requirement:
for req in requirements:
req = req.strip()
if not req:
continue
if "git+" in requirement:
if "git+" in req:
continue
if not any(equal_sign in requirement for equal_sign in ["==", "~="]):
unpinned_requirements = []
for req in requirements:
if req.startswith("git+"):
continue
if any(equal_sign in req for equal_sign in ["==", "~="]):
continue
package = req.split("<")[0].split(">")[0].strip()
unpinned_requirements.append(package)
if unpinned_requirements:
msg = (
f'Install cell in notebook "{filename}" contains a requirement without'
f" == or ~= ({requirement})"
f'Install cell in notebook "{filename}" contains requirements without'
"pinning (== or ~=):"
)
for req in unpinned_requirements:
msg += f"\n - {req}"
msg += dedent(f"""
Get the currently installed versions with:
python3 -m pip freeze | grep -iE '{"|".join(sorted(unpinned_requirements))}'
""")
raise PrecommitError(msg)
requirements_lower = [r.lower() for r in requirements if not r.startswith("git+")]
if sorted(requirements_lower) != requirements_lower:
sorted_requirements = " ".join(sorted(requirements))
msg = (
f'Requirements in notebook "{filename}" are not sorted alphabetically.'
f" Should be:\n\n {sorted_requirements}"
)


def _format_pip_requirements(
filename: str, install_statement: str, notebook: NotebookNode, cell_id: int
) -> None:
requirements = extract_pip_requirements(install_statement)
if requirements is None:
return
git_requirements = {r for r in requirements if r.startswith("git+")}
pip_requirements = set(requirements) - git_requirements
pip_requirements = {r.lower().replace("_", "-") for r in pip_requirements}
sorted_requirements = sorted(pip_requirements) + sorted(git_requirements)
expected = f"{__EXPECTED_PIP_INSTALL_LINE} {' '.join(sorted_requirements)}"
if install_statement != expected:
notebook["cells"][cell_id]["source"] = expected
nbformat.write(notebook, filename)
msg = f'Ordered and formatted pip install cell in "{filename}"'
raise PrecommitError(msg)


def __update_metadata(filename: str, metadata: dict, notebook: NotebookNode) -> None:
def _update_metadata(filename: str, metadata: dict, notebook: NotebookNode) -> None:
updated_metadata = False
jupyter_metadata = metadata.get("jupyter")
if jupyter_metadata is not None and jupyter_metadata.get("source_hidden"):
Expand Down
7 changes: 4 additions & 3 deletions src/repoma/set_nb_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import nbformat

from repoma.utilities.notebook import load_notebook
from repoma.utilities.project_info import get_pypi_name

__CONFIG_CELL_CONTENT = """
Expand Down Expand Up @@ -132,7 +133,7 @@ def _update_cell(
) -> None:
if _skip_notebook(filename):
return
notebook = nbformat.read(filename, as_version=nbformat.NO_CONVERT)
notebook = load_notebook(filename)
exiting_cell = notebook["cells"][cell_id]
new_cell = nbformat.v4.new_code_cell(
new_content,
Expand All @@ -150,7 +151,7 @@ def _update_cell(
def _insert_autolink_concat(filename: str) -> None:
if _skip_notebook(filename, ignore_statement="<!-- no autolink-concat -->"):
return
notebook = nbformat.read(filename, as_version=nbformat.NO_CONVERT)
notebook = load_notebook(filename)
expected_cell_content = """
```{autolink-concat}
```
Expand All @@ -173,7 +174,7 @@ def _insert_autolink_concat(filename: str) -> None:
def _skip_notebook(
filename: str, ignore_statement: str = "<!-- no-set-nb-cells -->"
) -> bool:
notebook = nbformat.read(filename, as_version=nbformat.NO_CONVERT)
notebook = load_notebook(filename)
for cell in notebook["cells"]:
if cell["cell_type"] != "markdown":
continue
Expand Down
8 changes: 8 additions & 0 deletions src/repoma/utilities/notebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Helper tools for working with Jupyter Notebooks."""

import nbformat
from nbformat import NotebookNode


def load_notebook(path: str) -> NotebookNode:
return nbformat.read(path, as_version=nbformat.NO_CONVERT)

0 comments on commit 7290c1e

Please sign in to comment.