diff --git a/.gitignore b/.gitignore index 3d4c7df..f068c87 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__/ *.egg-info/ logs/ None/ +dist/ .pytest_cache/ .DS_Store loco_mujoco/datasets/ diff --git a/dist/loco-mujoco-0.1.tar.gz b/dist/loco-mujoco-0.1.tar.gz deleted file mode 100644 index 88c836f..0000000 Binary files a/dist/loco-mujoco-0.1.tar.gz and /dev/null differ diff --git a/download_datasets.py b/download_datasets.py deleted file mode 100644 index 22c427d..0000000 --- a/download_datasets.py +++ /dev/null @@ -1,3 +0,0 @@ -import loco_mujoco - -loco_mujoco.download_all_datasets() diff --git a/loco_mujoco/__init__.py b/loco_mujoco/__init__.py index 877e34e..2d0a9ec 100644 --- a/loco_mujoco/__init__.py +++ b/loco_mujoco/__init__.py @@ -1,16 +1,11 @@ __version__ = '0.1' - -from .environments import LocoEnv - try: + from .environments import LocoEnv + def get_all_task_names(): return LocoEnv.get_all_task_names() - - def download_all_datasets(): - return LocoEnv.download_all_datasets() - -except ImportError: - pass +except ImportError as e: + print(e) diff --git a/loco_mujoco/environments/base.py b/loco_mujoco/environments/base.py index 901f6c3..02c63e2 100644 --- a/loco_mujoco/environments/base.py +++ b/loco_mujoco/environments/base.py @@ -1,7 +1,5 @@ import os -import wget -import zipfile import warnings from pathlib import Path from copy import deepcopy @@ -953,36 +951,6 @@ def get_all_task_names(cls): _registered_envs = dict() - @classmethod - def download_all_datasets(cls): - """ - Download and installs all datasets. - - """ - dataset_path = Path(loco_mujoco.__file__).resolve().parent / "datasets" - - print("Downloading Humanoid Datasets ...\n") - dataset_path_humanoid = dataset_path / "humanoids/real" - dataset_path_humanoid_str = str(dataset_path_humanoid) - humanoid_url = "https://zenodo.org/records/10102870/files/humanoid_datasets_v0.1.zip?download=1" - wget.download(humanoid_url, out=dataset_path_humanoid_str) - file_name = "humanoid_datasets_v0.1.zip" - file_path = str(dataset_path_humanoid / file_name) - with zipfile.ZipFile(file_path, "r") as zip_ref: - zip_ref.extractall(dataset_path_humanoid_str) - os.remove(file_path) - - print("Downloading Quadruped Datasets ...\n") - dataset_path_quadrupeds = dataset_path / "quadrupeds/real" - dataset_path_quadrupeds_str = str(dataset_path_quadrupeds) - quadruped_url = "https://zenodo.org/records/10102870/files/quadruped_datasets_v0.1.zip?download=1" - wget.download(quadruped_url, out=dataset_path_quadrupeds_str) - file_name = "quadruped_datasets_v0.1.zip" - file_path = str(dataset_path_quadrupeds / file_name) - with zipfile.ZipFile(file_path, "r") as zip_ref: - zip_ref.extractall(dataset_path_quadrupeds_str) - os.remove(file_path) - class ValidTaskConf: diff --git a/loco_mujoco/utils/__init__.py b/loco_mujoco/utils/__init__.py index 3ed8891..ca1f42e 100644 --- a/loco_mujoco/utils/__init__.py +++ b/loco_mujoco/utils/__init__.py @@ -3,3 +3,4 @@ from .checks import * from .video import video2gif from .domain_randomization import * +from .dataset import download_all_datasets, download_real_datasets, download_perfect_dataset diff --git a/loco_mujoco/utils/dataset.py b/loco_mujoco/utils/dataset.py new file mode 100644 index 0000000..43a232b --- /dev/null +++ b/loco_mujoco/utils/dataset.py @@ -0,0 +1,82 @@ +import os +import wget +import zipfile +from pathlib import Path +import loco_mujoco + + +def download_all_datasets(): + """ + Download and installs all datasets. + + """ + download_real_datasets() + download_perfect_dataset() + + +def download_real_datasets(): + """ + Download and installs real datasets. + + """ + + dataset_path = Path(loco_mujoco.__file__).resolve().parent / "datasets" + print(dataset_path) + + print("Downloading Humanoid Datasets ...\n") + dataset_path_humanoid = dataset_path / "humanoids/real" + dataset_path_humanoid_str = str(dataset_path_humanoid) + os.makedirs(dataset_path_humanoid_str, exist_ok=True) + humanoid_url = "https://zenodo.org/records/10102870/files/humanoid_datasets_v0.1.zip?download=1" + wget.download(humanoid_url, out=dataset_path_humanoid_str) + file_name = "humanoid_datasets_v0.1.zip" + file_path = str(dataset_path_humanoid / file_name) + with zipfile.ZipFile(file_path, "r") as zip_ref: + zip_ref.extractall(dataset_path_humanoid_str) + os.remove(file_path) + + print("Downloading Quadruped Datasets ...\n") + dataset_path_quadrupeds = dataset_path / "quadrupeds/real" + dataset_path_quadrupeds_str = str(dataset_path_quadrupeds) + os.makedirs(dataset_path_quadrupeds_str, exist_ok=True) + quadruped_url = "https://zenodo.org/records/10102870/files/quadruped_datasets_v0.1.zip?download=1" + wget.download(quadruped_url, out=dataset_path_quadrupeds_str) + file_name = "quadruped_datasets_v0.1.zip" + file_path = str(dataset_path_quadrupeds / file_name) + with zipfile.ZipFile(file_path, "r") as zip_ref: + zip_ref.extractall(dataset_path_quadrupeds_str) + os.remove(file_path) + + +def download_perfect_dataset(): + """ + Download and installs perfect datasets. + + """ + dataset_path = Path(loco_mujoco.__file__).resolve().parent / "datasets" + + print("Downloading Perfect Humanoid Datasets ...\n") + dataset_path_humanoid = dataset_path / "humanoids/perfect" + dataset_path_humanoid_str = str(dataset_path_humanoid) + os.makedirs(dataset_path_humanoid_str, exist_ok=True) + humanoid_url = "https://zenodo.org/records/10102870/files/humanoid_datasets_v0.1.zip?download=1" + wget.download(humanoid_url, out=dataset_path_humanoid_str) + file_name = "humanoid_datasets_v0.1.zip" + file_path = str(dataset_path_humanoid / file_name) + with zipfile.ZipFile(file_path, "r") as zip_ref: + zip_ref.extractall(dataset_path_humanoid_str) + os.remove(file_path) + + + print("Downloading Perfect Quadruped Datasets ...\n") + dataset_path_quadrupeds = dataset_path / "quadrupeds/perfect" + dataset_path_quadrupeds_str = str(dataset_path_quadrupeds) + os.makedirs(dataset_path_quadrupeds_str, exist_ok=True) + quadruped_url = "https://zenodo.org/records/10102870/files/quadruped_datasets_v0.1.zip?download=1" + wget.download(quadruped_url, out=dataset_path_quadrupeds_str) + file_name = "quadruped_datasets_v0.1.zip" + file_path = str(dataset_path_quadrupeds / file_name) + with zipfile.ZipFile(file_path, "r") as zip_ref: + zip_ref.extractall(dataset_path_quadrupeds_str) + os.remove(file_path) + diff --git a/pyproject.toml b/pyproject.toml index 1450994..a000edd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "wheel"] +requires = ["setuptools", "wheel", "wget"] [project] name = "loco-mujoco" @@ -35,4 +35,6 @@ Repository = "https://github.com/robfiras/loco-mujoco" Issues = "https://github.com/robfiras/loco-mujoco/issues" [project.scripts] -loco-mujoco-download = "loco_mujoco:download_all_datasets" \ No newline at end of file +loco-mujoco-download = "loco_mujoco.utils:download_all_datasets" +loco-mujoco-download-real = "loco_mujoco.utils:download_real_datasets" +loco-mujoco-download-perfect = "loco_mujoco.utils:download_perfect_datasets" \ No newline at end of file diff --git a/setup.py b/setup.py index 88d775e..0836eae 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ from setuptools import setup, find_packages -from os import path import glob from loco_mujoco import __version__