diff --git a/b2/console_tool.py b/b2/console_tool.py index bdf570be2..2f99f56c7 100644 --- a/b2/console_tool.py +++ b/b2/console_tool.py @@ -11,6 +11,8 @@ ###################################################################### from __future__ import annotations +import tempfile + from b2._cli.autocomplete_cache import AUTOCOMPLETE # noqa AUTOCOMPLETE.autocomplete_from_cache() @@ -1606,10 +1608,50 @@ def _represent_legal_hold(cls, legal_hold: LegalHold): def _print_file_attribute(self, label, value): self._print((label + ':').ljust(20) + ' ' + value) - def get_local_output_filepath(self, filename: str) -> pathlib.Path: + def get_local_output_filepath( + self, filename: str, file_request: DownloadedFile + ) -> pathlib.Path: if filename == '-': return STDOUT_FILEPATH - return pathlib.Path(filename) + + output_filepath = pathlib.Path(filename) + + # As longs as it's not a directory, we're overwriting everything. + if not output_filepath.is_dir(): + return output_filepath + + # If the output is directory, we're expected to download the file right there. + # Normally, we overwrite the target without asking any questions, but in this case + # user might be oblivious of the actual mistake he's about to commit. + # If he, e.g.: downloads file by ID, he might not know the name of the file + # and actually overwrite something unintended. + output_directory = output_filepath + output_filepath = output_directory / file_request.download_version.file_name + # If it doesn't exist, we stop worrying. + if not output_filepath.exists(): + return output_filepath + + # If it does exist, we make a unique file prefixed with the actual file name. + file_name_as_path = pathlib.Path(file_request.download_version.file_name) + file_name = file_name_as_path.stem + file_extension = file_name_as_path.suffix + + # Default permissions are: readable and writable by this user only, executable by noone. + # This "temporary" file is not automatically removed, but still created in the safest way possible. + fd_handle, output_filepath_str = tempfile.mkstemp( + prefix=file_name, + suffix=file_extension, + dir=output_directory, + ) + # Close the handle, so the file is not locked. + # This file is no longer 100% "safe", but that's acceptable. + os.close(fd_handle) + + # "Normal" file created by Python has readable for everyone, writable for user only. + # We change the permissions, to match the default ones. + os.chmod(output_filepath_str, 0o644) + + return pathlib.Path(output_filepath_str) class DownloadFileBase( @@ -1645,7 +1687,7 @@ def _run(self, args): ) self._print_download_info(downloaded_file) - output_filepath = self.get_local_output_filepath(args.localFileName) + output_filepath = self.get_local_output_filepath(args.localFileName, downloaded_file) downloaded_file.save_to(output_filepath) self._print('Download finished') @@ -1711,7 +1753,7 @@ def _run(self, args): file_request = self.api.download_file_by_uri( args.B2_URI, progress_listener=progress_listener, encryption=encryption_setting ) - output_filepath = self.get_local_output_filepath(target_filename) + output_filepath = self.get_local_output_filepath(target_filename, file_request) file_request.save_to(output_filepath) return 0 diff --git a/changelog.d/+downloading_to_directory.added.md b/changelog.d/+downloading_to_directory.added.md new file mode 100644 index 000000000..b21c2124b --- /dev/null +++ b/changelog.d/+downloading_to_directory.added.md @@ -0,0 +1 @@ +Whenever target filename is a directory, file is downloaded into that directory. diff --git a/test/integration/test_b2_command_line.py b/test/integration/test_b2_command_line.py index 2dce5e1e3..88050ad55 100755 --- a/test/integration/test_b2_command_line.py +++ b/test/integration/test_b2_command_line.py @@ -17,6 +17,7 @@ import json import os import os.path +import pathlib import re import sys import time @@ -2682,6 +2683,46 @@ def test_download_file_stdout( ).replace("\r", "") == sample_filepath.read_text() +def test_download_file_to_directory( + b2_tool, bucket_name, sample_filepath, tmp_path, uploaded_sample_file +): + downloads_directory = 'downloads' + target_directory = tmp_path / downloads_directory + target_directory.mkdir() + filename_as_path = pathlib.Path(uploaded_sample_file['fileName']) + + sample_file_content = sample_filepath.read_text() + b2_tool.should_succeed( + [ + 'download-file', + '--quiet', + f"b2://{bucket_name}/{uploaded_sample_file['fileName']}", + str(target_directory), + ], + ) + downloaded_file = target_directory / filename_as_path + assert downloaded_file.read_text() == sample_file_content, \ + f'{downloaded_file}, {downloaded_file.read_text()}, {sample_file_content}' + + b2_tool.should_succeed( + [ + 'download-file', + '--quiet', + f"b2id://{uploaded_sample_file['fileId']}", + str(target_directory), + ], + ) + # A second file should be created. + new_files = [ + filepath + for filepath in target_directory.glob(f'{filename_as_path.stem}*{filename_as_path.suffix}') + if filepath.name != filename_as_path.name + ] + assert len(new_files) == 1, f'{new_files}' + assert new_files[0].read_text() == sample_file_content, \ + f'{new_files}, {new_files[0].read_text()}, {sample_file_content}' + + def test_cat(b2_tool, bucket_name, sample_filepath, tmp_path, uploaded_sample_file): assert b2_tool.should_succeed( ['cat', f"b2://{bucket_name}/{uploaded_sample_file['fileName']}"], diff --git a/test/unit/test_console_tool.py b/test/unit/test_console_tool.py index 8593797e0..fdfc3f4a5 100644 --- a/test/unit/test_console_tool.py +++ b/test/unit/test_console_tool.py @@ -1117,6 +1117,56 @@ def test_download_by_name_1_thread(self): def test_download_by_name_10_threads(self): self._test_download_threads(download_by='name', num_threads=10) + def _test_download_to_directory(self, download_by: str): + self._authorize_account() + self._create_my_bucket() + + base_filename = 'file' + extension = '.txt' + source_filename = f'{base_filename}{extension}' + + with TempDir() as temp_dir: + local_file = self._make_local_file(temp_dir, source_filename) + local_file_content = self._read_file(local_file) + + self._run_command( + ['upload-file', '--noProgress', 'my-bucket', local_file, source_filename], + remove_version=True, + ) + + b2uri = f'b2://my-bucket/{source_filename}' if download_by == 'name' else 'b2id://9999' + command = [ + 'download-file', + '--noProgress', + b2uri, + ] + + target_directory = os.path.join(temp_dir, 'target') + os.mkdir(target_directory) + command += [target_directory] + self._run_command(command) + self.assertEqual( + local_file_content, + self._read_file(os.path.join(target_directory, source_filename)) + ) + + # Download the file second time, to check the override behavior. + self._run_command(command) + # We should get another file. + target_directory_files = [ + elem + for elem in pathlib.Path(target_directory).glob(f'{base_filename}*{extension}') + if elem.name != source_filename + ] + assert len(target_directory_files) == 1, f'{target_directory_files}' + self.assertEqual(local_file_content, self._read_file(target_directory_files[0])) + + def test_download_by_id_to_directory(self): + self._test_download_to_directory(download_by='id') + + def test_download_by_name_to_directory(self): + self._test_download_to_directory(download_by='name') + def test_copy_file_by_id(self): self._authorize_account() self._create_my_bucket()