Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Display notice when copyrighted file is copied or renamed #52

Merged
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ jobs:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'
- name: Build & Test
run: ./ci/build-test.sh
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
]
requires-python = ">=3.9"
requires-python = ">=3.10"
dependencies = [
"PyYAML",
"bashlex",
Expand Down
94 changes: 70 additions & 24 deletions src/rapids_pre_commit_hooks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import git

from .lint import Linter, LintMain
from .lint import Linter, LintMain, LintWarning

COPYRIGHT_RE: re.Pattern = re.compile(
r"Copyright *(?:\(c\))? *(?P<years>(?P<first_year>\d{4})(-(?P<last_year>\d{4}))?),?"
Expand Down Expand Up @@ -59,30 +59,69 @@ def append_stripped(start: int, item: re.Match):
return lines


def add_copy_rename_note(
linter: Linter,
warning: LintWarning,
change_type: str,
old_filename: Optional[Union[str, os.PathLike[str]]],
):
CHANGE_VERBS = {
"C": "copied",
"R": "renamed",
}
try:
change_verb = CHANGE_VERBS[change_type]
except KeyError:
pass
else:
warning.add_note(
(0, len(linter.content)), f"file was {change_verb} from '{old_filename}'"
)


def apply_copyright_revert(
linter: Linter, old_match: re.Match, new_match: re.Match
linter: Linter,
change_type: str,
old_filename: Optional[Union[str, os.PathLike[str]]],
old_match: re.Match,
new_match: re.Match,
) -> None:
if old_match.group("years") == new_match.group("years"):
warning_pos = new_match.span()
else:
warning_pos = new_match.span("years")
linter.add_warning(
w = linter.add_warning(
warning_pos,
"copyright is not out of date and should not be updated",
).add_replacement(new_match.span(), old_match.group())
)
w.add_replacement(new_match.span(), old_match.group())
add_copy_rename_note(linter, w, change_type, old_filename)


def apply_copyright_update(linter: Linter, match: re.Match, year: int) -> None:
linter.add_warning(match.span("years"), "copyright is out of date").add_replacement(
def apply_copyright_update(
linter: Linter,
change_type: str,
old_filename: Optional[Union[str, os.PathLike[str]]],
match: re.Match,
year: int,
) -> None:
w = linter.add_warning(match.span("years"), "copyright is out of date")
w.add_replacement(
match.span(),
COPYRIGHT_REPLACEMENT.format(
first_year=match.group("first_year"),
last_year=year,
),
)
add_copy_rename_note(linter, w, change_type, old_filename)


def apply_copyright_check(linter: Linter, old_content: Optional[str]) -> None:
def apply_copyright_check(
linter: Linter,
change_type: str,
old_filename: Optional[Union[str, os.PathLike[str]]],
old_content: Optional[str],
) -> None:
if linter.content != old_content:
current_year = datetime.datetime.now().year
new_copyright_matches = match_copyright(linter.content)
Expand All @@ -97,14 +136,18 @@ def apply_copyright_check(linter: Linter, old_content: Optional[str]) -> None:
old_copyright_matches, new_copyright_matches
):
if old_match.group() != new_match.group():
apply_copyright_revert(linter, old_match, new_match)
apply_copyright_revert(
linter, change_type, old_filename, old_match, new_match
)
elif new_copyright_matches:
for match in new_copyright_matches:
if (
int(match.group("last_year") or match.group("first_year"))
< current_year
):
apply_copyright_update(linter, match, current_year)
apply_copyright_update(
linter, change_type, old_filename, match, current_year
)
else:
linter.add_warning((0, 0), "no copyright notice found")

Expand Down Expand Up @@ -233,22 +276,24 @@ def try_get_ref(remote: "git.Remote") -> Optional["git.Reference"]:

def get_changed_files(
args: argparse.Namespace,
) -> dict[Union[str, os.PathLike[str]], Optional["git.Blob"]]:
) -> dict[Union[str, os.PathLike[str]], tuple[str, Optional["git.Blob"]]]:
try:
repo = git.Repo()
except git.InvalidGitRepositoryError:
return {
os.path.relpath(os.path.join(dirpath, filename), "."): None
os.path.relpath(os.path.join(dirpath, filename), "."): ("A", None)
for dirpath, dirnames, filenames in os.walk(".")
for filename in filenames
}

changed_files: dict[Union[str, os.PathLike[str]], Optional["git.Blob"]] = {
f: None for f in repo.untracked_files
}
changed_files: dict[
Union[str, os.PathLike[str]], tuple[str, Optional["git.Blob"]]
] = {f: ("A", None) for f in repo.untracked_files}
target_branch_upstream_commit = get_target_branch_upstream_commit(repo, args)
if target_branch_upstream_commit is None:
changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()})
changed_files.update(
{blob.path: ("A", None) for _, blob in repo.index.iter_blobs()}
)
return changed_files

for merge_base in repo.merge_base(
Expand All @@ -262,9 +307,9 @@ def get_changed_files(
)
for diff in diffs:
if diff.change_type == "A":
changed_files[diff.b_path] = None
changed_files[diff.b_path] = (diff.change_type, None)
elif diff.change_type != "D":
changed_files[diff.b_path] = diff.a_blob
changed_files[diff.b_path] = (diff.change_type, diff.a_blob)

return changed_files

Expand Down Expand Up @@ -313,16 +358,17 @@ def the_check(linter: Linter, args: argparse.Namespace):
return

try:
changed_file = changed_files[git_filename]
change_type, changed_file = changed_files[git_filename]
except KeyError:
return

old_content = (
changed_file.data_stream.read().decode()
if changed_file is not None
else None
)
apply_copyright_check(linter, old_content)
if changed_file is None:
old_filename = None
old_content = None
else:
old_filename = changed_file.path
old_content = changed_file.data_stream.read().decode()
apply_copyright_check(linter, change_type, old_filename, old_content)

return the_check

Expand Down
22 changes: 5 additions & 17 deletions src/rapids_pre_commit_hooks/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,13 @@
import functools
import re
import warnings
from collections.abc import Callable, Generator, Iterable
from collections.abc import Callable
from itertools import pairwise
from typing import Optional

from rich.console import Console
from rich.markup import escape


# Taken from Python docs
# (https://docs.python.org/3.12/library/itertools.html#itertools.pairwise)
# Replace with itertools.pairwise after dropping Python 3.9 support
def _pairwise(iterable: Iterable) -> Generator:
# pairwise('ABCDEFG') → AB BC CD DE EF FG
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b


_PosType = tuple[int, int]


Expand Down Expand Up @@ -66,9 +54,9 @@ class LintWarning:
pos: _PosType
msg: str
replacements: list[Replacement] = dataclasses.field(
default_factory=list, init=False
default_factory=list, kw_only=True
)
notes: list[Note] = dataclasses.field(default_factory=list, init=False)
notes: list[Note] = dataclasses.field(default_factory=list, kw_only=True)

def add_replacement(self, pos: _PosType, newtext: str) -> None:
self.replacements.append(Replacement(pos, newtext))
Expand Down Expand Up @@ -102,7 +90,7 @@ def fix(self) -> str:
key=lambda replacement: replacement.pos,
)

for r1, r2 in _pairwise(sorted_replacements):
for r1, r2 in pairwise(sorted_replacements):
if r1.pos[1] > r2.pos[0]:
raise OverlappingReplacementsError(f"{r1} overlaps with {r2}")

Expand Down
Loading