diff --git a/milatools/datasets/torch.py b/milatools/datasets/torch.py index 6d396d89..a910ff1d 100644 --- a/milatools/datasets/torch.py +++ b/milatools/datasets/torch.py @@ -15,13 +15,14 @@ def fetch_imagenet(local_directory=None): train_directory = os.path.join(dataset_home, "train") validation_directory = os.path.join(dataset_home, "val") - subprocess.run(f"mkdir -p {train_directory}/ {validation_directory}/") - subprocess.run(f"tar -xf /network/datasets/imagenet/ILSVRC2012_img_train.tar -C {train_directory}") + subprocess.run(f"mkdir -p {train_directory}/ {validation_directory}/", shell=True) + subprocess.run(f"tar -xf /network/datasets/imagenet/ILSVRC2012_img_train.tar -C {train_directory}", shell=True) p = subprocess.Popen(['cp', '-r', f'/network/datasets/imagenet.var/imagenet_torchvision/val {dataset_home}/']) subprocess.run( - 'find ' + train_directory + ' -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done', shell=True - ) + 'find ' + train_directory + + ' -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done', + shell=True) p.wait() return train_directory, validation_directory else: @@ -49,8 +50,8 @@ def __init__(self, local_directory: str = None, *dataset_args, **dataset_kwargs) if milatools.running_on_mila_cluster: local_directory = os.path.join(os.environ["SLURM_TMPDIR"], "MNIST") mnist_path = DATASET_PATH.format("mnist", "mnist") - subprocess.run(f"mkdir -p {local_directory}/") - subprocess.run(f"tar -xf {mnist_path} -C {mnist_path}") + subprocess.run("mkdir -p {local_directory}/", shell=True) + subprocess.run(f"tar -xf {mnist_path} -C {mnist_path}", shell=True) super().__init__(local_directory, *dataset_args, **dataset_kwargs)