From 61e9ca7a0822f163fe722d275645c32636cadbcf Mon Sep 17 00:00:00 2001 From: Piotr Date: Sat, 2 Mar 2024 08:16:01 +0100 Subject: [PATCH] Gracefully shutdown when received interrupt signal (#6) --- src/salesforce_archivist/cli.py | 11 ++++ .../salesforce/download.py | 32 ++++++++--- .../salesforce/salesforce.py | 6 ++- .../salesforce/validation.py | 19 ++++--- test/salesforce/test_download.py | 54 ++++++++++++++++++- test/salesforce/test_salesforce.py | 8 +-- test/salesforce/test_validation.py | 45 ++++++++++++++++ 7 files changed, 153 insertions(+), 22 deletions(-) diff --git a/src/salesforce_archivist/cli.py b/src/salesforce_archivist/cli.py index fbf1d57..ce46747 100644 --- a/src/salesforce_archivist/cli.py +++ b/src/salesforce_archivist/cli.py @@ -1,3 +1,6 @@ +import signal +from types import FrameType + import click from click import Context @@ -5,6 +8,14 @@ from simple_salesforce import Salesforce as SalesforceClient +def signal_handler(signum: int, frame: FrameType | None) -> None: + print("Received signal {}. Attempting graceful shutdown. Please wait...".format(signum)) + raise KeyboardInterrupt + + +signal.signal(signal.SIGINT, signal_handler) + + @click.group() @click.pass_context def cli(ctx: Context) -> None: diff --git a/src/salesforce_archivist/salesforce/download.py b/src/salesforce_archivist/salesforce/download.py index 5b159e4..c0322e3 100644 --- a/src/salesforce_archivist/salesforce/download.py +++ b/src/salesforce_archivist/salesforce/download.py @@ -158,6 +158,7 @@ def __init__( downloaded_version_list: DownloadedContentVersionList, max_api_usage_percent: float | None = None, wait_sec: int = 300, + max_workers: int | None = None, ): self._client = sf_client self._downloaded_versions_list = downloaded_version_list @@ -165,6 +166,8 @@ def __init__( self._wait_sec = wait_sec self._stats = DownloadStats() self._lock = threading.Lock() + self._stop_event = threading.Event() + self._max_workers = max_workers def download_content_version_from_sf(self, version: ContentVersion, download_path: str) -> None: downloaded_version = self._downloaded_versions_list.get_version(version) @@ -230,8 +233,11 @@ def download_or_wait(self, version: ContentVersion, download_path: str) -> None: ) error = False try: - self.download_content_version_from_sf(version=version, download_path=download_path) self._wait_if_api_usage_limit() + self.download_content_version_from_sf(version=version, download_path=download_path) + except StopDownloadException: + msg = "[ERROR] Stop signal received. Graceful shutdown." + error = True except Exception as e: msg = "[ERROR] Failed to download content version {id}: {error}".format(id=version.id, error=e) error = True @@ -240,11 +246,16 @@ def download_or_wait(self, version: ContentVersion, download_path: str) -> None: self._stats.add_processed(error=error) self._print_download_msg(msg, error=error) - def download(self, download_list: DownloadContentVersionList, max_workers: int | None = None) -> DownloadStats: + def download(self, download_list: DownloadContentVersionList) -> DownloadStats: self._stats.initialize(total=len(download_list)) - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - for version, download_path in download_list: - executor.submit(self.download_or_wait, version=version, download_path=download_path) + try: + with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) as executor: + for version, download_path in download_list: + executor.submit(self.download_or_wait, version=version, download_path=download_path) + except KeyboardInterrupt as e: + self._stop_event.set() + executor.shutdown(wait=True, cancel_futures=True) + raise e return self._stats def _wait_if_api_usage_limit(self) -> None: @@ -252,5 +263,14 @@ def _wait_if_api_usage_limit(self) -> None: usage = self._client.get_api_usage() while usage.percent >= self._max_api_usage_percent: self._print_download_msg(msg="[NOTICE] Waiting for API limit to drop.") - sleep(self._wait_sec) + for counter in range(self._wait_sec): + # check every second if stop signal was received, and if so, + # raise exception to stop current download + if self._stop_event.is_set(): + raise StopDownloadException + sleep(1) usage = self._client.get_api_usage(refresh=True) + + +class StopDownloadException(Exception): + pass diff --git a/src/salesforce_archivist/salesforce/salesforce.py b/src/salesforce_archivist/salesforce/salesforce.py index 484d2dc..ec87dbd 100644 --- a/src/salesforce_archivist/salesforce/salesforce.py +++ b/src/salesforce_archivist/salesforce/salesforce.py @@ -163,8 +163,9 @@ def download_files( sf_client=self._client, downloaded_version_list=downloaded_content_version_list, max_api_usage_percent=self._max_api_usage_percent, + max_workers=max_workers, ) - return downloader.download(download_list=download_content_version_list, max_workers=max_workers) + return downloader.download(download_list=download_content_version_list) finally: downloaded_content_version_list.save() @@ -177,7 +178,8 @@ def validate_download( try: validator = ContentVersionDownloadValidator( validated_content_version_list=validated_content_version_list, + max_workers=max_workers, ) - return validator.validate(download_list=download_content_version_list, max_workers=max_workers) + return validator.validate(download_list=download_content_version_list) finally: validated_content_version_list.save() diff --git a/src/salesforce_archivist/salesforce/validation.py b/src/salesforce_archivist/salesforce/validation.py index 1e24bd3..de9cb43 100644 --- a/src/salesforce_archivist/salesforce/validation.py +++ b/src/salesforce_archivist/salesforce/validation.py @@ -106,13 +106,11 @@ def invalid(self) -> int: class ContentVersionDownloadValidator: - def __init__( - self, - validated_content_version_list: ValidatedContentVersionList, - ): + def __init__(self, validated_content_version_list: ValidatedContentVersionList, max_workers: int | None = None): self._validated_list = validated_content_version_list self._stats = ValidationStats() self._lock = threading.Lock() + self._max_workers = max_workers def _print_validated_msg(self, msg: str, invalid: bool = False) -> None: percent = self._stats.processed / self._stats.total * 100 if self._stats.total > 0 else 0.0 @@ -166,9 +164,14 @@ def validate_version(self, version: ContentVersion, download_path: str) -> bool: return not invalid - def validate(self, download_list: DownloadContentVersionList, max_workers: int | None = None) -> ValidationStats: + def validate(self, download_list: DownloadContentVersionList) -> ValidationStats: self._stats.initialize(total=len(download_list)) - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - for version, download_path in download_list: - executor.submit(self.validate_version, version=version, download_path=download_path) + try: + with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) as executor: + for version, download_path in download_list: + executor.submit(self.validate_version, version=version, download_path=download_path) + except KeyboardInterrupt as e: + executor.shutdown(wait=True, cancel_futures=True) + raise e + return self._stats diff --git a/test/salesforce/test_download.py b/test/salesforce/test_download.py index 49f266f..d6b6646 100644 --- a/test/salesforce/test_download.py +++ b/test/salesforce/test_download.py @@ -196,6 +196,56 @@ def test_content_version_downloader_download_will_download_in_parallel(submit_mo assert submit_mock.call_count == 2 +@patch("concurrent.futures.ThreadPoolExecutor") +def test_content_version_downloader_download_will_use_defined_workers(thread_pool_mock): + archivist_obj = ArchivistObject(data_dir="/fake/dir", obj_type="User") + link_list = ContentDocumentLinkList(data_dir=archivist_obj.data_dir) + version_list = ContentVersionList(data_dir=archivist_obj.data_dir) + download_content_version_list = DownloadContentVersionList( + document_link_list=link_list, content_version_list=version_list, data_dir=archivist_obj.data_dir + ) + downloaded_version_list = DownloadedContentVersionList(data_dir=archivist_obj.data_dir) + sf_client = Mock() + max_workers = 3 + downloader = ContentVersionDownloader( + sf_client=sf_client, downloaded_version_list=downloaded_version_list, max_workers=max_workers + ) + downloader.download(download_list=download_content_version_list) + assert thread_pool_mock.call_args == call(max_workers=max_workers) + + +@patch.object(concurrent.futures.ThreadPoolExecutor, "submit", side_effect=KeyboardInterrupt) +@patch.object(concurrent.futures.ThreadPoolExecutor, "shutdown", return_value=None) +def test_content_version_downloader_download_will_gracefully_shutdown(shutdown_mock, submit_mock): + archivist_obj = ArchivistObject(data_dir="/fake/dir", obj_type="User") + link_list = ContentDocumentLinkList(data_dir=archivist_obj.data_dir) + link = ContentDocumentLink(linked_entity_id="LID", content_document_id="DOC1") + link_list.add_link(doc_link=link) + version_list = ContentVersionList(data_dir=archivist_obj.data_dir) + version_list.add_version( + version=ContentVersion( + id="VID1", + document_id=link.content_document_id, + checksum="c1", + extension="ext1", + title="version1", + version_number=1, + ) + ) + download_content_version_list = DownloadContentVersionList( + document_link_list=link_list, content_version_list=version_list, data_dir=archivist_obj.data_dir + ) + downloaded_version_list = DownloadedContentVersionList(data_dir=archivist_obj.data_dir) + sf_client = Mock() + downloader = ContentVersionDownloader( + sf_client=sf_client, + downloaded_version_list=downloaded_version_list, + ) + with pytest.raises(KeyboardInterrupt): + downloader.download(download_list=download_content_version_list) + shutdown_mock.assert_has_calls([call(wait=True), call(wait=True, cancel_futures=True)]) + + @patch("os.path.exists") def test_content_version_downloader_download_content_version_from_sf_will_add_already_downloaded_version_to_list( exist_mock, @@ -280,7 +330,7 @@ def test_content_version_downloader_download_content_version_from_sf_will_downlo assert file.read() == b"test" -@patch("salesforce_archivist.salesforce.download.sleep", spec=True, return_value=None) +@patch("salesforce_archivist.salesforce.download.sleep", return_value=None) def test_content_version_downloader_download_or_wait(sleep_mock): sf_client = MagicMock() api_usage = ApiUsage(Usage(used=50, total=100)) @@ -306,7 +356,7 @@ def usage_side_effect(refresh: bool) -> ApiUsage: ContentVersion(id="ID", document_id="DOC", checksum="c", extension="e", title="T", version_number=1), download_path="/fake/download/path", ) - sleep_mock.assert_called_once_with(wait) + sleep_mock.assert_has_calls([call(1) for _ in range(wait)]) def test_download_stats_initialize(): diff --git a/test/salesforce/test_salesforce.py b/test/salesforce/test_salesforce.py index ee6fed9..6e28bd7 100644 --- a/test/salesforce/test_salesforce.py +++ b/test/salesforce/test_salesforce.py @@ -454,8 +454,8 @@ def test_download_files_will_call_download_and_save(): ) download_mock.assert_has_calls( [ - call(download_list=download_content_version_list, max_workers=5), - call(download_list=download_content_version_list, max_workers=5), + call(download_list=download_content_version_list), + call(download_list=download_content_version_list), ] ) assert downloaded_content_version_list.save.call_count == 2 @@ -483,8 +483,8 @@ def test_validate_download_will_call_validate_and_save(): ) validate_mock.assert_has_calls( [ - call(download_list=download_content_version_list, max_workers=5), - call(download_list=download_content_version_list, max_workers=5), + call(download_list=download_content_version_list), + call(download_list=download_content_version_list), ] ) assert validated_content_version_list.save.call_count == 2 diff --git a/test/salesforce/test_validation.py b/test/salesforce/test_validation.py index fbd6e77..ea34f80 100644 --- a/test/salesforce/test_validation.py +++ b/test/salesforce/test_validation.py @@ -169,6 +169,51 @@ def test_content_version_download_validator_validate_will_validate_in_parallel(s assert submit_mock.call_count == 2 +@patch("concurrent.futures.ThreadPoolExecutor") +def test_content_version_download_validator_validate_will_use_defined_workers(thread_pool_mock): + archivist_obj = ArchivistObject(data_dir="/fake/dir", obj_type="User") + link_list = ContentDocumentLinkList(data_dir=archivist_obj.data_dir) + version_list = ContentVersionList(data_dir=archivist_obj.data_dir) + download_content_version_list = DownloadContentVersionList( + document_link_list=link_list, content_version_list=version_list, data_dir=archivist_obj.data_dir + ) + validated_version_list = ValidatedContentVersionList(data_dir=archivist_obj.data_dir) + max_workers = 3 + validator = ContentVersionDownloadValidator( + validated_content_version_list=validated_version_list, max_workers=max_workers + ) + validator.validate(download_list=download_content_version_list) + assert thread_pool_mock.call_args == call(max_workers=max_workers) + + +@patch.object(concurrent.futures.ThreadPoolExecutor, "submit", side_effect=KeyboardInterrupt) +@patch.object(concurrent.futures.ThreadPoolExecutor, "shutdown", return_value=None) +def test_content_version_download_validator_validate_will_gracefully_shutdown(shutdown_mock, submit_mock): + archivist_obj = ArchivistObject(data_dir="/fake/dir", obj_type="User") + link_list = ContentDocumentLinkList(data_dir=archivist_obj.data_dir) + link = ContentDocumentLink(linked_entity_id="LID", content_document_id="DOC1") + link_list.add_link(doc_link=link) + version_list = ContentVersionList(data_dir=archivist_obj.data_dir) + version_list.add_version( + version=ContentVersion( + id="VID1", + document_id=link.content_document_id, + checksum="c1", + extension="ext1", + title="version1", + version_number=1, + ) + ) + download_content_version_list = DownloadContentVersionList( + document_link_list=link_list, content_version_list=version_list, data_dir=archivist_obj.data_dir + ) + validated_version_list = ValidatedContentVersionList(data_dir=archivist_obj.data_dir) + validator = ContentVersionDownloadValidator(validated_content_version_list=validated_version_list) + with pytest.raises(KeyboardInterrupt): + validator.validate(download_list=download_content_version_list) + shutdown_mock.assert_has_calls([call(wait=True), call(wait=True, cancel_futures=True)]) + + def test_content_version_download_validator_validate_version_will_find_missing_file(): archivist_obj = ArchivistObject(data_dir="/fake/dir", obj_type="User") version = ContentVersion(