Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit number of download workers #86

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions b2sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,21 @@ def url_for_api(info, api_name):
class Services(object):
""" Gathers objects that provide high level logic over raw api usage. """

def __init__(self, session, max_upload_workers=10, max_copy_workers=10):
def __init__(
self, session, max_upload_workers=10, max_copy_workers=10, max_download_workers=None
):
"""
Initialize Services object using given session.

:param b2sdk.v1.Session session:
:param int max_upload_workers: a number of upload threads
:param int max_copy_workers: a number of copy threads
:param int max_download_workers: a maximum number of download threads.
If ``None`` then :class:`~b2sdk.v1.DownloadManager` ``4 * DEFAULT_MAX_STREAMS`` is used.
"""
self.session = session
self.large_file = LargeFileServices(self)
self.download_manager = DownloadManager(self)
self.download_manager = DownloadManager(self, max_download_workers=max_download_workers)
self.upload_manager = UploadManager(self, max_upload_workers=max_upload_workers)
self.copy_manager = CopyManager(self, max_copy_workers=max_copy_workers)
self.emerger = Emerger(self)
Expand Down Expand Up @@ -89,7 +93,8 @@ def __init__(
cache=None,
raw_api=None,
max_upload_workers=10,
max_copy_workers=10
max_copy_workers=10,
max_download_workers=None,
):
"""
Initialize the API using the given account info.
Expand All @@ -116,12 +121,15 @@ def __init__(

:param int max_upload_workers: a number of upload threads, default is 10
:param int max_copy_workers: a number of copy threads, default is 10
:param int max_download_workers: a maximum number of download threads.
If ``None`` then :class:`~b2sdk.v1.DownloadManager` ``4 * DEFAULT_MAX_STREAMS`` is used.
"""
self.session = B2Session(account_info=account_info, cache=cache, raw_api=raw_api)
self.services = Services(
self.session,
max_upload_workers=max_upload_workers,
max_copy_workers=max_copy_workers,
max_download_workers=max_download_workers,
)

@property
Expand Down
24 changes: 23 additions & 1 deletion b2sdk/transfer/inbound/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

import logging
import six
import threading

from contextlib import contextmanager

from b2sdk.download_dest import DownloadDestProgressWrapper
from b2sdk.progress import DoNothingProgressListener
Expand All @@ -30,6 +33,17 @@
logger = logging.getLogger(__name__)


class ProtectedSemaphore(object):
def __init__(self, semaphore):
self._lock = threading.RLock()
self._semaphore = semaphore

@contextmanager
def get_semaphore(self):
with self._lock:
yield self._semaphore


@six.add_metaclass(B2TraceMetaAbstract)
class DownloadManager(object):
"""
Expand All @@ -46,18 +60,26 @@ class DownloadManager(object):
MIN_CHUNK_SIZE = 8192 # ~1MB file will show ~1% progress increment
MAX_CHUNK_SIZE = 1024**2

def __init__(self, services):
def __init__(self, services, max_download_workers=None):
"""
Initialize the DownloadManager using the given services object.

:param b2sdk.v1.Services services:
:param int max_download_workers: a maximum number of download threads.
If ``None`` then ``4 * DEFAULT_MAX_STREAMS`` is used.
"""

self.services = services
self.max_download_workers = max_download_workers or 4 * self.DEFAULT_MAX_STREAMS
self.max_workers_semaphore = ProtectedSemaphore(
threading.BoundedSemaphore(self.max_download_workers)
)

self.strategies = [
ParallelDownloader(
max_streams=self.DEFAULT_MAX_STREAMS,
min_part_size=self.DEFAULT_MIN_PART_SIZE,
protected_semaphore=self.max_workers_semaphore,
min_chunk_size=self.MIN_CHUNK_SIZE,
max_chunk_size=self.MAX_CHUNK_SIZE,
),
Expand Down
93 changes: 66 additions & 27 deletions b2sdk/transfer/inbound/downloader/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ class ParallelDownloader(AbstractDownloader):
#
FINISH_HASHING_BUFFER_SIZE = 1024**2

def __init__(self, max_streams, min_part_size, *args, **kwargs):
def __init__(self, max_streams, min_part_size, protected_semaphore, *args, **kwargs):
"""
:param max_streams: maximum number of simultaneous streams
:param min_part_size: minimum amount of data a single stream will retrieve, in bytes
"""
self.max_streams = max_streams
self.min_part_size = min_part_size
self.protected_semaphore = protected_semaphore
super(ParallelDownloader, self).__init__(*args, **kwargs)

def is_suitable(self, metadata, progress_listener):
Expand Down Expand Up @@ -132,31 +133,59 @@ def _finish_hashing(self, first_part, file, hasher, content_length):
def _get_parts(
self, response, session, writer, hasher, first_part, parts_to_download, chunk_size
):
stream = FirstPartDownloaderThread(
response,
hasher,
session,
writer,
first_part,
chunk_size,
)
stream.start()
streams = [stream]

for part in parts_to_download:
stream = NonHashingDownloaderThread(
response.request.url,
session,
writer,
part,
chunk_size,
)
stream.start()
streams.append(stream)
with self.protected_semaphore.get_semaphore() as semaphore:
semaphore.acquire()
try:
stream = FirstPartDownloaderThread(
response,
hasher,
session,
writer,
first_part,
chunk_size,
semaphore,
)
stream.start()
except Exception:
semaphore.release()
raise

streams = [stream]

for part in parts_to_download:
semaphore.acquire()
try:
stream = NonHashingDownloaderThread(
response.request.url,
session,
writer,
part,
chunk_size,
semaphore,
)
stream.start()
except Exception:
semaphore.release()
raise
streams.append(stream)
for stream in streams:
stream.join()


class ClosableQueue(queue.Queue):
def __init__(self, *args, **kwargs):
super(ClosableQueue, self).__init__(*args, **kwargs)
self._closed = False

def put(self, *args, **kwargs):
if self._closed:
raise RuntimeError('queue closed')
return super(ClosableQueue, self).put(*args, **kwargs)

def close(self):
self._closed = True


class WriterThread(threading.Thread):
"""
A thread responsible for keeping a queue of data chunks to write to a file-like object and for actually writing them down.
Expand All @@ -183,7 +212,7 @@ class WriterThread(threading.Thread):

def __init__(self, file, max_queue_depth):
self.file = file
self.queue = queue.Queue(max_queue_depth)
self.queue = ClosableQueue(max_queue_depth)
self.total = 0
super(WriterThread, self).__init__()

Expand All @@ -204,25 +233,35 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
self.queue.put((True, None, None))
# any thread trying to put somthing on queue would fail with RuntimeError
self.queue.close()
self.join()


class AbstractDownloaderThread(threading.Thread):
def __init__(self, session, writer, part_to_download, chunk_size):
def __init__(self, session, writer, part_to_download, chunk_size, semaphore):
"""
:param session: raw_api wrapper
:param writer: where to write data
:param part_to_download: PartToDownload object
:param chunk_size: internal buffer size to use for writing and hashing
:param semaphore: already acquired semaphore that downloader thread has to release on finish
"""
self.session = session
self.writer = writer
self.part_to_download = part_to_download
self.chunk_size = chunk_size
self.semaphore = semaphore
super(AbstractDownloaderThread, self).__init__()

@abstractmethod
def run(self):
try:
self.run_download()
finally:
self.semaphore.release()

@abstractmethod
def run_download(self):
pass


Expand All @@ -236,7 +275,7 @@ def __init__(self, response, hasher, *args, **kwargs):
self.hasher = hasher
super(FirstPartDownloaderThread, self).__init__(*args, **kwargs)

def run(self):
def run_download(self):
writer_queue_put = self.writer.queue.put
hasher_update = self.hasher.update
first_offset = self.part_to_download.local_range.start
Expand Down Expand Up @@ -291,7 +330,7 @@ def __init__(self, url, *args, **kwargs):
self.url = url
super(NonHashingDownloaderThread, self).__init__(*args, **kwargs)

def run(self):
def run_download(self):
writer_queue_put = self.writer.queue.put
start_range = self.part_to_download.local_range.start
actual_part_size = self.part_to_download.local_range.size()
Expand Down
2 changes: 2 additions & 0 deletions test/v0/test_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ def setUp(self):
force_chunk_size=2,
max_streams=999,
min_part_size=2,
protected_semaphore=self.bucket.api.services.download_manager.max_workers_semaphore,
)
]

Expand Down Expand Up @@ -792,5 +793,6 @@ def setUp(self):
force_chunk_size=3,
max_streams=2,
min_part_size=2,
protected_semaphore=self.bucket.api.services.download_manager.max_workers_semaphore,
)
]
2 changes: 2 additions & 0 deletions test/v1/test_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,7 @@ def setUp(self):
force_chunk_size=2,
max_streams=999,
min_part_size=2,
protected_semaphore=self.bucket.api.services.download_manager.max_workers_semaphore,
)
]

Expand Down Expand Up @@ -892,5 +893,6 @@ def setUp(self):
force_chunk_size=3,
max_streams=2,
min_part_size=2,
protected_semaphore=self.bucket.api.services.download_manager.max_workers_semaphore,
)
]