diff --git a/supervisor/api/backups.py b/supervisor/api/backups.py index a1a1a59b404..0a5ca5aecb9 100644 --- a/supervisor/api/backups.py +++ b/supervisor/api/backups.py @@ -83,6 +83,7 @@ def _ensure_list(item: Any) -> list: { vol.Optional(ATTR_PASSWORD): vol.Maybe(str), vol.Optional(ATTR_BACKGROUND, default=False): vol.Boolean(), + vol.Optional(ATTR_LOCATION): vol.Maybe(str), } ) @@ -379,8 +380,10 @@ async def backup_partial(self, request: web.Request): async def restore_full(self, request: web.Request): """Full restore of a backup.""" backup = self._extract_slug(request) - self._validate_cloud_backup_location(request, backup.location) body = await api_validate(SCHEMA_RESTORE_FULL, request) + self._validate_cloud_backup_location( + request, body.get(ATTR_LOCATION, backup.location) + ) background = body.pop(ATTR_BACKGROUND) restore_task, job_id = await self._background_backup_task( self.sys_backups.do_restore_full, backup, **body @@ -397,8 +400,10 @@ async def restore_full(self, request: web.Request): async def restore_partial(self, request: web.Request): """Partial restore a backup.""" backup = self._extract_slug(request) - self._validate_cloud_backup_location(request, backup.location) body = await api_validate(SCHEMA_RESTORE_PARTIAL, request) + self._validate_cloud_backup_location( + request, body.get(ATTR_LOCATION, backup.location) + ) background = body.pop(ATTR_BACKGROUND) restore_task, job_id = await self._background_backup_task( self.sys_backups.do_restore_partial, backup, **body diff --git a/supervisor/backups/backup.py b/supervisor/backups/backup.py index af6406ab45a..fa330bb7926 100644 --- a/supervisor/backups/backup.py +++ b/supervisor/backups/backup.py @@ -3,7 +3,8 @@ import asyncio from base64 import b64decode, b64encode from collections import defaultdict -from collections.abc import Awaitable +from collections.abc import AsyncGenerator, Awaitable +from contextlib import asynccontextmanager from copy import deepcopy from datetime import timedelta from functools import cached_property @@ -12,6 +13,7 @@ import logging from pathlib import Path import tarfile +from tarfile import TarFile from tempfile import TemporaryDirectory import time from typing import Any, Self @@ -56,6 +58,7 @@ from ..utils import remove_folder from ..utils.dt import parse_datetime, utcnow from ..utils.json import json_bytes +from ..utils.sentinel import DEFAULT from .const import BUF_SIZE, LOCATION_CLOUD_BACKUP, BackupType from .utils import key_to_iv, password_to_key from .validate import SCHEMA_BACKUP @@ -86,7 +89,6 @@ def __init__( self._data: dict[str, Any] = data or {ATTR_SLUG: slug} self._tmp = None self._outer_secure_tarfile: SecureTarFile | None = None - self._outer_secure_tarfile_tarfile: tarfile.TarFile | None = None self._key: bytes | None = None self._aes: Cipher | None = None self._locations: dict[str | None, Path] = {location: tar_file} @@ -375,59 +377,68 @@ def _load_file(): return True - async def __aenter__(self): - """Async context to open a backup.""" + @asynccontextmanager + async def create(self) -> AsyncGenerator[None]: + """Create new backup file.""" + if self.tarfile.is_file(): + raise BackupError( + f"Cannot make new backup at {self.tarfile.as_posix()}, file already exists!", + _LOGGER.error, + ) - # create a backup - if not self.tarfile.is_file(): - self._outer_secure_tarfile = SecureTarFile( - self.tarfile, - "w", - gzip=False, - bufsize=BUF_SIZE, + self._outer_secure_tarfile = SecureTarFile( + self.tarfile, + "w", + gzip=False, + bufsize=BUF_SIZE, + ) + try: + with self._outer_secure_tarfile as outer_tarfile: + yield + await self._create_cleanup(outer_tarfile) + finally: + self._outer_secure_tarfile = None + + @asynccontextmanager + async def open(self, location: str | None | type[DEFAULT]) -> AsyncGenerator[None]: + """Open backup for restore.""" + if location != DEFAULT and location not in self.all_locations: + raise BackupError( + f"Backup {self.slug} does not exist in location {location}", + _LOGGER.error, + ) + + backup_tarfile = ( + self.tarfile if location == DEFAULT else self.all_locations[location] + ) + if not backup_tarfile.is_file(): + raise BackupError( + f"Cannot open backup at {backup_tarfile.as_posix()}, file does not exist!", + _LOGGER.error, ) - self._outer_secure_tarfile_tarfile = self._outer_secure_tarfile.__enter__() - return # extract an existing backup - self._tmp = TemporaryDirectory(dir=str(self.tarfile.parent)) + self._tmp = TemporaryDirectory(dir=str(backup_tarfile.parent)) def _extract_backup(): """Extract a backup.""" - with tarfile.open(self.tarfile, "r:") as tar: + with tarfile.open(backup_tarfile, "r:") as tar: tar.extractall( path=self._tmp.name, members=secure_path(tar), filter="fully_trusted", ) - await self.sys_run_in_executor(_extract_backup) - - async def __aexit__(self, exception_type, exception_value, traceback): - """Async context to close a backup.""" - # exists backup or exception on build - try: - await self._aexit(exception_type, exception_value, traceback) - finally: - if self._tmp: - self._tmp.cleanup() - if self._outer_secure_tarfile: - self._outer_secure_tarfile.__exit__( - exception_type, exception_value, traceback - ) - self._outer_secure_tarfile = None - self._outer_secure_tarfile_tarfile = None + with self._tmp: + await self.sys_run_in_executor(_extract_backup) + yield - async def _aexit(self, exception_type, exception_value, traceback): + async def _create_cleanup(self, outer_tarfile: TarFile) -> None: """Cleanup after backup creation. - This is a separate method to allow it to be called from __aexit__ to ensure + Separate method to be called from create to ensure that cleanup is always performed, even if an exception is raised. """ - # If we're not creating a new backup, or if an exception was raised, we're done - if not self._outer_secure_tarfile or exception_type is not None: - return - # validate data try: self._data = SCHEMA_BACKUP(self._data) @@ -445,7 +456,7 @@ def _add_backup_json(): tar_info = tarfile.TarInfo(name="./backup.json") tar_info.size = len(raw_bytes) tar_info.mtime = int(time.time()) - self._outer_secure_tarfile_tarfile.addfile(tar_info, fileobj=fileobj) + outer_tarfile.addfile(tar_info, fileobj=fileobj) try: await self.sys_run_in_executor(_add_backup_json) diff --git a/supervisor/backups/manager.py b/supervisor/backups/manager.py index 6ecd000e2d4..e817cc45e59 100644 --- a/supervisor/backups/manager.py +++ b/supervisor/backups/manager.py @@ -405,7 +405,7 @@ async def _do_backup( try: self.sys_core.state = CoreState.FREEZE - async with backup: + async with backup.create(): # HomeAssistant Folder is for v1 if homeassistant: self._change_stage(BackupJobStage.HOME_ASSISTANT, backup) @@ -575,6 +575,7 @@ async def _do_restore( folder_list: list[str], homeassistant: bool, replace: bool, + location: str | None | type[DEFAULT], ) -> bool: """Restore from a backup. @@ -585,7 +586,7 @@ async def _do_restore( try: task_hass: asyncio.Task | None = None - async with backup: + async with backup.open(location): # Restore docker config self._change_stage(RestoreJobStage.DOCKER_CONFIG, backup) backup.restore_dockerconfig(replace) @@ -671,7 +672,10 @@ async def _do_restore( cleanup=False, ) async def do_restore_full( - self, backup: Backup, password: str | None = None + self, + backup: Backup, + password: str | None = None, + location: str | None | type[DEFAULT] = DEFAULT, ) -> bool: """Restore a backup.""" # Add backup ID to job @@ -702,7 +706,12 @@ async def do_restore_full( await self.sys_core.shutdown() success = await self._do_restore( - backup, backup.addon_list, backup.folders, True, True + backup, + backup.addon_list, + backup.folders, + homeassistant=True, + replace=True, + location=location, ) finally: self.sys_core.state = CoreState.RUNNING @@ -731,6 +740,7 @@ async def do_restore_partial( addons: list[str] | None = None, folders: list[Path] | None = None, password: str | None = None, + location: str | None | type[DEFAULT] = DEFAULT, ) -> bool: """Restore a backup.""" # Add backup ID to job @@ -766,7 +776,12 @@ async def do_restore_partial( try: success = await self._do_restore( - backup, addon_list, folder_list, homeassistant, False + backup, + addon_list, + folder_list, + homeassistant=homeassistant, + replace=False, + location=location, ) finally: self.sys_core.state = CoreState.RUNNING diff --git a/tests/api/test_backups.py b/tests/api/test_backups.py index 0c00d71ec75..cb607f59487 100644 --- a/tests/api/test_backups.py +++ b/tests/api/test_backups.py @@ -810,3 +810,50 @@ async def test_partial_backup_all_addons( ) assert resp.status == 200 store_addons.assert_called_once_with([install_addon_ssh]) + + +async def test_restore_backup_from_location( + api_client: TestClient, coresys: CoreSys, tmp_supervisor_data: Path +): + """Test restoring a backup from a specific location.""" + coresys.core.state = CoreState.RUNNING + coresys.hardware.disk.get_disk_free_space = lambda x: 5000 + + # Make a backup and a file to test with + (test_file := coresys.config.path_share / "test.txt").touch() + resp = await api_client.post( + "/backups/new/partial", + json={ + "name": "Test", + "folders": ["share"], + "location": [None, ".cloud_backup"], + }, + ) + assert resp.status == 200 + body = await resp.json() + backup = coresys.backups.get(body["data"]["slug"]) + assert set(backup.all_locations) == {None, ".cloud_backup"} + + # The use case of this is user might want to pick a particular mount if one is flaky + # To simulate this, remove the file from one location and show one works and the other doesn't + assert backup.location is None + backup.all_locations[None].unlink() + test_file.unlink() + + resp = await api_client.post( + f"/backups/{backup.slug}/restore/partial", + json={"location": None, "folders": ["share"]}, + ) + assert resp.status == 400 + body = await resp.json() + assert ( + body["message"] + == f"Cannot open backup at {backup.all_locations[None].as_posix()}, file does not exist!" + ) + + resp = await api_client.post( + f"/backups/{backup.slug}/restore/partial", + json={"location": ".cloud_backup", "folders": ["share"]}, + ) + assert resp.status == 200 + assert test_file.is_file() diff --git a/tests/backups/test_backup.py b/tests/backups/test_backup.py index aaf45c932e5..45e1e90b978 100644 --- a/tests/backups/test_backup.py +++ b/tests/backups/test_backup.py @@ -14,7 +14,7 @@ async def test_new_backup_stays_in_folder(coresys: CoreSys, tmp_path: Path): backup.new("test", "2023-07-21T21:05:00.000000+00:00", BackupType.FULL) assert not listdir(tmp_path) - async with backup: + async with backup.create(): assert len(listdir(tmp_path)) == 1 assert backup.tarfile.exists()