Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to download a repo folder (including CLI) #389

Merged
merged 20 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 140 additions & 1 deletion dagshub/common/api/repo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import logging
from pathlib import Path, PurePosixPath

import rich.progress

from dagshub.common.api.responses import (
RepoAPIResponse,
Expand All @@ -8,7 +11,10 @@
ContentAPIEntry,
StorageContentAPIResult,
)
from dagshub.common.download import download_files
from dagshub.common.rich_util import get_rich_progress
from dagshub.common.util import multi_urljoin
from functools import partial

try:
from functools import cached_property
Expand All @@ -22,7 +28,7 @@
import dagshub.auth
from dagshub.common import config

from dagshub.common.helpers import http_request
from dagshub.common.helpers import http_request, log_message

logger = logging.getLogger("dagshub")

Expand Down Expand Up @@ -254,6 +260,139 @@ def get_storage_file(self, path: str) -> bytes:
raise RuntimeError(error_msg)
return res.content

def _get_files_in_path(
self, path, revision=None, recursive=False, traverse_storages=False
) -> List[ContentAPIEntry]:
"""
Walks through the path of the repo, returning non-dir entries
"""

dir_queue = []
files = []

list_fn_folder = partial(self.list_path, revision=revision)
list_fn_storage = self.list_storage_path

def push_folder(folder_path):
dir_queue.append((folder_path, list_fn_folder))

def push_storage(storage_path):
dir_queue.append((storage_path, list_fn_storage))

# Initialize the queue
is_storage_path = str(path).lstrip("/").split("/")[0] in {"s3:", "gs:", "azure:"}
if is_storage_path:
# Handle storage paths - they start with s3:/, gs:/ or azure:/
path = str(path).replace(":", "", 1)
push_storage(path)
else:
push_folder(path)

progress = get_rich_progress(rich.progress.MofNCompleteColumn())
task = progress.add_task("Traversing directories...", total=None)

def step(step_path, list_fn):
"""
step_path: path in the repo to list
list_fn: which function to use to list it (can be list_path or list_storage_path)
"""
res = list_fn(step_path)
for entry in res:
if entry.type == "file":
files.append(entry)
elif recursive:
if entry.versioning == "bucket":
if traverse_storages:
push_storage(entry.path)
else:
push_folder(entry.path)
progress.update(task, advance=1)

with progress:
while len(dir_queue):
query_path, list_fn = dir_queue.pop(0)
step(query_path, list_fn)

return files

def download(
self,
remote_path,
local_path=".",
revision=None,
recursive=True,
keep_source_prefix=False,
redownload=False,
download_storages=False,
):
"""
Downloads the contents of the repository at "remote_path" to the "local_path"

Args:
remote_path: Path in the repository of the folder or file to download.
local_path: Where to download the files. Defaults to current working directory.
revision: Repo revision or branch, if not specified - uses default repo branch.
Ignored for downloading from buckets.
recursive: Whether to download files recursively.
keep_source_prefix: | Whether to keep the path of the folder in the download path or not.
| Example: Given remote_path ``src/data`` and file ``test/file.txt``
| if ``True``: will download to ``<local_path>/src/data/test/file.txt``
| if ``False``: will download to ``<local_path>/test/file.txt``
redownload: Whether to redownload files that already exist on the local filesystem.
The downloader doesn't do any hash comparisons and only checks
if a file already exists in the local filesystem or not.
download_storages: If downloading the whole repo, by default we're not downloading the integrated storages
Toggle this to ``True`` to change this behavior
"""
traverse_storages = True
if str(remote_path) == "/" and not download_storages:
log_message(
"Skipping downloading from connected storages. "
"Set the `download_storages` flag if you want "
"to download the whole content of the connected storages."
)
traverse_storages = False

files = self._get_files_in_path(remote_path, revision, recursive, traverse_storages=traverse_storages)
file_tuples = []
if local_path is None:
local_path = "."
local_path = Path(local_path)
# Strip the slashes from the beginning so the relative_to logic works
remote_path = str(remote_path).lstrip("/")
if not remote_path:
remote_path = "/"
# For storage paths get rid of the colon in the beginning of the schema, the download urls won't have it either
if remote_path.split("/")[0] in {"s3:", "gs:", "azure:"}:
remote_path = remote_path.replace(":", "", 1)
# Edge case - if the user requested a single file - different output path semantics
if len(files) == 1 and files[0].path == remote_path:
f = files[0]
remote_path = PurePosixPath(f.path)
# If local_path was specified, assume that the local_path is the exact name of the file
if local_path != Path("."):
# Saving to existing dir - append the name of remote file to the end a-la cp
if local_path.exists() and local_path.is_dir():
remote_path = remote_path if keep_source_prefix else remote_path.name
file_path = local_path / remote_path
else:
file_path = local_path
else:
file_path = remote_path if keep_source_prefix else remote_path.name
file_tuples.append((f.download_url, file_path))
else:
for f in files:
file_path_in_remote = PurePosixPath(f.path)
remote_path_obj = PurePosixPath(remote_path)
if not keep_source_prefix and remote_path != "/":
file_path = file_path_in_remote.relative_to(remote_path_obj)
else:
file_path = file_path_in_remote
file_path = local_path / file_path
file_tuples.append((f.download_url, file_path))
download_files(file_tuples, skip_if_exists=not redownload)
log_message(f"Downloaded {len(files)} file(s) to {local_path.resolve()}")

@cached_property
def default_branch(self) -> str:
"""
Expand Down
50 changes: 50 additions & 0 deletions dagshub/common/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dagshub.common.logging_util
from dagshub import init, __version__
from dagshub.common import config, rich_console
from dagshub.common.api.repo import RepoAPI
from dagshub.upload import create_repo, Repo
from dagshub.common.helpers import http_request, log_message
from dagshub.upload.errors import UpdateNotAllowedError
Expand Down Expand Up @@ -118,6 +119,55 @@ def to_log_level(verbosity):
return logging.DEBUG


KEEP_PREFIX_HELP = """ Whether to keep the path of the folder in the download path or not.
Example: Given remote_path "src/data" and file "test/file.txt"
if True: will download to "<local_path>/src/data/test/file.txt"
if False: will download to "<local_path>/test/file.txt"
"""


@cli.command()
@click.argument("repo", callback=validate_repo)
@click.argument("remote_path")
@click.argument("local_path", required=False, type=click.Path())
@click.option(
"-b", "--branch", help="Branch or revision to download from. " "If left unspecified, use the default branch."
)
@click.option("--keep-prefix", is_flag=True, default=False, help=KEEP_PREFIX_HELP)
@click.option("--not-recursive", is_flag=True, help="Don't download nested folders")
@click.option("--redownload", is_flag=True, help="Redownload files, even if they already exist locally")
@click.option("-v", "--verbose", default=0, count=True, help="Verbosity level")
@click.option("-q", "--quiet", is_flag=True, help="Suppress print output")
@click.option("--host", help="DagsHub instance to which you want to login")
@click.pass_context
def download(ctx, repo, remote_path, local_path, branch, not_recursive, keep_prefix, verbose, quiet, host, redownload):
"""
Download REMOTE_PATH from REPO to LOCAL_PATH

REMOTE_PATH can be either directory or a file

If LOCAL_PATH is left blank, downloads to the current directory

Example:
dagshub download nirbarazida/CheXNet data_labeling/data ./data
"""
host = host or ctx.obj["host"]
config.quiet = quiet or ctx.obj["quiet"]

logger = logging.getLogger()
logger.setLevel(to_log_level(verbose))

repoApi = RepoAPI(f"{repo[0]}/{repo[1]}", host=host)
repoApi.download(
remote_path,
local_path,
revision=branch,
recursive=not not_recursive,
keep_source_prefix=keep_prefix,
redownload=redownload,
)


@cli.command()
@click.argument("repo", callback=validate_repo)
@click.argument("filename", type=click.Path(exists=True))
Expand Down