Skip to content

Commit

Permalink
QueST Release
Browse files Browse the repository at this point in the history
  • Loading branch information
atharvamete committed Sep 15, 2024
0 parents commit 9315a24
Show file tree
Hide file tree
Showing 82 changed files with 11,973 additions and 0 deletions.
62 changes: 62 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Ignore compiled files
*.class
*.o
*.pyc

# Ignore build output
build/
dist/

*~
.hydra
*pyc
*.egg-info/
datasets
datasets_scripted
*videos
*.mp4
results
outputs/*
/MUJOCO_LOG.TXT
wandb
experiments/
experiments_clip/*
evaluations_clip/*
experiments_finetune/*
eval_outs
pace_models/*
*.zip
*.out
*.hdf5
*.pth
experiments_saved/
ppo_experiments/
slurm/
# Ignore IDE and editor files
.idea/
.vscode/
*.sublime-project
*.sublime-workspace

# Ignore dependencies
node_modules/
vendor/

# Ignore logs and temporary files
*.log
*.tmp

# Ignore sensitive information
config.ini
secrets.json

# Ignore generated documentation
docs/

# Ignore OS-specific files
.DS_Store
Thumbs.db

data
experiments
zplots
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 Atharva Mete

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
133 changes: 133 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# QueST: Self-Supervised Skill Abstractions for Continuous Control

Atharva Mete, Haotian Xue, Albert Wilcox, Yongxin Chen, Animesh Garg

[![Static Badge](https://img.shields.io/badge/Project-Page-green?style=for-the-badge)](https://quest-model.github.io/)
[![arXiv](https://img.shields.io/badge/arXiv-2406.09246-df2a2a.svg?style=for-the-badge)](https://arxiv.org/abs/2407.15840)
[![PyTorch](https://img.shields.io/badge/PyTorch-2.2.0-EE4C2C.svg?style=for-the-badge&logo=pytorch)](https://pytorch.org/get-started/locally/)
[![Python](https://img.shields.io/badge/python-3.10-blue?style=for-the-badge)](https://www.python.org)
[![License](https://img.shields.io/github/license/TRI-ML/prismatic-vlms?style=for-the-badge)](LICENSE)

[**Installation**](#installation) | [**Dataset Download**](#dataset-download) | [**Training**](#training) | [**Evaluation**](#evaluating) | [**Project Website**](https://quest-model.github.io/)


<hr style="border: 2px solid gray;"></hr>

## Latest Updates
- [2024-09-09] Initial release

<hr style="border: 2px solid gray;"></hr>

## Installation

Please run the following commands in the given order to install the dependency for QueST
```
conda create -n quest python=3.10.14
conda activate quest
git clone https://github.com/atharvamete/QueST.git
cd quest
python -m pip install torch==2.2.0 torchvision==0.17.0
python -m pip install -e .
```
Note: Above automatically installs metaworld as python packages

Install LIBERO seperately
```
git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git
cd LIBERO
python -m pip install -e .
```
Note: All LIBERO dependencies are already included in quest/requirements.txt

## Dataset Download
LIBERO: Please download the libero data seperately following their [docs](https://lifelong-robot-learning.github.io/LIBERO/html/algo_data/datasets.html#datasets).

MetaWorld: We have provided the script we used to collect the data using scripted policies in the MetaWorld package. Please run the following command to collect the data. This uses configs as per [collect_data.yaml](config/collect_data.yaml).
```
python scripts/generate_metaworld_dataset.py
```
We generate 100 demonstrations for each of 45 pretraining tasks and 5 for downstream tasks.

## Training
First set the path to the dataset `data_prefix` and `output_prefix` in [train_base](config/train_base.yaml). `output_prefix` is where all the logs and checkpoints will be stored.

We provide detailed sample commands for training all stages and for all baselines in the [scripts](scripts) directory. For all methods, [autoencoder.sh](scripts/quest/autoencoder.sh) trains the autoencoder (only used in QueST and VQ-BeT), [main.sh](scripts/quest/main.sh) trains the main algorithm (skill-prior incase of QueST), and [finetune.sh](scripts/quest/finetune.sh) finetunes the model on downstream tasks.

Run the following command to train QueST's stage-0 i.e. the autoencoder. (ref: [autoencoder.sh](scripts/quest/autoencoder.sh))
```
python train.py --config-name=train_autoencoder.yaml \
task=libero_90 \
algo=quest \
exp_name=final \
variant_name=block_32_ds_4 \
algo.skill_block_size=32 \
algo.downsample_factor=4 \
seed=0
```
The above command trains the autoencoder on the libero-90 dataset with a block size of 32 and a downsample factor of 4. The run directory will be created at `<output_prefix>/<benchmark_name>/<task>/<algo>/<exp_name>/<variant_name>/<seed>/<stage>`. For above command, it will be `./experiments/libero/libero_90/quest/final/block_32_ds_4/0/stage_0`.

Run the following command to train QueST's stage-1 i.e. the skill-prior. (ref: [main.sh](scripts/quest/main.sh))
```
python train.py --config-name=train_prior.yaml \
task=libero_90 \
algo=quest \
exp_name=final \
variant_name=block_32_ds_4 \
algo.skill_block_size=32 \
algo.downsample_factor=4 \
training.auto_continue=true \
seed=0
```
Here, training.auto_continue will automatically load the latest checkpoint from the previous training stage.

Run the following command to finetune QueST on a downstream tasks. (ref: [finetune.sh](scripts/quest/finetune.sh))
```
python train.py --config-name=train_fewshot.yaml \
task=libero_long \
algo=quest \
exp_name=final \
variant_name=block_32_ds_4 \
algo.skill_block_size=32 \
algo.downsample_factor=4 \
algo.l1_loss_scale=10 \
training.auto_continue=true \
seed=0
```
Here, algo.l1_loss_scale is used to finetune the decoder of the autoencoder while finetuning.

## Evaluating
Run the following command to evaluate the trained model. (ref: [eval.sh](scripts/eval.sh))
```
python evaluate.py \
task=libero_90 \
algo=quest \
exp_name=final \
variant_name=block_32_ds_4 \
stage=1 \
training.use_tqdm=false \
seed=0
```
This will automatically load the latest checkpoint as per your exp_name, variant_name, algo, and stage. Else you can specify the checkpoint_path to load a specific checkpoint.

## Citation
If you find this work useful, please consider citing:
```
@misc{mete2024questselfsupervisedskillabstractions,
title={QueST: Self-Supervised Skill Abstractions for Learning Continuous Control},
author={Atharva Mete and Haotian Xue and Albert Wilcox and Yongxin Chen and Animesh Garg},
year={2024},
eprint={2407.15840},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2407.15840},
}
```

## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

<hr style="border: 2px solid gray;"></hr>

## Acknowledgements
1. We would like to thank the authors of [LIBERO](https://lifelong-robot-learning.github.io/LIBERO/) and [MetaWorld](https://meta-world.github.io/) for providing the datasets and environments for our experiments.
2. We would also like to thank the authors of our baselines [VQ-BeT](https://github.com/jayLEE0301/vq_bet_official), [ACT](), and [Diffusion Policy]() for providing the codebase for their methods; and the authors of [Robomimic]() from which we adapted the utility files for our codebase.
55 changes: 55 additions & 0 deletions config/algo/act.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
defaults:
- base
- _self_

policy:
_target_: quest.algos.act.ACT
act_model:
_target_: quest.algos.baseline_modules.act_utils.detr_vae.DETRVAE
transformer:
_target_: quest.algos.baseline_modules.act_utils.transformer.build_transformer
hidden_dim: ${algo.embed_dim}
dropout: 0.1
nheads: 8
dim_feedforward: ${eval:'${algo.embed_dim} * 4'}
enc_layers: 4
dec_layers: 7
pre_norm: false
encoder:
_target_: quest.algos.baseline_modules.act_utils.detr_vae.build_encoder
d_model: ${algo.embed_dim}
nheads: 8
dim_feedforward: ${eval:'${algo.embed_dim} * 4'}
enc_layers: 4
pre_norm: false
dropout: 0.1
state_dim: ${task.shape_meta.action_dim}
proprio_dim: 8
shape_meta: ${task.shape_meta}
num_queries: ${algo.skill_block_size}
loss_fn:
_target_: torch.nn.L1Loss
kl_weight: ${algo.kl_weight}
lr_backbone: ${algo.lr}
action_horizon: ${algo.action_horizon}
obs_reduction: 'none'

name: act

lr: 0.0001
weight_decay: 0.0001

kl_weight: 10.0
embed_dim: 256
action_horizon: 2

skill_block_size: 16 # this is output action sequence length, for ACT 16 works better than 32
frame_stack: 1 # this is input observation sequence length

dataset:
seq_len: ${algo.skill_block_size}
frame_stack: ${algo.frame_stack}
obs_seq_len: 1
lowdim_obs_seq_len: null
load_obs_for_pretrain: true
load_next_obs: false
2 changes: 2 additions & 0 deletions config/algo/aug/identity.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: quest.algos.utils.data_augmentation.IdentityAug
_partial_: true
13 changes: 13 additions & 0 deletions config/algo/aug/image.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_target_: quest.algos.utils.data_augmentation.DataAugGroup
_partial_: true
aug_list:
- _target_: quest.algos.utils.data_augmentation.BatchWiseImgColorJitterAug
_partial_: true
brightness: 0.3
contrast: 0.3
saturation: 0.3
hue: 0.3
epsilon: 0.1
- _target_: quest.algos.utils.data_augmentation.TranslationAug
_partial_: true
translation: 4
34 changes: 34 additions & 0 deletions config/algo/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
defaults:
- encoder/lowdim: mlp
- encoder/image: resnet
- aug: image

embed_dim: 256
lowdim_embed_dim: ${algo.embed_dim}
image_embed_dim: ${algo.embed_dim}

policy:
image_encoder_factory: ${algo.encoder.image}
lowdim_encoder_factory: ${algo.encoder.lowdim}
aug_factory: ${algo.aug}
optimizer_factory:
_target_: torch.optim.AdamW
_partial_: true
lr: ${algo.lr}
betas: [0.9, 0.999]
weight_decay: ${algo.weight_decay}
scheduler_factory:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
_partial_: true
eta_min: 1e-5
last_epoch: -1
T_max: ${training.n_epochs}
embed_dim: ${algo.embed_dim}
shape_meta: ${task.shape_meta}
obs_reduction: 'none'
device: ${device}

dataset:
lowdim_obs_seq_len: ${algo.skill_block_size}
load_next_obs: false
dataset_keys: [actions]
49 changes: 49 additions & 0 deletions config/algo/bc_transformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
defaults:
- base
- _self_

policy:
_target_: quest.algos.bc_transformer.BCTransformerPolicy
transformer_model:
_target_: quest.algos.baseline_modules.bc_transformer_modules.TransformerDecoder
input_size: ${algo.embed_dim}
num_layers: 4
num_heads: 6
head_output_size: 64
mlp_hidden_size: 256
dropout: 0.1
policy_head:
_target_: quest.algos.baseline_modules.bc_transformer_modules.GMMHead
input_size: ${algo.embed_dim}
output_size: ${task.shape_meta.action_dim}
hidden_size: 1024
num_layers: 2
min_std: 0.0001
num_modes: 5
low_eval_noise: false
activation: "softplus"
loss_coef: 1.0
positional_encoding:
_target_: quest.algos.baseline_modules.bc_transformer_modules.SinusoidalPositionEncoding
input_size: ${algo.embed_dim}
inv_freq_factor: 10
loss_reduction: 'mean'
obs_reduction: stack
device: ${device}

name: bc_transformer_policy

lr: 0.0001
weight_decay: 0.0001

embed_dim: 128
skill_block_size: 1 # bc_transformer does not do action chunking
frame_stack: 10 # this is input observation sequence length

dataset:
seq_len: ${algo.skill_block_size}
frame_stack: ${algo.frame_stack}
obs_seq_len: 1
lowdim_obs_seq_len: null
load_obs_for_pretrain: false
load_next_obs: false
Loading

0 comments on commit 9315a24

Please sign in to comment.