From 0430fd1c5335c7388bc73b05e7d88d959f2e3dfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Br=C3=A9nainn=20Woodsend?= Date: Fri, 29 Sep 2023 00:16:49 +0100 Subject: [PATCH] Fix corrupt mirror cache when builds are control+c cancelled. --- polycotylus/_mirror.py | 43 ++++++++++++++++++++++++++++++++++++------ tests/test_mirror.py | 32 +++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/polycotylus/_mirror.py b/polycotylus/_mirror.py index 1fc98bb..1a68a5b 100644 --- a/polycotylus/_mirror.py +++ b/polycotylus/_mirror.py @@ -15,6 +15,7 @@ import email.utils import contextlib import collections +import json from polycotylus import _docker from polycotylus._docker import cache_root @@ -76,6 +77,13 @@ def base_url(self): self._base_url = self._base_url() return self._base_url.strip("/") + def local_path(self, server_path): + # Some packages may contain colons (:) in their filenames. The colon is + # a prohibited character on Windows. Replace it with a nearly identical + # unicode equivalent. + server_path = server_path.lstrip("/").replace(":", "\uA789") + return self.base_dir / server_path + def serve(self): """Enable this mirror and block until killed (via Ctrl+C).""" with self: @@ -110,16 +118,35 @@ def __enter__(self): distribution. """)) from None self._prune() + try: + partial_downloads = json.loads((self.base_dir / "partial-downloads.json").read_bytes()) + except FileNotFoundError: + pass + else: + for file in partial_downloads: + with contextlib.suppress(FileNotFoundError): + self.local_path(file).unlink() + (self.base_dir / "partial-downloads.json").unlink() + thread = threading.Thread(target=self._httpd.serve_forever, daemon=True) + self._premature_abort = False thread.start() self._thread = thread self._listeners = 1 - def __exit__(self, *_): + def __exit__(self, exc_type, exc_value, traceback): with self._lock: self._listeners -= 1 if self._listeners: return + + if isinstance(exc_value, KeyboardInterrupt): + with self._lock: + self._premature_abort = True + for path in self._in_progress: + if response := getattr(self._in_progress[path], "_upstream", None): # pragma: no branch + response.close() + # Wait until all running downloads are complete to avoid competing over # ports if this mirror is re-enabled soon after. while self._in_progress: # pragma: no cover @@ -183,8 +210,7 @@ def cache(self): # Some packages may contain colons (:) in their filenames. The colon is # a prohibited character on Windows. Replace it with a nearly identical # unicode equivalent. - path = self.path.lstrip("/").replace(":", "\uA789") - return self.parent.base_dir / path + return self.parent.local_path(self.path) def do_GET(self): if any(fnmatch(self.path, i) for i in self.parent.ignore_patterns): @@ -243,9 +269,11 @@ def do_GET(self): with self.parent._lock: if self.command != "HEAD": if self.path not in self.parent._in_progress and not use_cache: - t = threading.Thread(target=self._download) - self.parent._in_progress[self.path] = t - t.start() + self._thread = threading.Thread(target=self._download) + self.parent._in_progress[self.path] = self + with open(self.parent.base_dir / "partial-downloads.json", "w") as f: + json.dump(sorted(self.parent._in_progress), f) + self._thread.start() if self.path in self.parent._in_progress \ or (self.command == "HEAD" and not use_cache): @@ -299,6 +327,9 @@ def _download(self): finally: with self.parent._lock: del self.parent._in_progress[self.path] + if not self.parent._premature_abort: + with open(self.parent.base_dir / "partial-downloads.json", "w") as f: + json.dump(sorted(self.parent._in_progress), f) def _in_progress_send(self): """Send a file from the cache whilst the cache is being written.""" diff --git a/tests/test_mirror.py b/tests/test_mirror.py index 7d026f3..d004202 100644 --- a/tests/test_mirror.py +++ b/tests/test_mirror.py @@ -12,6 +12,8 @@ import contextlib from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http import HTTPStatus +import json +import textwrap import pytest @@ -252,6 +254,36 @@ def _bogus_copy(source, dest, length=None): gzip.decompress(response.read()) +def test_control_c(tmp_path): + def upstream_get(self): + self.send_response(HTTPStatus.OK) + self.send_header("Content-Length", 1_000_000) + self.end_headers() + with contextlib.suppress(ConnectionResetError): + for i in range(10_000): + self.wfile.write(os.urandom(100)) + time.sleep(0.001) + if i == 1000: + os.kill(p.pid, 2) + + with fake_upstream(upstream_get): + with subprocess.Popen([sys.executable, "-c", textwrap.dedent(f""" + import sys + sys.path.insert(0, "{os.path.dirname(__file__)}") + from test_mirror import * + with CachedMirror("http://localhost:8899", Path("{tmp_path}"), [], [], 9989, "", ()): + with urlopen("http://localhost:9989/foo") as response: + response.read() + """)], stderr=subprocess.PIPE) as p: + assert p.wait(3) == -2, p.stderr + foo_cache = tmp_path / "foo" + assert foo_cache.stat().st_size < 1_000_000 + assert json.loads((tmp_path / "partial-downloads.json").read_bytes()) == ["/foo"] + + with CachedMirror("http://localhost:8899", tmp_path, [], [], 9989, "", ()): + assert not foo_cache.exists() + + obsolete_caches = { "alpine": [ "./v3.17/main/aarch64/curl-7.87.0-r1.apk",