diff --git a/README.md b/README.md index 1c9063f..8dd5082 100644 --- a/README.md +++ b/README.md @@ -108,15 +108,21 @@ There are several available subcommands: - `clean` -- performs garbage collection - `upload` -- uploads files to the backend (no chunking, no encryption, keeping original names) -> ⚠️ **WARNING**: actions that read from or upload to the repository can safely be run -> concurrently; however, there are presently no guards in place that would make it safe -> for you to run destructive actions (`delete`, `clean`) concurrently with those actions -> *unless* you use independent keys (see the explanation above). I do plan to implement them -> soon-ish, but in the meantime **DO NOT** use shared keys (or, naturally, the same key) -> to `snapshot` and `clean` at the same time, for example. -> -> As far as the upcoming implementation of such guards, it'll be based on locks. I'm familiar -> with the lock-free deduplication strategy (like in Duplicacy), but I don't like it much. +It's always safe to read from and upload to a Replicat repository concurrently. In order to +make it possible for you to run destructive actions (`delete`, `clean`) concurrently with +uploads and reads, Replicat uses lock-based guards. Here's what you should know: + + - locks are designed to protect the integrity of data in the case of concurrent operations + performed with shared keys (or, naturally, the same key), meaning that locks do not lock + the whole repository, unless the repository is unencrypted. If you're sure that you're + the sole user of the repository, or that no one is using the repository with the same + (or shared) key at the same time, then you can safely use the repository in exclusive mode + + - Replicat will terminate if it detects a conflicting operation being performed with + the same (or shared) key. It may have to wait a few extra seconds to make sure all of the + locks are visible + + - during shutdown Replicat will attempt to delete the locks it created There are several command line arguments that are common to all subcommands: @@ -129,8 +135,10 @@ There are several command line arguments that are common to all subcommands: destinations). If the backend requires additional arguments, they will appear in the `--help` output. Refer to the section on backends for more detailed information. + - `-x`/`--exclusive` -- enables the exclusive mode (see above) - `-q`/`--hide-progress` -- suppresses progress indication for commands that support it - - `-c`/`--concurrent` -- the number of concurrent connections to the backend + - `-c`/`--concurrent` -- the number of concurrent connections to the backend. + Normal lock operations don't respect this limit - `--cache-directory` -- specifies the directory to use for cache. `--no-cache` disables cache completely. - `-v`/`--verbose` -- specifies the logging verbosity. The default verbosity is `WARNING`, diff --git a/replicat/__main__.py b/replicat/__main__.py index 32dcdf4..0c26a29 100644 --- a/replicat/__main__.py +++ b/replicat/__main__.py @@ -42,6 +42,7 @@ async def _cmd_handler(args, unknown, error): concurrent=args.concurrent, quiet=args.quiet, cache_directory=args.cache_directory, + exclusive=args.exclusive, ) if args.action == 'init': diff --git a/replicat/exceptions.py b/replicat/exceptions.py index 2d5e9f6..a7e1725 100644 --- a/replicat/exceptions.py +++ b/replicat/exceptions.py @@ -8,3 +8,7 @@ class DecryptionError(ReplicatError): class AuthRequired(ReplicatError): pass + + +class Locked(ReplicatError): + pass diff --git a/replicat/repository.py b/replicat/repository.py index eb613dc..68113f8 100644 --- a/replicat/repository.py +++ b/replicat/repository.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import asyncio import collections.abc +import contextvars import dataclasses import inspect import io @@ -13,12 +16,13 @@ import sys import threading import time -from collections import namedtuple +from collections import OrderedDict, namedtuple from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from datetime import datetime from decimal import Decimal -from functools import cached_property +from enum import Enum +from functools import cached_property, wraps from pathlib import Path from random import Random from typing import Any, Dict, Optional @@ -91,10 +95,83 @@ def chunkify(self, it): return self.chunker(it, params=params) +class LockTypes(str, Enum): + create_read = 'cr' + delete = 'd' + + +class RepositoryLock: + """An instance of this class can be used as a decorator for methods of Repository. + It manages the lock worker for the specified lock type""" + + _locked_at = contextvars.ContextVar('locked_at') + _repository = contextvars.ContextVar('repository') + + @property + def locked_at(self): + try: + return self._locked_at.get() + except LookupError as e: + raise RuntimeError( + 'lock can only be managed from inside decorated methods' + ) from e + + @locked_at.setter + def locked_at(self, value): + self._locked_at.set(value) + + @property + def repository(self): + try: + return self._repository.get() + except LookupError as e: + raise RuntimeError( + 'lock can only be managed from inside decorated methods' + ) from e + + @repository.setter + def repository(self, value): + self._repository.set(value) + + def __call__(self, lock_type: LockTypes): + def _decorator(func): + @wraps(func) + async def _wrapper(repository: Repository, *args, **kwargs): + self.repository = repository + self.locked_at = started_at = utils.utc_timestamp() + + worker = asyncio.create_task( + repository.lock_worker(started_at, lock_type) + ) + task = asyncio.create_task(func(repository, *args, **kwargs)) + task.add_done_callback(lambda _: worker.cancel()) + + await asyncio.wait([worker, task], return_when=asyncio.FIRST_EXCEPTION) + try: + await worker + except asyncio.CancelledError: + pass + + return await task + + return _wrapper + + return _decorator + + async def wait(self, wait_time, lock_type: LockTypes): + await self.repository.wait_for_lock( + wait_time + self.locked_at - utils.utc_timestamp(), lock_type + ) + + +lock = RepositoryLock() + + class Repository: # NOTE: trailing slashes CHUNK_PREFIX = 'data/' SNAPSHOT_PREFIX = 'snapshots/' + LOCK_PREFIX = 'locks/' # These correspond to the names of adapters DEFAULT_CHUNKER_NAME = 'gclmulchunker' DEFAULT_CIPHER_NAME = 'aes_gcm' @@ -105,6 +182,8 @@ class Repository: DEFAULT_SHARED_KDF_NAME = 'blake2b' EMPTY_TABLE_VALUE = '--' DEFAULT_CACHE_DIRECTORY = utils.fs.DEFAULT_CACHE_DIRECTORY + LOCK_TTL = 15 * 60 + LOCK_TTP = 15 def __init__( self, @@ -113,10 +192,13 @@ def __init__( concurrent, quiet=True, cache_directory=DEFAULT_CACHE_DIRECTORY, + exclusive=False, ): - self._concurrent = concurrent - self._quiet = quiet - self._cache_directory = cache_directory + self.backend = backend + self.concurrent = concurrent + self.quiet = quiet + self.cache_directory = cache_directory + self.exclusive = exclusive self._slots = asyncio.Queue(maxsize=concurrent) # We need actual integers for TQDM slot management in CLI, but this queue @@ -125,18 +207,16 @@ def __init__( for slot in range(2, concurrent + 2): self._slots.put_nowait(slot) - self.backend = backend - def display_status(self, message): print(ef.bold + message + ef.rs, file=sys.stderr) def _get_cached(self, path): - assert self._cache_directory is not None - return Path(self._cache_directory, path).read_bytes() + assert self.cache_directory is not None + return Path(self.cache_directory, path).read_bytes() def _store_cached(self, path, data): - assert self._cache_directory is not None - file = Path(self._cache_directory, path) + assert self.cache_directory is not None + file = Path(self.cache_directory, path) file.parent.mkdir(parents=True, exist_ok=True) file.write_bytes(data) @@ -147,7 +227,7 @@ def _unlocked(self): @cached_property def executor(self): """Executor for non-async methods of the backend instance""" - return ThreadPoolExecutor(max_workers=self._concurrent) + return ThreadPoolExecutor(max_workers=self.concurrent) def _awrap(self, func, *args, **kwargs): if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): @@ -191,7 +271,7 @@ async def _download(self, location): async def _upload_data(self, location, data): async with self._acquire_slot(): - logger.info('Uploading binary data to %s', location) + logger.info('Uploading binary data as %s', location) await self._awrap(self.backend.upload, location, data) async def _delete(self, location): @@ -279,6 +359,113 @@ def _snapshot_digest_to_location_parts(self, digest, /): digest_mac = self.props.mac(digest) if self.props.encrypted else digest return LocationParts(name=digest.hex(), tag=digest_mac.hex()) + def get_lock_location(self, *, name, tag): + """Build POSIX-style storage path for the lock using its name and tag. + The tag is included for ownership verification. The part after the last slash + (actual filename on the filesystem) must be under 255 bytes for compatibility + with most filesystems. You can assume that both name and tag are hex-strings + each no longer than 128 characters. The returned path must start with + LOCK_PREFIX and is allowed to contain forward slashes, characters from name + and tag, and hyphens.""" + return posixpath.join(self.LOCK_PREFIX, f'{name}-{tag}') + + def parse_lock_location(self, location, /): + """Parse the storage path for the lock, extract its name and tag""" + if not location.startswith(self.LOCK_PREFIX): + raise ValueError('Not a lock location') + + head, _, tag = location.rpartition('-') + return LocationParts(name=head.rpartition('/')[2], tag=tag) + + def get_lock_frame(self, ts): + """Return the lock frame range for this timestamp""" + since_start = ts % self.LOCK_TTL + current_start = ts - since_start + next_start = current_start + self.LOCK_TTL + return current_start, next_start + + def _lock_ts_to_location_parts(self, ts, lock_type): + ts_hex = ts.to_bytes(ts.bit_length() // 8 + 1, 'little').hex() + name = f'{lock_type}-{ts_hex}' + + if self.props.encrypted: + tag = self.props.mac(name.encode('ascii')).hex() + else: + tag = 'none' + + return LocationParts(name=name, tag=tag) + + def _lock_ts_to_location(self, ts, lock_type): + name, tag = self._lock_ts_to_location_parts(ts, lock_type) + return self.get_lock_location(name=name, tag=tag) + + async def _upload_lock(self, path): + logger.info('Uploading lock %s', path) + unknown = self.backend.upload(path, b'') + if inspect.isawaitable(unknown): + await unknown + + async def _check_lock(self, path): + logger.info('Checking lock at %s', path) + unknown = self.backend.exists(path) + if inspect.isawaitable(unknown): + return await unknown + else: + return unknown + + async def _delete_lock(self, path): + logger.info('Deleting lock %s', path) + unknown = self.backend.delete(path) + if inspect.isawaitable(unknown): + await unknown + + def lock_frames(self, ts): + frame_start, _ = self.get_lock_frame(ts) + yield frame_start + + while True: + ts = utils.utc_timestamp() + frame_start, frame_end = self.get_lock_frame(ts) + yield frame_end if frame_end - ts <= self.LOCK_TTP else frame_start + + async def lock_worker(self, ts, lock_type, delay=1): + """Create a lock of this type for every lock frame indefinitely, starting + from the given timestamp. By the time it terminates (e.g. if it's canceled), + all of the locks must be deleted""" + if self.exclusive: + return + + created = OrderedDict() + try: + for frame_start in self.lock_frames(ts): + if frame_start not in created: + location = self._lock_ts_to_location(frame_start, lock_type) + await self._upload_lock(location) + created[frame_start] = location + + if len(created) > 2: + _, to_delete = created.popitem(last=False) + await self._delete_lock(to_delete) + + await asyncio.sleep(delay) + finally: + logger.info('Running lock cleanup') + await asyncio.gather(*map(self._delete_lock, created.values())) + + async def wait_for_lock(self, wait_time, lock_type): + """Wait this many seconds, check if the lock of the given type exists, + raise exception if that's the case""" + if self.exclusive: + return + + logger.info('Waiting for %s lock for %d seconds', lock_type, wait_time) + await asyncio.sleep(wait_time) + + frame_start, _ = self.get_lock_frame(utils.utc_timestamp()) + location = self._lock_ts_to_location(frame_start, lock_type) + if await self._check_lock(location): + raise exceptions.Locked('Repository is locked by another operation') + def read_metadata(self, path, /): # TODO: Cache stat result? stat_result = os.stat(path) @@ -490,6 +677,12 @@ async def init(self, *, password=None, settings=None, key_output_path=None): json.dumps(key, indent=4, default=self.default_serialization_hook) ) else: + if password is not None or key_output_path is not None: + raise exceptions.ReplicatError( + 'Password and key output path can only be provided to initialise ' + 'encrypted repositories' + ) + key = None self.display_status('Uploading config') @@ -521,6 +714,10 @@ async def unlock(self, *, password=None, key=None): props, **self._instantiate_key(key, password=password, cipher=props.cipher), ) + elif password is not None or key is not None: + raise exceptions.ReplicatError( + 'Cannot provide password or key to unlock unencrypted repositories' + ) self.props = props @@ -603,7 +800,7 @@ async def _load_snapshots(self, *, snapshot_regex=None): async def _download_snapshot(path, digest): contents = empty - if self._cache_directory is not None: + if self.cache_directory is not None: try: contents = self._get_cached(path) except FileNotFoundError: @@ -615,7 +812,7 @@ async def _download_snapshot(path, digest): if self.props.hash_digest(contents) != digest: raise exceptions.ReplicatError(f'Snapshot at {path!r} is corrupted') - if self._cache_directory is not None: + if self.cache_directory is not None: logger.info('Caching %s', path) self._store_cached(path, contents) @@ -817,6 +1014,7 @@ async def list_files(self, *, snapshot_regex=None, files_regex=None): def _flatten_paths(self, paths): return list(utils.fs.flatten_paths(path.resolve(strict=True) for path in paths)) + @lock(LockTypes.create_read) async def snapshot(self, *, paths, note=None, rate_limit=None): self.display_status('Collecting files') files = self._flatten_paths(paths) @@ -843,7 +1041,7 @@ async def snapshot(self, *, paths, note=None, rate_limit=None): unit_scale=True, total=None, position=0, - disable=self._quiet, + disable=self.quiet, leave=True, ) finished_tracker = tqdm( @@ -851,11 +1049,11 @@ async def snapshot(self, *, paths, note=None, rate_limit=None): unit='', total=len(files), position=1, - disable=self._quiet, + disable=self.quiet, leave=True, ) loop = asyncio.get_running_loop() - chunk_queue = queue.Queue(maxsize=self._concurrent * 10) + chunk_queue = queue.Queue(maxsize=self.concurrent * 10) abort = threading.Event() if rate_limit is not None: @@ -1002,7 +1200,7 @@ async def _worker(queue_timeout=0.025): desc=f'Chunk #{chunk.counter:06}', total=length, position=slot, - disable=self._quiet, + disable=self.quiet, rate_limiter=rate_limiter, ) with stream, iowrapper: @@ -1021,7 +1219,10 @@ async def _worker(queue_timeout=0.025): ) with finished_tracker, bytes_tracker: try: - await asyncio.gather(*(_worker() for _ in range(self._concurrent))) + await asyncio.gather( + lock.wait(self.LOCK_TTP, LockTypes.delete), + *(_worker() for _ in range(self.concurrent)), + ) except: abort.set() raise @@ -1103,7 +1304,7 @@ async def restore(self, *, snapshot_regex=None, files_regex=None, path=None): chunks_references = {} executor = ThreadPoolExecutor( - max_workers=self._concurrent * 5, thread_name_prefix='file-writer' + max_workers=self.concurrent * 5, thread_name_prefix='file-writer' ) glock = threading.Lock() flocks = {} @@ -1227,16 +1428,16 @@ async def _worker(): unit_scale=True, total=total_bytes, position=0, - disable=self._quiet, + disable=self.quiet, leave=True, ) with bytes_tracker: - await asyncio.gather(*(_worker() for _ in range(self._concurrent))) + await asyncio.gather(*(_worker() for _ in range(self.concurrent))) return utils.DefaultNamespace(files=list(seen_files)) + @lock(LockTypes.delete) async def delete_snapshots(self, snapshots): - # TODO: locking self.display_status('Loading snapshots') to_delete = set() to_keep = set() @@ -1264,12 +1465,14 @@ async def delete_snapshots(self, snapshots): to_delete.difference_update(to_keep) + await lock.wait(self.LOCK_TTP, LockTypes.create_read) + finished_snapshots_tracker = tqdm( desc='Snapshots deleted', unit='', total=len(snapshots_locations), position=0, - disable=self._quiet, + disable=self.quiet, leave=True, ) @@ -1285,7 +1488,7 @@ async def _delete_snapshot(location): unit='', total=len(to_delete), position=0, - disable=self._quiet, + disable=self.quiet, leave=True, ) @@ -1297,8 +1500,8 @@ async def _delete_chunk(digest): with finished_chunks_tracker: await asyncio.gather(*map(_delete_chunk, to_delete)) - async def clean(self): - # TODO: locking + @lock(LockTypes.delete) + async def delete_unreferenced_chunks(self): self.display_status('Loading snapshots') referenced_digests = { y async for _, x in self._load_snapshots() for y in x['chunks'] @@ -1324,14 +1527,17 @@ async def clean(self): to_delete.add(location) if not to_delete: + print('No unreferenced chunks found') return + await lock.wait(self.LOCK_TTP, LockTypes.create_read) + finished_tracker = tqdm( desc='Unreferenced chunks deleted', unit='', total=len(to_delete), position=0, - disable=self._quiet, + disable=self.quiet, leave=True, ) @@ -1342,8 +1548,55 @@ async def _delete_chunk(location): with finished_tracker: await asyncio.gather(*map(_delete_chunk, to_delete)) - self.display_status('Running post-deletion cleanup') - await self._clean() + return to_delete + + async def delete_stale_locks(self): + self.display_status('Checking for stale locks') + to_delete = [] + + async for location in self._aiter(self.backend.list_files, self.LOCK_PREFIX): + name, tag = self.parse_lock_location(location) + if self.props.encrypted: + logger.info('Validating tag for %s', location) + if self.props.mac(name.encode('ascii')) != bytes.fromhex(tag): + logger.info('Tag for %s did not match, skipping', location) + continue + + encoded_lock_ts = name.rpartition('-')[2] + lock_ts = int.from_bytes(bytes.fromhex(encoded_lock_ts), 'little') + if utils.utc_timestamp() - lock_ts > self.LOCK_TTL * 2: + to_delete.append(location) + + if not to_delete: + print('No stale locks found') + return + + finished_tracker = tqdm( + desc='Stale locks deleted', + unit='', + total=len(to_delete), + position=0, + disable=self.quiet, + leave=True, + ) + + async def _delete_lock(location): + await self._delete(location) + finished_tracker.update() + + with finished_tracker: + await asyncio.gather(*map(_delete_lock, to_delete)) + + return to_delete + + async def clean(self): + deleted = any( + [await self.delete_unreferenced_chunks(), await self.delete_stale_locks()] + ) + + if deleted: + self.display_status('Running post-deletion cleanup') + await self._clean() @utils.disable_gc def _benchmark_chunker(self, adapter, number=10, size=512_000_000): @@ -1401,7 +1654,7 @@ async def upload(self, paths, *, rate_limit=None, skip_existing=False): unit_scale=True, total=None, position=0, - disable=self._quiet, + disable=self.quiet, leave=True, ) finished_tracker = tqdm( @@ -1409,7 +1662,7 @@ async def upload(self, paths, *, rate_limit=None, skip_existing=False): unit='', total=len(files), position=1, - disable=self._quiet, + disable=self.quiet, leave=True, ) base_directory = Path.cwd() @@ -1436,7 +1689,7 @@ async def _upload_path(path): desc=name, total=length, position=slot, - disable=self._quiet, + disable=self.quiet, rate_limiter=rate_limiter, ) with stream, iowrapper: diff --git a/replicat/tests/conftest.py b/replicat/tests/conftest.py index 8ffcba1..1044194 100644 --- a/replicat/tests/conftest.py +++ b/replicat/tests/conftest.py @@ -13,4 +13,6 @@ def local_backend(tmpdir): @pytest.fixture def local_repo(local_backend, tmpdir): - return Repository(local_backend, concurrent=5, cache_directory=tmpdir / '.cache') + return Repository( + local_backend, concurrent=5, cache_directory=tmpdir / '.cache', exclusive=True + ) diff --git a/replicat/tests/test_repository.py b/replicat/tests/test_repository.py index 6c0734f..8fb672d 100644 --- a/replicat/tests/test_repository.py +++ b/replicat/tests/test_repository.py @@ -1,15 +1,17 @@ +import asyncio import os import re import threading import time +from itertools import islice from random import Random -from unittest.mock import ANY, patch +from unittest.mock import ANY, DEFAULT, AsyncMock, call, patch import pytest -from replicat import exceptions +from replicat import exceptions, utils from replicat.backends.local import Local -from replicat.repository import Repository +from replicat.repository import LockTypes, Repository from replicat.utils import adapters @@ -38,6 +40,165 @@ def test_parse_snapshot_location(self, local_repo): assert name == 'GHIJKLMNOPQR' assert tag == '0123456789ABCDEF' + def test_get_lock_location(self, local_repo): + location = local_repo.get_lock_location(name='a-b-c-d-e', tag='012345678') + assert location == local_repo.LOCK_PREFIX + 'a-b-c-d-e-012345678' + + def test_parse_lock_location(self, local_repo): + location = local_repo.LOCK_PREFIX + 'a-b-c-d-e-012345678' + name, tag = local_repo.parse_lock_location(location) + assert name == 'a-b-c-d-e' + assert tag == '012345678' + + @pytest.mark.parametrize( + 'ts, start, end', + [ + (0, 0, Repository.LOCK_TTL), + (Repository.LOCK_TTL - 1, 0, Repository.LOCK_TTL), + (Repository.LOCK_TTL, Repository.LOCK_TTL, Repository.LOCK_TTL * 2), + (Repository.LOCK_TTL + 1, Repository.LOCK_TTL, Repository.LOCK_TTL * 2), + (Repository.LOCK_TTL * 2 - 1, Repository.LOCK_TTL, Repository.LOCK_TTL * 2), + ], + ) + def test_get_lock_frame(self, local_repo, ts, start, end): + assert local_repo.get_lock_frame(ts) == (start, end) + + @pytest.mark.parametrize( + 'tss, lock_ttl, lock_ttp, values', + [ + ([123, 187, 188, 189], 10, 1, [120, 180, 180, 190]), + ([125, 127, 127, 127], 10, 3, [120, 130, 130, 130]), + ([123, 188, 188, 188, 188], 10, 2, [120, 190, 190, 190, 190]), + ([100, 109, 2_198, 2_199, 2_200], 1_000, 1, [0, 0, 2_000, 2_000, 2_000]), + ([39, 39], 20, 0, [20, 20]), + ([39, 39], 20, 1, [20, 40]), + ([39, 100], 100, 0, [0, 100]), + ], + ) + def test_lock_frames(self, local_backend, tss, lock_ttl, lock_ttp, values): + class _repository(Repository): + LOCK_TTL = lock_ttl + LOCK_TTP = lock_ttp + + repository = _repository(local_backend, concurrent=5) + it = iter(tss) + + with patch.object(utils, 'utc_timestamp', side_effect=lambda: next(it)): + frames_it = repository.lock_frames(next(it)) + generated = list(islice(frames_it, len(values))) + + assert generated == values + + +class TestLockWorker: + @pytest.mark.parametrize('lock_type', LockTypes) + @pytest.mark.asyncio + async def test_exclusive(self, local_backend, tmp_path, lock_type): + repository = Repository(local_backend, concurrent=5, exclusive=True) + await repository.init( + password=b'', settings={'encryption': {'kdf': {'n': 4}}} + ) + with patch.object(local_backend, 'upload') as upload_mock, patch.object( + local_backend, 'delete' + ) as delete_mock: + await repository.lock_worker(12_345, LockTypes.create_read) + upload_mock.assert_not_called() + delete_mock.assert_not_called() + + @pytest.mark.parametrize('lock_type', LockTypes) + @pytest.mark.asyncio + async def test_not_exclusive(self, local_backend, tmp_path, lock_type): + repository = Repository(local_backend, concurrent=5) + await repository.init( + password=b'', settings={'encryption': {'kdf': {'n': 4}}} + ) + + with patch.object(local_backend, 'upload') as upload_mock, patch.object( + local_backend, 'delete' + ) as delete_mock, patch.object( + repository, + 'lock_frames', + return_value=[19, 19, 23, 29, 29, 29, 31, 37, 37, 41, 43, 47], + ) as timestamp_mock: + await repository.lock_worker(19, lock_type, delay=0) + + unique_tss = sorted(set(timestamp_mock.return_value)) + assert upload_mock.call_count == len(unique_tss) + lock_locations = [ + repository._lock_ts_to_location(x, lock_type) for x in unique_tss + ] + upload_mock.assert_has_calls([call(x, b'') for x in lock_locations]) + + assert delete_mock.call_count == len(unique_tss) + delete_mock.assert_has_calls([call(x) for x in lock_locations], any_order=True) + + +class TestWaitForLock: + @pytest.mark.parametrize('lock_type', LockTypes) + @pytest.mark.asyncio + async def test_exclusive(self, local_backend, tmp_path, lock_type): + repository = Repository(local_backend, concurrent=5, exclusive=True) + await repository.init( + password=b'', settings={'encryption': {'kdf': {'n': 4}}} + ) + with patch.object(local_backend, 'exists') as exists_mock, patch.object( + asyncio, 'sleep' + ) as sleep_mock: + await repository.wait_for_lock(1, lock_type) + sleep_mock.assert_not_awaited() + exists_mock.assert_not_called() + + @pytest.mark.parametrize('lock_type', LockTypes) + @pytest.mark.parametrize('wait_time', [-11, 0, 7]) + @pytest.mark.asyncio + async def test_exists(self, local_backend, tmp_path, lock_type, wait_time): + repository = Repository(local_backend, concurrent=5) + await repository.init( + password=b'', settings={'encryption': {'kdf': {'n': 4}}} + ) + with patch.object( + local_backend, 'exists', return_value=True + ) as exists_mock, patch.object(utils, 'utc_timestamp') as ts_mock, patch.object( + repository, + 'get_lock_frame', + return_value=(787, 997), + ) as frame_mock, patch.object( + asyncio, 'sleep' + ) as sleep_mock: + with pytest.raises(exceptions.Locked): + await repository.wait_for_lock(wait_time, lock_type) + + sleep_mock.assert_awaited_once_with(wait_time) + frame_mock.assert_called_once_with(ts_mock.return_value) + exists_mock.assert_called_once_with( + repository._lock_ts_to_location(frame_mock.return_value[0], lock_type) + ) + + @pytest.mark.parametrize('lock_type', LockTypes) + @pytest.mark.parametrize('wait_time', [-11, 0, 7]) + @pytest.mark.asyncio + async def test_does_not_exist(self, local_backend, tmp_path, lock_type, wait_time): + repository = Repository(local_backend, concurrent=5) + await repository.init( + password=b'', settings={'encryption': {'kdf': {'n': 4}}} + ) + with patch.object( + local_backend, 'exists', return_value=False + ) as exists_mock, patch.object(utils, 'utc_timestamp') as ts_mock, patch.object( + repository, + 'get_lock_frame', + return_value=(787, 997), + ) as frame_mock, patch.object( + asyncio, 'sleep' + ) as sleep_mock: + await repository.wait_for_lock(wait_time, lock_type) + + sleep_mock.assert_awaited_once_with(wait_time) + frame_mock.assert_called_once_with(ts_mock.return_value) + exists_mock.assert_called_once_with( + repository._lock_ts_to_location(frame_mock.return_value[0], lock_type) + ) + class TestInit: @pytest.mark.asyncio @@ -113,7 +274,6 @@ async def test_encrypted_ok(self, local_backend, local_repo, tmp_path): @pytest.mark.asyncio async def test_unencrypted_ok(self, local_backend, local_repo): result = await local_repo.init( - password=b'', settings={'encryption': None, 'chunking': {'max_length': 128_129}}, ) @@ -570,6 +730,95 @@ def upldstream(name, contents, length): upload_mock.assert_called_once_with(result.location, ANY) assert upload_stream_mock.call_count == len(result.data['files'][0]['chunks']) + @pytest.mark.parametrize('encryption', [None, {'kdf': {'n': 4}}]) + @pytest.mark.asyncio + async def test_not_locked(self, monkeypatch, local_backend, tmp_path, encryption): + local_repo = Repository(local_backend, concurrent=5) + await local_repo.init( + password=encryption and b'', + settings={'encryption': encryption}, + ) + + file = tmp_path / 'file' + file.write_bytes(Random(0).randbytes(1)) + + event = threading.Event() + wait_for_lock_mock = AsyncMock() + + async def _wait_for_lock(*a, **ka): + await wait_for_lock_mock(*a, **ka) + while not event.is_set(): + await asyncio.sleep(0) + + monkeypatch.setattr(local_repo, 'wait_for_lock', _wait_for_lock) + + with patch.object( + utils, + 'utc_timestamp', + side_effect=lambda it=iter( + [local_repo.LOCK_TTP * 23, local_repo.LOCK_TTP * 131] + ): next(it), + ), patch.object(local_repo, 'lock_worker') as lock_worker_mock, patch.object( + local_backend, 'upload' + ) as upload_mock, patch.object( + local_backend, 'upload_stream', side_effect=lambda *a, **ka: event.set() + ) as upload_stream_mock: + result = await local_repo.snapshot(paths=[file]) + + lock_worker_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * 23, LockTypes.create_read + ) + wait_for_lock_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * (23 - 131 + 1), LockTypes.delete + ) + upload_mock.assert_called_once_with(result.location, ANY) + upload_stream_mock.assert_called_once() + + @pytest.mark.parametrize('encryption', [None, {'kdf': {'n': 4}}]) + @pytest.mark.asyncio + async def test_locked(self, monkeypatch, local_backend, tmp_path, encryption): + local_repo = Repository(local_backend, concurrent=5) + await local_repo.init( + password=encryption and b'', + settings={'encryption': encryption}, + ) + + file = tmp_path / 'file' + file.write_bytes(Random(0).randbytes(1)) + + event = threading.Event() + wait_for_lock_mock = AsyncMock() + + async def _wait_for_lock(*a, **ka): + await wait_for_lock_mock(*a, **ka) + while not event.is_set(): + await asyncio.sleep(0) + raise exceptions.Locked + + monkeypatch.setattr(local_repo, 'wait_for_lock', _wait_for_lock) + + with pytest.raises(exceptions.Locked), patch.object( + utils, + 'utc_timestamp', + side_effect=lambda it=iter( + [local_repo.LOCK_TTP * 19, local_repo.LOCK_TTP * 117] + ): next(it), + ), patch.object(local_repo, 'lock_worker') as lock_worker_mock, patch.object( + local_backend, 'upload' + ) as upload_mock, patch.object( + local_backend, 'upload_stream', side_effect=lambda *a, **ka: event.set() + ) as upload_stream_mock: + await local_repo.snapshot(paths=[file]) + + lock_worker_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * 19, LockTypes.create_read + ) + wait_for_lock_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * (19 - 117 + 1), LockTypes.delete + ) + upload_mock.assert_not_called() + upload_stream_mock.assert_called_once() + class TestRestore: @pytest.mark.asyncio @@ -899,6 +1148,72 @@ async def test_unencrypted_unreferenced(self, local_backend, local_repo, tmp_pat assert not any(map(local_backend.exists, snapshot_paths)) assert not local_backend.exists(snapshot.location) + @pytest.mark.parametrize('encryption', [None, {'kdf': {'n': 4}}]) + @pytest.mark.asyncio + async def test_not_locked(self, monkeypatch, local_backend, encryption): + local_repo = Repository(local_backend, concurrent=5) + await local_repo.init( + password=encryption and b'', + settings={'encryption': encryption}, + ) + + with patch.multiple(local_repo, lock_worker=DEFAULT, wait_for_lock=DEFAULT): + snapshot = await local_repo.snapshot(paths=[]) + + with patch.object( + utils, + 'utc_timestamp', + side_effect=lambda it=iter( + [local_repo.LOCK_TTP * 37, local_repo.LOCK_TTP * 157] + ): next(it), + ), patch.object(local_repo, 'lock_worker') as lock_worker_mock, patch.object( + local_repo, 'wait_for_lock' + ) as wait_for_lock_mock, patch.object( + local_backend, 'delete' + ) as delete_mock: + await local_repo.delete_snapshots([snapshot.name]) + + lock_worker_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * 37, LockTypes.delete + ) + wait_for_lock_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * (37 - 157 + 1), LockTypes.create_read + ) + delete_mock.assert_called_once_with(snapshot.location) + + @pytest.mark.parametrize('encryption', [None, {'kdf': {'n': 4}}]) + @pytest.mark.asyncio + async def test_locked(self, monkeypatch, local_backend, encryption): + local_repo = Repository(local_backend, concurrent=5) + await local_repo.init( + password=encryption and b'', + settings={'encryption': encryption}, + ) + + with patch.multiple(local_repo, lock_worker=DEFAULT, wait_for_lock=DEFAULT): + snapshot = await local_repo.snapshot(paths=[]) + + with pytest.raises(exceptions.Locked), patch.object( + utils, + 'utc_timestamp', + side_effect=lambda it=iter( + [local_repo.LOCK_TTP * 41, local_repo.LOCK_TTP * 161] + ): next(it), + ), patch.object(local_repo, 'lock_worker') as lock_worker_mock, patch.object( + local_repo, 'wait_for_lock', side_effect=exceptions.Locked + ) as wait_for_lock_mock, patch.object( + local_backend, 'delete' + ) as delete_mock: + await local_repo.delete_snapshots([snapshot.name]) + + lock_worker_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * 41, LockTypes.delete + ) + wait_for_lock_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * (41 - 161 + 1), LockTypes.create_read + ) + delete_mock.assert_not_called() + class TestClean: @pytest.mark.asyncio @@ -1097,6 +1412,84 @@ async def test_unencrypted_referenced(self, local_backend, local_repo, tmp_path) assert all(map(local_backend.exists, chunks_paths)) + @pytest.mark.parametrize('encryption', [None, {'kdf': {'n': 4}}]) + @pytest.mark.asyncio + async def test_not_locked(self, monkeypatch, local_backend, tmp_path, encryption): + local_repo = Repository(local_backend, concurrent=5) + await local_repo.init( + password=encryption and b'', + settings={'encryption': encryption}, + ) + + file = tmp_path / 'file' + file.write_bytes(Random(0).randbytes(1)) + + with patch.multiple(local_repo, lock_worker=DEFAULT, wait_for_lock=DEFAULT): + snapshot = await local_repo.snapshot(paths=[file]) + + # Delete the created reference + local_backend.delete(snapshot.location) + + with patch.object( + utils, + 'utc_timestamp', + side_effect=lambda it=iter( + [local_repo.LOCK_TTP * 29, local_repo.LOCK_TTP * 167] + ): next(it), + ), patch.object(local_repo, 'lock_worker') as lock_worker_mock, patch.object( + local_repo, 'wait_for_lock' + ) as wait_for_lock_mock, patch.object( + local_backend, 'delete' + ) as delete_mock: + await local_repo.clean() + + lock_worker_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * 29, LockTypes.delete + ) + wait_for_lock_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * (29 - 167 + 1), LockTypes.create_read + ) + delete_mock.assert_called_once() + + @pytest.mark.parametrize('encryption', [None, {'kdf': {'n': 4}}]) + @pytest.mark.asyncio + async def test_locked(self, monkeypatch, local_backend, tmp_path, encryption): + local_repo = Repository(local_backend, concurrent=5) + await local_repo.init( + password=encryption and b'', + settings={'encryption': encryption}, + ) + + file = tmp_path / 'file' + file.write_bytes(Random(0).randbytes(1)) + + with patch.multiple(local_repo, lock_worker=DEFAULT, wait_for_lock=DEFAULT): + snapshot = await local_repo.snapshot(paths=[file]) + + # Delete the created reference + local_backend.delete(snapshot.location) + + with pytest.raises(exceptions.Locked), patch.object( + utils, + 'utc_timestamp', + side_effect=lambda it=iter( + [local_repo.LOCK_TTP * 43, local_repo.LOCK_TTP * 173] + ): next(it), + ), patch.object(local_repo, 'lock_worker') as lock_worker_mock, patch.object( + local_repo, 'wait_for_lock', side_effect=exceptions.Locked + ) as wait_for_lock_mock, patch.object( + local_backend, 'delete' + ) as delete_mock: + await local_repo.clean() + + lock_worker_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * 43, LockTypes.delete + ) + wait_for_lock_mock.assert_awaited_once_with( + local_repo.LOCK_TTP * (43 - 173 + 1), LockTypes.create_read + ) + delete_mock.assert_not_called() + class TestUpload: @pytest.fixture(autouse=True) @@ -1109,7 +1502,7 @@ def change_cwd(self, tmp_path): @pytest.mark.asyncio async def test_within_cwd(self, tmp_path): backend = Local(tmp_path / 'backend') - repository = Repository(backend, concurrent=5) + repository = Repository(backend, concurrent=5, exclusive=True) files_base_path = tmp_path / 'files' contents = { @@ -1144,7 +1537,7 @@ async def test_within_cwd(self, tmp_path): ) @pytest.mark.asyncio - async def test_overwrites(self, local_backend, tmp_path): + async def test_overwrites(self, tmp_path): backend = Local(tmp_path / 'backend') backend.upload('files/file', b'') @@ -1152,14 +1545,14 @@ async def test_overwrites(self, local_backend, tmp_path): files_base_path.mkdir() (files_base_path / 'file').write_bytes(b'') - repository = Repository(backend, concurrent=5) + repository = Repository(backend, concurrent=5, exclusive=True) await repository.upload([files_base_path / 'file']) assert set(backend.list_files()) == {'files/file'} assert backend.download('files/file') == b'' @pytest.mark.asyncio - async def test_skip_existing(self, local_backend, tmp_path): + async def test_skip_existing(self, tmp_path): backend = Local(tmp_path / 'backend') backend.upload('files/file', b'') @@ -1167,7 +1560,7 @@ async def test_skip_existing(self, local_backend, tmp_path): files_base_path.mkdir() (files_base_path / 'file').write_bytes(b'') - repository = Repository(backend, concurrent=5) + repository = Repository(backend, concurrent=5, exclusive=True) await repository.upload([files_base_path / 'file'], skip_existing=True) assert set(backend.list_files()) == {'files/file'} diff --git a/replicat/tests/test_utils.py b/replicat/tests/test_utils.py index 917924c..e486c4a 100644 --- a/replicat/tests/test_utils.py +++ b/replicat/tests/test_utils.py @@ -5,6 +5,8 @@ import time from base64 import standard_b64encode from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timezone +from unittest.mock import patch import pytest @@ -44,7 +46,7 @@ def authenticate(self): # Simulate work, wait for all the calls to finish if self.counter: while True: - time.sleep(0.5) + time.sleep(0.1) if self.results.count('ERROR') >= self.raise_on - 1: break @@ -79,7 +81,7 @@ async def authenticate(self): # Simulate work, wait for all the calls to finish if self.counter: while True: - await asyncio.sleep(0.5) + await asyncio.sleep(0.1) if self.results.count('ERROR') >= self.raise_on - 1: break @@ -333,3 +335,22 @@ def test_iterative_scandir(tmp_path): str(tmp_path / 'A/B/K/differentfile'), str(tmp_path / 'Y/yetanotherfile'), ] + + +@pytest.mark.parametrize( + 'now, expected', + [ + (datetime(1970, 1, 1, tzinfo=timezone.utc), 0), + (datetime(1970, 1, 15, 1, 2, 3, tzinfo=timezone.utc), 1_213_323), + ( + datetime(2022, 12, 11, 10, 9, 8, 765432, tzinfo=timezone.utc), + 1_670_753_348, + ), + ], +) +def test_utc_timestamp(now, expected): + with patch('replicat.utils.datetime') as datetime_mock: + datetime_mock.side_effect = datetime + datetime_mock.now.return_value = now + assert utils.utc_timestamp() == expected + datetime_mock.now.assert_called_once_with(timezone.utc) diff --git a/replicat/utils/__init__.py b/replicat/utils/__init__.py index d1e2671..bf8a94c 100644 --- a/replicat/utils/__init__.py +++ b/replicat/utils/__init__.py @@ -13,6 +13,7 @@ import threading import time import weakref +from datetime import datetime, timezone from decimal import Decimal from pathlib import Path from types import SimpleNamespace @@ -98,6 +99,12 @@ def _get_environb(var, default=None): ) common_options_parser = argparse.ArgumentParser(add_help=False) +common_options_parser.add_argument( + '-x', + '--exclusive', + action='store_true', + help='Assume exclusive access to the respository', +) common_options_parser.add_argument( '-q', '--hide-progress', @@ -591,3 +598,9 @@ async def as_completed(tasks): for _ in tasks: yield await queue.get() + + +def utc_timestamp(): + epoch = datetime(1970, 1, 1, tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + return int((now - epoch).total_seconds())