Skip to content

Commit

Permalink
Show progress bar when downloading files
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorjerse committed Jan 18, 2024
1 parent 7bf360f commit b0ea0d5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Added
update
- Add ``get_annotations`` on ``Sample`` which returns all annotations on a
sample as a dictionary
- Optionally show progress bar when downloading files

Fixed
-----
Expand Down
28 changes: 25 additions & 3 deletions src/resdk/resolwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
import getpass
import json
import logging
import ntpath
import os
Expand All @@ -17,11 +18,13 @@
import webbrowser
from contextlib import suppress
from importlib.metadata import version as package_version
from typing import Optional, TypedDict
from pathlib import Path
from typing import Optional, TypedDict, Union
from urllib.parse import urlencode, urljoin, urlparse

import requests
import slumber
import tqdm
from packaging import version

from resdk.uploader import Uploader
Expand Down Expand Up @@ -455,7 +458,9 @@ def get_or_run(self, slug=None, input={}):
model_data = self.api.data.get_or_create.post(data)
return Data(resolwe=self, **model_data)

def _download_files(self, files, download_dir=None):
def _download_files(
self, files: list[Union[str, Path]], download_dir=None, show_progress=True
):
"""Download files.
Download files from the Resolwe server to the download
Expand All @@ -481,6 +486,9 @@ def _download_files(self, files, download_dir=None):

else:
self.logger.info("Downloading files to %s:", download_dir)
# Store the sizes of files in the given directory.
# Use the dictionary to cache the responses.
sizes: dict[str, dict[str, int]] = dict()

for file_uri in files:
file_name = os.path.basename(file_uri)
Expand All @@ -495,7 +503,20 @@ def _download_files(self, files, download_dir=None):

self.logger.info("* %s", os.path.join(file_path, file_name))

with open(
file_directory = os.path.dirname(file_url)
if file_directory not in sizes:
sizes[file_directory] = {
entry["name"]: entry["size"]
for entry in json.loads(
self.session.get(file_directory, auth=self.auth).content
)
if entry["type"] == "file"
}
file_size = sizes[file_directory][file_name]

with tqdm.tqdm(
total=file_size, disable=not show_progress
) as progress_bar, open(
os.path.join(download_dir, file_path, file_name), "wb"
) as file_handle:
response = self.session.get(file_url, stream=True, auth=self.auth)
Expand All @@ -505,6 +526,7 @@ def _download_files(self, files, download_dir=None):
else:
for chunk in response.iter_content(chunk_size=CHUNK_SIZE):
file_handle.write(chunk)
progress_bar.update(len(chunk))

def data_usage(self, **query_params):
"""Get per-user data usage information.
Expand Down

0 comments on commit b0ea0d5

Please sign in to comment.