-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 9315a24
Showing
82 changed files
with
11,973 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Ignore compiled files | ||
*.class | ||
*.o | ||
*.pyc | ||
|
||
# Ignore build output | ||
build/ | ||
dist/ | ||
|
||
*~ | ||
.hydra | ||
*pyc | ||
*.egg-info/ | ||
datasets | ||
datasets_scripted | ||
*videos | ||
*.mp4 | ||
results | ||
outputs/* | ||
/MUJOCO_LOG.TXT | ||
wandb | ||
experiments/ | ||
experiments_clip/* | ||
evaluations_clip/* | ||
experiments_finetune/* | ||
eval_outs | ||
pace_models/* | ||
*.zip | ||
*.out | ||
*.hdf5 | ||
*.pth | ||
experiments_saved/ | ||
ppo_experiments/ | ||
slurm/ | ||
# Ignore IDE and editor files | ||
.idea/ | ||
.vscode/ | ||
*.sublime-project | ||
*.sublime-workspace | ||
|
||
# Ignore dependencies | ||
node_modules/ | ||
vendor/ | ||
|
||
# Ignore logs and temporary files | ||
*.log | ||
*.tmp | ||
|
||
# Ignore sensitive information | ||
config.ini | ||
secrets.json | ||
|
||
# Ignore generated documentation | ||
docs/ | ||
|
||
# Ignore OS-specific files | ||
.DS_Store | ||
Thumbs.db | ||
|
||
data | ||
experiments | ||
zplots |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2024 Atharva Mete | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# QueST: Self-Supervised Skill Abstractions for Continuous Control | ||
|
||
Atharva Mete, Haotian Xue, Albert Wilcox, Yongxin Chen, Animesh Garg | ||
|
||
[![Static Badge](https://img.shields.io/badge/Project-Page-green?style=for-the-badge)](https://quest-model.github.io/) | ||
[![arXiv](https://img.shields.io/badge/arXiv-2406.09246-df2a2a.svg?style=for-the-badge)](https://arxiv.org/abs/2407.15840) | ||
[![PyTorch](https://img.shields.io/badge/PyTorch-2.2.0-EE4C2C.svg?style=for-the-badge&logo=pytorch)](https://pytorch.org/get-started/locally/) | ||
[![Python](https://img.shields.io/badge/python-3.10-blue?style=for-the-badge)](https://www.python.org) | ||
[![License](https://img.shields.io/github/license/TRI-ML/prismatic-vlms?style=for-the-badge)](LICENSE) | ||
|
||
[**Installation**](#installation) | [**Dataset Download**](#dataset-download) | [**Training**](#training) | [**Evaluation**](#evaluating) | [**Project Website**](https://quest-model.github.io/) | ||
|
||
|
||
<hr style="border: 2px solid gray;"></hr> | ||
|
||
## Latest Updates | ||
- [2024-09-09] Initial release | ||
|
||
<hr style="border: 2px solid gray;"></hr> | ||
|
||
## Installation | ||
|
||
Please run the following commands in the given order to install the dependency for QueST | ||
``` | ||
conda create -n quest python=3.10.14 | ||
conda activate quest | ||
git clone https://github.com/atharvamete/QueST.git | ||
cd quest | ||
python -m pip install torch==2.2.0 torchvision==0.17.0 | ||
python -m pip install -e . | ||
``` | ||
Note: Above automatically installs metaworld as python packages | ||
|
||
Install LIBERO seperately | ||
``` | ||
git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git | ||
cd LIBERO | ||
python -m pip install -e . | ||
``` | ||
Note: All LIBERO dependencies are already included in quest/requirements.txt | ||
|
||
## Dataset Download | ||
LIBERO: Please download the libero data seperately following their [docs](https://lifelong-robot-learning.github.io/LIBERO/html/algo_data/datasets.html#datasets). | ||
|
||
MetaWorld: We have provided the script we used to collect the data using scripted policies in the MetaWorld package. Please run the following command to collect the data. This uses configs as per [collect_data.yaml](config/collect_data.yaml). | ||
``` | ||
python scripts/generate_metaworld_dataset.py | ||
``` | ||
We generate 100 demonstrations for each of 45 pretraining tasks and 5 for downstream tasks. | ||
|
||
## Training | ||
First set the path to the dataset `data_prefix` and `output_prefix` in [train_base](config/train_base.yaml). `output_prefix` is where all the logs and checkpoints will be stored. | ||
|
||
We provide detailed sample commands for training all stages and for all baselines in the [scripts](scripts) directory. For all methods, [autoencoder.sh](scripts/quest/autoencoder.sh) trains the autoencoder (only used in QueST and VQ-BeT), [main.sh](scripts/quest/main.sh) trains the main algorithm (skill-prior incase of QueST), and [finetune.sh](scripts/quest/finetune.sh) finetunes the model on downstream tasks. | ||
|
||
Run the following command to train QueST's stage-0 i.e. the autoencoder. (ref: [autoencoder.sh](scripts/quest/autoencoder.sh)) | ||
``` | ||
python train.py --config-name=train_autoencoder.yaml \ | ||
task=libero_90 \ | ||
algo=quest \ | ||
exp_name=final \ | ||
variant_name=block_32_ds_4 \ | ||
algo.skill_block_size=32 \ | ||
algo.downsample_factor=4 \ | ||
seed=0 | ||
``` | ||
The above command trains the autoencoder on the libero-90 dataset with a block size of 32 and a downsample factor of 4. The run directory will be created at `<output_prefix>/<benchmark_name>/<task>/<algo>/<exp_name>/<variant_name>/<seed>/<stage>`. For above command, it will be `./experiments/libero/libero_90/quest/final/block_32_ds_4/0/stage_0`. | ||
|
||
Run the following command to train QueST's stage-1 i.e. the skill-prior. (ref: [main.sh](scripts/quest/main.sh)) | ||
``` | ||
python train.py --config-name=train_prior.yaml \ | ||
task=libero_90 \ | ||
algo=quest \ | ||
exp_name=final \ | ||
variant_name=block_32_ds_4 \ | ||
algo.skill_block_size=32 \ | ||
algo.downsample_factor=4 \ | ||
training.auto_continue=true \ | ||
seed=0 | ||
``` | ||
Here, training.auto_continue will automatically load the latest checkpoint from the previous training stage. | ||
|
||
Run the following command to finetune QueST on a downstream tasks. (ref: [finetune.sh](scripts/quest/finetune.sh)) | ||
``` | ||
python train.py --config-name=train_fewshot.yaml \ | ||
task=libero_long \ | ||
algo=quest \ | ||
exp_name=final \ | ||
variant_name=block_32_ds_4 \ | ||
algo.skill_block_size=32 \ | ||
algo.downsample_factor=4 \ | ||
algo.l1_loss_scale=10 \ | ||
training.auto_continue=true \ | ||
seed=0 | ||
``` | ||
Here, algo.l1_loss_scale is used to finetune the decoder of the autoencoder while finetuning. | ||
|
||
## Evaluating | ||
Run the following command to evaluate the trained model. (ref: [eval.sh](scripts/eval.sh)) | ||
``` | ||
python evaluate.py \ | ||
task=libero_90 \ | ||
algo=quest \ | ||
exp_name=final \ | ||
variant_name=block_32_ds_4 \ | ||
stage=1 \ | ||
training.use_tqdm=false \ | ||
seed=0 | ||
``` | ||
This will automatically load the latest checkpoint as per your exp_name, variant_name, algo, and stage. Else you can specify the checkpoint_path to load a specific checkpoint. | ||
|
||
## Citation | ||
If you find this work useful, please consider citing: | ||
``` | ||
@misc{mete2024questselfsupervisedskillabstractions, | ||
title={QueST: Self-Supervised Skill Abstractions for Learning Continuous Control}, | ||
author={Atharva Mete and Haotian Xue and Albert Wilcox and Yongxin Chen and Animesh Garg}, | ||
year={2024}, | ||
eprint={2407.15840}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.RO}, | ||
url={https://arxiv.org/abs/2407.15840}, | ||
} | ||
``` | ||
|
||
## License | ||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. | ||
|
||
<hr style="border: 2px solid gray;"></hr> | ||
|
||
## Acknowledgements | ||
1. We would like to thank the authors of [LIBERO](https://lifelong-robot-learning.github.io/LIBERO/) and [MetaWorld](https://meta-world.github.io/) for providing the datasets and environments for our experiments. | ||
2. We would also like to thank the authors of our baselines [VQ-BeT](https://github.com/jayLEE0301/vq_bet_official), [ACT](), and [Diffusion Policy]() for providing the codebase for their methods; and the authors of [Robomimic]() from which we adapted the utility files for our codebase. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
defaults: | ||
- base | ||
- _self_ | ||
|
||
policy: | ||
_target_: quest.algos.act.ACT | ||
act_model: | ||
_target_: quest.algos.baseline_modules.act_utils.detr_vae.DETRVAE | ||
transformer: | ||
_target_: quest.algos.baseline_modules.act_utils.transformer.build_transformer | ||
hidden_dim: ${algo.embed_dim} | ||
dropout: 0.1 | ||
nheads: 8 | ||
dim_feedforward: ${eval:'${algo.embed_dim} * 4'} | ||
enc_layers: 4 | ||
dec_layers: 7 | ||
pre_norm: false | ||
encoder: | ||
_target_: quest.algos.baseline_modules.act_utils.detr_vae.build_encoder | ||
d_model: ${algo.embed_dim} | ||
nheads: 8 | ||
dim_feedforward: ${eval:'${algo.embed_dim} * 4'} | ||
enc_layers: 4 | ||
pre_norm: false | ||
dropout: 0.1 | ||
state_dim: ${task.shape_meta.action_dim} | ||
proprio_dim: 8 | ||
shape_meta: ${task.shape_meta} | ||
num_queries: ${algo.skill_block_size} | ||
loss_fn: | ||
_target_: torch.nn.L1Loss | ||
kl_weight: ${algo.kl_weight} | ||
lr_backbone: ${algo.lr} | ||
action_horizon: ${algo.action_horizon} | ||
obs_reduction: 'none' | ||
|
||
name: act | ||
|
||
lr: 0.0001 | ||
weight_decay: 0.0001 | ||
|
||
kl_weight: 10.0 | ||
embed_dim: 256 | ||
action_horizon: 2 | ||
|
||
skill_block_size: 16 # this is output action sequence length, for ACT 16 works better than 32 | ||
frame_stack: 1 # this is input observation sequence length | ||
|
||
dataset: | ||
seq_len: ${algo.skill_block_size} | ||
frame_stack: ${algo.frame_stack} | ||
obs_seq_len: 1 | ||
lowdim_obs_seq_len: null | ||
load_obs_for_pretrain: true | ||
load_next_obs: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
_target_: quest.algos.utils.data_augmentation.IdentityAug | ||
_partial_: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
_target_: quest.algos.utils.data_augmentation.DataAugGroup | ||
_partial_: true | ||
aug_list: | ||
- _target_: quest.algos.utils.data_augmentation.BatchWiseImgColorJitterAug | ||
_partial_: true | ||
brightness: 0.3 | ||
contrast: 0.3 | ||
saturation: 0.3 | ||
hue: 0.3 | ||
epsilon: 0.1 | ||
- _target_: quest.algos.utils.data_augmentation.TranslationAug | ||
_partial_: true | ||
translation: 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
defaults: | ||
- encoder/lowdim: mlp | ||
- encoder/image: resnet | ||
- aug: image | ||
|
||
embed_dim: 256 | ||
lowdim_embed_dim: ${algo.embed_dim} | ||
image_embed_dim: ${algo.embed_dim} | ||
|
||
policy: | ||
image_encoder_factory: ${algo.encoder.image} | ||
lowdim_encoder_factory: ${algo.encoder.lowdim} | ||
aug_factory: ${algo.aug} | ||
optimizer_factory: | ||
_target_: torch.optim.AdamW | ||
_partial_: true | ||
lr: ${algo.lr} | ||
betas: [0.9, 0.999] | ||
weight_decay: ${algo.weight_decay} | ||
scheduler_factory: | ||
_target_: torch.optim.lr_scheduler.CosineAnnealingLR | ||
_partial_: true | ||
eta_min: 1e-5 | ||
last_epoch: -1 | ||
T_max: ${training.n_epochs} | ||
embed_dim: ${algo.embed_dim} | ||
shape_meta: ${task.shape_meta} | ||
obs_reduction: 'none' | ||
device: ${device} | ||
|
||
dataset: | ||
lowdim_obs_seq_len: ${algo.skill_block_size} | ||
load_next_obs: false | ||
dataset_keys: [actions] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
defaults: | ||
- base | ||
- _self_ | ||
|
||
policy: | ||
_target_: quest.algos.bc_transformer.BCTransformerPolicy | ||
transformer_model: | ||
_target_: quest.algos.baseline_modules.bc_transformer_modules.TransformerDecoder | ||
input_size: ${algo.embed_dim} | ||
num_layers: 4 | ||
num_heads: 6 | ||
head_output_size: 64 | ||
mlp_hidden_size: 256 | ||
dropout: 0.1 | ||
policy_head: | ||
_target_: quest.algos.baseline_modules.bc_transformer_modules.GMMHead | ||
input_size: ${algo.embed_dim} | ||
output_size: ${task.shape_meta.action_dim} | ||
hidden_size: 1024 | ||
num_layers: 2 | ||
min_std: 0.0001 | ||
num_modes: 5 | ||
low_eval_noise: false | ||
activation: "softplus" | ||
loss_coef: 1.0 | ||
positional_encoding: | ||
_target_: quest.algos.baseline_modules.bc_transformer_modules.SinusoidalPositionEncoding | ||
input_size: ${algo.embed_dim} | ||
inv_freq_factor: 10 | ||
loss_reduction: 'mean' | ||
obs_reduction: stack | ||
device: ${device} | ||
|
||
name: bc_transformer_policy | ||
|
||
lr: 0.0001 | ||
weight_decay: 0.0001 | ||
|
||
embed_dim: 128 | ||
skill_block_size: 1 # bc_transformer does not do action chunking | ||
frame_stack: 10 # this is input observation sequence length | ||
|
||
dataset: | ||
seq_len: ${algo.skill_block_size} | ||
frame_stack: ${algo.frame_stack} | ||
obs_seq_len: 1 | ||
lowdim_obs_seq_len: null | ||
load_obs_for_pretrain: false | ||
load_next_obs: false |
Oops, something went wrong.