Skip to content

Commit

Permalink
Gracefully shutdown when received interrupt signal (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrekkr authored Mar 2, 2024
1 parent 2861e66 commit 61e9ca7
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 22 deletions.
11 changes: 11 additions & 0 deletions src/salesforce_archivist/cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import signal
from types import FrameType

import click
from click import Context

from salesforce_archivist.archivist import Archivist, ArchivistConfig
from simple_salesforce import Salesforce as SalesforceClient


def signal_handler(signum: int, frame: FrameType | None) -> None:
print("Received signal {}. Attempting graceful shutdown. Please wait...".format(signum))
raise KeyboardInterrupt


signal.signal(signal.SIGINT, signal_handler)


@click.group()
@click.pass_context
def cli(ctx: Context) -> None:
Expand Down
32 changes: 26 additions & 6 deletions src/salesforce_archivist/salesforce/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,16 @@ def __init__(
downloaded_version_list: DownloadedContentVersionList,
max_api_usage_percent: float | None = None,
wait_sec: int = 300,
max_workers: int | None = None,
):
self._client = sf_client
self._downloaded_versions_list = downloaded_version_list
self._max_api_usage_percent = max_api_usage_percent
self._wait_sec = wait_sec
self._stats = DownloadStats()
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._max_workers = max_workers

def download_content_version_from_sf(self, version: ContentVersion, download_path: str) -> None:
downloaded_version = self._downloaded_versions_list.get_version(version)
Expand Down Expand Up @@ -230,8 +233,11 @@ def download_or_wait(self, version: ContentVersion, download_path: str) -> None:
)
error = False
try:
self.download_content_version_from_sf(version=version, download_path=download_path)
self._wait_if_api_usage_limit()
self.download_content_version_from_sf(version=version, download_path=download_path)
except StopDownloadException:
msg = "[ERROR] Stop signal received. Graceful shutdown."
error = True
except Exception as e:
msg = "[ERROR] Failed to download content version {id}: {error}".format(id=version.id, error=e)
error = True
Expand All @@ -240,17 +246,31 @@ def download_or_wait(self, version: ContentVersion, download_path: str) -> None:
self._stats.add_processed(error=error)
self._print_download_msg(msg, error=error)

def download(self, download_list: DownloadContentVersionList, max_workers: int | None = None) -> DownloadStats:
def download(self, download_list: DownloadContentVersionList) -> DownloadStats:
self._stats.initialize(total=len(download_list))
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for version, download_path in download_list:
executor.submit(self.download_or_wait, version=version, download_path=download_path)
try:
with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) as executor:
for version, download_path in download_list:
executor.submit(self.download_or_wait, version=version, download_path=download_path)
except KeyboardInterrupt as e:
self._stop_event.set()
executor.shutdown(wait=True, cancel_futures=True)
raise e
return self._stats

def _wait_if_api_usage_limit(self) -> None:
if self._max_api_usage_percent is not None:
usage = self._client.get_api_usage()
while usage.percent >= self._max_api_usage_percent:
self._print_download_msg(msg="[NOTICE] Waiting for API limit to drop.")
sleep(self._wait_sec)
for counter in range(self._wait_sec):
# check every second if stop signal was received, and if so,
# raise exception to stop current download
if self._stop_event.is_set():
raise StopDownloadException
sleep(1)
usage = self._client.get_api_usage(refresh=True)


class StopDownloadException(Exception):
pass
6 changes: 4 additions & 2 deletions src/salesforce_archivist/salesforce/salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,9 @@ def download_files(
sf_client=self._client,
downloaded_version_list=downloaded_content_version_list,
max_api_usage_percent=self._max_api_usage_percent,
max_workers=max_workers,
)
return downloader.download(download_list=download_content_version_list, max_workers=max_workers)
return downloader.download(download_list=download_content_version_list)
finally:
downloaded_content_version_list.save()

Expand All @@ -177,7 +178,8 @@ def validate_download(
try:
validator = ContentVersionDownloadValidator(
validated_content_version_list=validated_content_version_list,
max_workers=max_workers,
)
return validator.validate(download_list=download_content_version_list, max_workers=max_workers)
return validator.validate(download_list=download_content_version_list)
finally:
validated_content_version_list.save()
19 changes: 11 additions & 8 deletions src/salesforce_archivist/salesforce/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,11 @@ def invalid(self) -> int:


class ContentVersionDownloadValidator:
def __init__(
self,
validated_content_version_list: ValidatedContentVersionList,
):
def __init__(self, validated_content_version_list: ValidatedContentVersionList, max_workers: int | None = None):
self._validated_list = validated_content_version_list
self._stats = ValidationStats()
self._lock = threading.Lock()
self._max_workers = max_workers

def _print_validated_msg(self, msg: str, invalid: bool = False) -> None:
percent = self._stats.processed / self._stats.total * 100 if self._stats.total > 0 else 0.0
Expand Down Expand Up @@ -166,9 +164,14 @@ def validate_version(self, version: ContentVersion, download_path: str) -> bool:

return not invalid

def validate(self, download_list: DownloadContentVersionList, max_workers: int | None = None) -> ValidationStats:
def validate(self, download_list: DownloadContentVersionList) -> ValidationStats:
self._stats.initialize(total=len(download_list))
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for version, download_path in download_list:
executor.submit(self.validate_version, version=version, download_path=download_path)
try:
with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) as executor:
for version, download_path in download_list:
executor.submit(self.validate_version, version=version, download_path=download_path)
except KeyboardInterrupt as e:
executor.shutdown(wait=True, cancel_futures=True)
raise e

return self._stats
54 changes: 52 additions & 2 deletions test/salesforce/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,56 @@ def test_content_version_downloader_download_will_download_in_parallel(submit_mo
assert submit_mock.call_count == 2


@patch("concurrent.futures.ThreadPoolExecutor")
def test_content_version_downloader_download_will_use_defined_workers(thread_pool_mock):
archivist_obj = ArchivistObject(data_dir="/fake/dir", obj_type="User")
link_list = ContentDocumentLinkList(data_dir=archivist_obj.data_dir)
version_list = ContentVersionList(data_dir=archivist_obj.data_dir)
download_content_version_list = DownloadContentVersionList(
document_link_list=link_list, content_version_list=version_list, data_dir=archivist_obj.data_dir
)
downloaded_version_list = DownloadedContentVersionList(data_dir=archivist_obj.data_dir)
sf_client = Mock()
max_workers = 3
downloader = ContentVersionDownloader(
sf_client=sf_client, downloaded_version_list=downloaded_version_list, max_workers=max_workers
)
downloader.download(download_list=download_content_version_list)
assert thread_pool_mock.call_args == call(max_workers=max_workers)


@patch.object(concurrent.futures.ThreadPoolExecutor, "submit", side_effect=KeyboardInterrupt)
@patch.object(concurrent.futures.ThreadPoolExecutor, "shutdown", return_value=None)
def test_content_version_downloader_download_will_gracefully_shutdown(shutdown_mock, submit_mock):
archivist_obj = ArchivistObject(data_dir="/fake/dir", obj_type="User")
link_list = ContentDocumentLinkList(data_dir=archivist_obj.data_dir)
link = ContentDocumentLink(linked_entity_id="LID", content_document_id="DOC1")
link_list.add_link(doc_link=link)
version_list = ContentVersionList(data_dir=archivist_obj.data_dir)
version_list.add_version(
version=ContentVersion(
id="VID1",
document_id=link.content_document_id,
checksum="c1",
extension="ext1",
title="version1",
version_number=1,
)
)
download_content_version_list = DownloadContentVersionList(
document_link_list=link_list, content_version_list=version_list, data_dir=archivist_obj.data_dir
)
downloaded_version_list = DownloadedContentVersionList(data_dir=archivist_obj.data_dir)
sf_client = Mock()
downloader = ContentVersionDownloader(
sf_client=sf_client,
downloaded_version_list=downloaded_version_list,
)
with pytest.raises(KeyboardInterrupt):
downloader.download(download_list=download_content_version_list)
shutdown_mock.assert_has_calls([call(wait=True), call(wait=True, cancel_futures=True)])


@patch("os.path.exists")
def test_content_version_downloader_download_content_version_from_sf_will_add_already_downloaded_version_to_list(
exist_mock,
Expand Down Expand Up @@ -280,7 +330,7 @@ def test_content_version_downloader_download_content_version_from_sf_will_downlo
assert file.read() == b"test"


@patch("salesforce_archivist.salesforce.download.sleep", spec=True, return_value=None)
@patch("salesforce_archivist.salesforce.download.sleep", return_value=None)
def test_content_version_downloader_download_or_wait(sleep_mock):
sf_client = MagicMock()
api_usage = ApiUsage(Usage(used=50, total=100))
Expand All @@ -306,7 +356,7 @@ def usage_side_effect(refresh: bool) -> ApiUsage:
ContentVersion(id="ID", document_id="DOC", checksum="c", extension="e", title="T", version_number=1),
download_path="/fake/download/path",
)
sleep_mock.assert_called_once_with(wait)
sleep_mock.assert_has_calls([call(1) for _ in range(wait)])


def test_download_stats_initialize():
Expand Down
8 changes: 4 additions & 4 deletions test/salesforce/test_salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,8 @@ def test_download_files_will_call_download_and_save():
)
download_mock.assert_has_calls(
[
call(download_list=download_content_version_list, max_workers=5),
call(download_list=download_content_version_list, max_workers=5),
call(download_list=download_content_version_list),
call(download_list=download_content_version_list),
]
)
assert downloaded_content_version_list.save.call_count == 2
Expand Down Expand Up @@ -483,8 +483,8 @@ def test_validate_download_will_call_validate_and_save():
)
validate_mock.assert_has_calls(
[
call(download_list=download_content_version_list, max_workers=5),
call(download_list=download_content_version_list, max_workers=5),
call(download_list=download_content_version_list),
call(download_list=download_content_version_list),
]
)
assert validated_content_version_list.save.call_count == 2
45 changes: 45 additions & 0 deletions test/salesforce/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,51 @@ def test_content_version_download_validator_validate_will_validate_in_parallel(s
assert submit_mock.call_count == 2


@patch("concurrent.futures.ThreadPoolExecutor")
def test_content_version_download_validator_validate_will_use_defined_workers(thread_pool_mock):
archivist_obj = ArchivistObject(data_dir="/fake/dir", obj_type="User")
link_list = ContentDocumentLinkList(data_dir=archivist_obj.data_dir)
version_list = ContentVersionList(data_dir=archivist_obj.data_dir)
download_content_version_list = DownloadContentVersionList(
document_link_list=link_list, content_version_list=version_list, data_dir=archivist_obj.data_dir
)
validated_version_list = ValidatedContentVersionList(data_dir=archivist_obj.data_dir)
max_workers = 3
validator = ContentVersionDownloadValidator(
validated_content_version_list=validated_version_list, max_workers=max_workers
)
validator.validate(download_list=download_content_version_list)
assert thread_pool_mock.call_args == call(max_workers=max_workers)


@patch.object(concurrent.futures.ThreadPoolExecutor, "submit", side_effect=KeyboardInterrupt)
@patch.object(concurrent.futures.ThreadPoolExecutor, "shutdown", return_value=None)
def test_content_version_download_validator_validate_will_gracefully_shutdown(shutdown_mock, submit_mock):
archivist_obj = ArchivistObject(data_dir="/fake/dir", obj_type="User")
link_list = ContentDocumentLinkList(data_dir=archivist_obj.data_dir)
link = ContentDocumentLink(linked_entity_id="LID", content_document_id="DOC1")
link_list.add_link(doc_link=link)
version_list = ContentVersionList(data_dir=archivist_obj.data_dir)
version_list.add_version(
version=ContentVersion(
id="VID1",
document_id=link.content_document_id,
checksum="c1",
extension="ext1",
title="version1",
version_number=1,
)
)
download_content_version_list = DownloadContentVersionList(
document_link_list=link_list, content_version_list=version_list, data_dir=archivist_obj.data_dir
)
validated_version_list = ValidatedContentVersionList(data_dir=archivist_obj.data_dir)
validator = ContentVersionDownloadValidator(validated_content_version_list=validated_version_list)
with pytest.raises(KeyboardInterrupt):
validator.validate(download_list=download_content_version_list)
shutdown_mock.assert_has_calls([call(wait=True), call(wait=True, cancel_futures=True)])


def test_content_version_download_validator_validate_version_will_find_missing_file():
archivist_obj = ArchivistObject(data_dir="/fake/dir", obj_type="User")
version = ContentVersion(
Expand Down

0 comments on commit 61e9ca7

Please sign in to comment.