From b0ea0d54c20d9650893d507b5152ffeb3e6ea10c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gregor=20Jer=C5=A1e?= Date: Thu, 18 Jan 2024 13:46:23 +0100 Subject: [PATCH] Show progress bar when downloading files --- docs/CHANGELOG.rst | 1 + src/resdk/resolwe.py | 28 +++++++++++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst index 07270273..9f335a27 100644 --- a/docs/CHANGELOG.rst +++ b/docs/CHANGELOG.rst @@ -26,6 +26,7 @@ Added update - Add ``get_annotations`` on ``Sample`` which returns all annotations on a sample as a dictionary +- Optionally show progress bar when downloading files Fixed ----- diff --git a/src/resdk/resolwe.py b/src/resdk/resolwe.py index 791cad18..fad92a1c 100644 --- a/src/resdk/resolwe.py +++ b/src/resdk/resolwe.py @@ -9,6 +9,7 @@ """ import getpass +import json import logging import ntpath import os @@ -17,11 +18,13 @@ import webbrowser from contextlib import suppress from importlib.metadata import version as package_version -from typing import Optional, TypedDict +from pathlib import Path +from typing import Optional, TypedDict, Union from urllib.parse import urlencode, urljoin, urlparse import requests import slumber +import tqdm from packaging import version from resdk.uploader import Uploader @@ -455,7 +458,9 @@ def get_or_run(self, slug=None, input={}): model_data = self.api.data.get_or_create.post(data) return Data(resolwe=self, **model_data) - def _download_files(self, files, download_dir=None): + def _download_files( + self, files: list[Union[str, Path]], download_dir=None, show_progress=True + ): """Download files. Download files from the Resolwe server to the download @@ -481,6 +486,9 @@ def _download_files(self, files, download_dir=None): else: self.logger.info("Downloading files to %s:", download_dir) + # Store the sizes of files in the given directory. + # Use the dictionary to cache the responses. + sizes: dict[str, dict[str, int]] = dict() for file_uri in files: file_name = os.path.basename(file_uri) @@ -495,7 +503,20 @@ def _download_files(self, files, download_dir=None): self.logger.info("* %s", os.path.join(file_path, file_name)) - with open( + file_directory = os.path.dirname(file_url) + if file_directory not in sizes: + sizes[file_directory] = { + entry["name"]: entry["size"] + for entry in json.loads( + self.session.get(file_directory, auth=self.auth).content + ) + if entry["type"] == "file" + } + file_size = sizes[file_directory][file_name] + + with tqdm.tqdm( + total=file_size, disable=not show_progress + ) as progress_bar, open( os.path.join(download_dir, file_path, file_name), "wb" ) as file_handle: response = self.session.get(file_url, stream=True, auth=self.auth) @@ -505,6 +526,7 @@ def _download_files(self, files, download_dir=None): else: for chunk in response.iter_content(chunk_size=CHUNK_SIZE): file_handle.write(chunk) + progress_bar.update(len(chunk)) def data_usage(self, **query_params): """Get per-user data usage information.