Skip to content

Commit

Permalink
Interrupt _CipherBackupStreamer workers (#136845)
Browse files Browse the repository at this point in the history
* Interrupt _CipherBackupStreamer workers

* Fix cleanup

* Only abort live threads
  • Loading branch information
emontnemery authored Jan 29, 2025
1 parent 3118831 commit 660653e
Showing 1 changed file with 57 additions and 37 deletions.
94 changes: 57 additions & 37 deletions homeassistant/components/backup/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pathlib import Path, PurePath
from queue import SimpleQueue
import tarfile
import threading
from typing import IO, Any, Self, cast

import aiohttp
Expand All @@ -22,6 +21,7 @@
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.json import JsonObjectType, json_loads_object
from homeassistant.util.thread import ThreadWithException

from .const import BUF_SIZE, LOGGER
from .models import AddonInfo, AgentBackup, Folder
Expand Down Expand Up @@ -57,6 +57,12 @@ class BackupEmpty(DecryptError):
_message = "No tar files found in the backup."


class AbortCipher(HomeAssistantError):
"""Abort the cipher operation."""

_message = "Abort cipher operation."


def make_backup_dir(path: Path) -> None:
"""Create a backup directory if it does not exist."""
path.mkdir(exist_ok=True)
Expand Down Expand Up @@ -252,24 +258,29 @@ def decrypt_backup(
"""Decrypt a backup."""
error: Exception | None = None
try:
with (
tarfile.open(
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
) as input_tar,
tarfile.open(
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
) as output_tar,
):
_decrypt_backup(input_tar, output_tar, password)
except (DecryptError, SecureTarError, tarfile.TarError) as err:
LOGGER.warning("Error decrypting backup: %s", err)
error = err
else:
# Pad the output stream to the requested minimum size
padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding)
try:
with (
tarfile.open(
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
) as input_tar,
tarfile.open(
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
) as output_tar,
):
_decrypt_backup(input_tar, output_tar, password)
except (DecryptError, SecureTarError, tarfile.TarError) as err:
LOGGER.warning("Error decrypting backup: %s", err)
error = err
else:
# Pad the output stream to the requested minimum size
padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding)
finally:
# Write an empty chunk to signal the end of the stream
output_stream.write(b"")
except AbortCipher:
LOGGER.debug("Cipher operation aborted")
finally:
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
on_done(error)


Expand Down Expand Up @@ -322,24 +333,29 @@ def encrypt_backup(
"""Encrypt a backup."""
error: Exception | None = None
try:
with (
tarfile.open(
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
) as input_tar,
tarfile.open(
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
) as output_tar,
):
_encrypt_backup(input_tar, output_tar, password, nonces)
except (EncryptError, SecureTarError, tarfile.TarError) as err:
LOGGER.warning("Error encrypting backup: %s", err)
error = err
else:
# Pad the output stream to the requested minimum size
padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding)
try:
with (
tarfile.open(
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
) as input_tar,
tarfile.open(
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
) as output_tar,
):
_encrypt_backup(input_tar, output_tar, password, nonces)
except (EncryptError, SecureTarError, tarfile.TarError) as err:
LOGGER.warning("Error encrypting backup: %s", err)
error = err
else:
# Pad the output stream to the requested minimum size
padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding)
finally:
# Write an empty chunk to signal the end of the stream
output_stream.write(b"")
except AbortCipher:
LOGGER.debug("Cipher operation aborted")
finally:
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
on_done(error)


Expand Down Expand Up @@ -387,7 +403,7 @@ def _encrypt_backup(
class _CipherWorkerStatus:
done: asyncio.Event
error: Exception | None = None
thread: threading.Thread
thread: ThreadWithException


class _CipherBackupStreamer:
Expand Down Expand Up @@ -440,7 +456,7 @@ def on_done(error: Exception | None) -> None:
stream = await self._open_stream()
reader = AsyncIteratorReader(self._hass, stream)
writer = AsyncIteratorWriter(self._hass)
worker = threading.Thread(
worker = ThreadWithException(
target=self._cipher_func,
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
)
Expand All @@ -451,6 +467,10 @@ def on_done(error: Exception | None) -> None:

async def wait(self) -> None:
"""Wait for the worker threads to finish."""
for worker in self._workers:
if not worker.thread.is_alive():
continue
worker.thread.raise_exc(AbortCipher)
await asyncio.gather(*(worker.done.wait() for worker in self._workers))


Expand Down

0 comments on commit 660653e

Please sign in to comment.