Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DROID Policy Learning/Evaluation #144

Draft
wants to merge 44 commits into
base: r2d2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c28c407
first commit, lots of WIP needed
ashwin-balakrishna96 Feb 28, 2024
851ca6e
remove act stuff
ashwin-balakrishna96 Feb 28, 2024
94fe84b
start clean
ashwin-balakrishna96 Feb 28, 2024
c0ddf13
release code
ashwin-balakrishna96 Mar 7, 2024
721ab48
fix args
ashwin-balakrishna96 Mar 7, 2024
1943abf
fix, using this as default from now on
ashwin-balakrishna96 Mar 8, 2024
c49ec03
cleanup, r2d2 --> droid
ashwin-balakrishna96 Mar 8, 2024
fe2e60c
update README, install deps, tested fresh install
ashwin-balakrishna96 Mar 8, 2024
21cab26
clean up requirements
ashwin-balakrishna96 Mar 8, 2024
03b19a8
update todos
ashwin-balakrishna96 Mar 8, 2024
fa1215f
clean
ashwin-balakrishna96 Mar 8, 2024
dfe3ed2
update README
ashwin-balakrishna96 Mar 8, 2024
bbe0c5b
clean
ashwin-balakrishna96 Mar 9, 2024
0b022fe
clean
ashwin-balakrishna96 Mar 9, 2024
56808a5
more cleanup
ashwin-balakrishna96 Mar 9, 2024
03a731d
more clean
ashwin-balakrishna96 Mar 9, 2024
e1d2e41
clean visual core
ashwin-balakrishna96 Mar 9, 2024
f814f8f
some hdf5 cleanup
suraj-nair-tri Mar 9, 2024
b21c4d7
Cleanup and removing unused files
suraj-nair-tri Mar 9, 2024
a804476
Update README.md
kpertsch Mar 9, 2024
d551ced
fix readmes
ashwin-balakrishna96 Mar 10, 2024
858f27f
fix merge
ashwin-balakrishna96 Mar 10, 2024
3b6f36c
clean more, gotta do bc transformer stuff
ashwin-balakrishna96 Mar 11, 2024
0d8c41b
remove unnecessary files
ashwin-balakrishna96 Mar 11, 2024
c96fa71
clean config
ashwin-balakrishna96 Mar 11, 2024
77239a8
clean configs
ashwin-balakrishna96 Mar 11, 2024
fe27799
clean
ashwin-balakrishna96 Mar 11, 2024
1269b0f
minor fix
ashwin-balakrishna96 Mar 11, 2024
c38876a
small fix
ashwin-balakrishna96 Mar 11, 2024
8f2a0f9
clean up install instructions
ashwin-balakrishna96 Mar 11, 2024
d55376c
octo dataloader fixes
ashwin-balakrishna96 Mar 11, 2024
51e1221
clean readme
ashwin-balakrishna96 Mar 12, 2024
f31a742
fix tiny bug
ashwin-balakrishna96 Mar 12, 2024
50d17eb
add dataloader example
ashwin-balakrishna96 Mar 12, 2024
e2df06b
fix
ashwin-balakrishna96 Mar 12, 2024
14c1c42
add in filter support
ashwin-balakrishna96 Mar 12, 2024
713de70
fix
ashwin-balakrishna96 Mar 12, 2024
1e06296
fix shuffle buffer size
ashwin-balakrishna96 Mar 12, 2024
2232921
full clean
ashwin-balakrishna96 Mar 13, 2024
cb5a935
clean out norm stuff
ashwin-balakrishna96 Mar 13, 2024
0725feb
fix comment
ashwin-balakrishna96 Mar 13, 2024
4ffb383
fix typo
ashwin-balakrishna96 Mar 13, 2024
eae359b
fix abs action
ashwin-balakrishna96 Mar 13, 2024
58e48fd
Update README.md
kpertsch Mar 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 69 additions & 61 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,88 +1,96 @@
# robomimic
# DROID Policy Learning and Evaluation

<p align="center">
<img width="24.0%" src="docs/images/task_lift.gif">
<img width="24.0%" src="docs/images/task_can.gif">
<img width="24.0%" src="docs/images/task_tool_hang.gif">
<img width="24.0%" src="docs/images/task_square.gif">
<img width="24.0%" src="docs/images/task_lift_real.gif">
<img width="24.0%" src="docs/images/task_can_real.gif">
<img width="24.0%" src="docs/images/task_tool_hang_real.gif">
<img width="24.0%" src="docs/images/task_transport.gif">
</p>
This repository contains code for training and evaluating policies on the DROID(TODO(Karl/Sasha): add link to DROID website) dataset. DROID is a large-scale, in-the-wild robot manipulation dataset. This codebase is built as a fork of [`robomimic`](https://robomimic.github.io/), a popular repository for imitation learning algorithm development. For more information about DROID, please see the following links:

[**[Homepage]**](https://robomimic.github.io/) &ensp; [**[Documentation]**](https://robomimic.github.io/docs/introduction/overview.html) &ensp; [**[Study Paper]**](https://arxiv.org/abs/2108.03298) &ensp; [**[Study Website]**](https://robomimic.github.io/study/) &ensp; [**[ARISE Initiative]**](https://github.com/ARISE-Initiative)
[**[Homepage]**](XXX) &ensp; [**[Documentation]**](XXX) &ensp; [**[Paper]**](XXX) &ensp; [**[Dataset Visualizer]**](XXX).

-------
## Latest Updates
- [05/23/2022] **v0.2.1**: Updated website and documentation to feature more tutorials :notebook_with_decorative_cover:
- [12/16/2021] **v0.2.0**: Modular observation modalities and encoders :wrench:, support for [MOMART](https://sites.google.com/view/il-for-mm/home) datasets :open_file_folder: [[release notes]](https://github.com/ARISE-Initiative/robomimic/releases/tag/v0.2.0) [[documentation]](https://robomimic.github.io/docs/v0.2/introduction/overview.html)
- [08/09/2021] **v0.1.0**: Initial code and paper release
## Installation
Create a python3 conda environment (tested with Python 3.10) and run the following:

-------
1. Create python 3.10 conda environment: `conda create --name droid_policy_learning python=3.10`
2. Activate the conda environment: `conda activate droid_policy_learning`
3. Install [octo](https://github.com/octo-models/octo) (used for data loading)
4. Run `pip install -e .` in `robomimic`. Make sure you are on the `r2d2` branch.

## Colab quickstart
Get started with a quick colab notebook demo of robomimic with installing anything locally.
With this you are all set up for training policies on DROID. If you want to evaluate your policies on a real robot DROID setup,
please install the DROID robot controller in the same conda environment (follow the instructions [here](https://github.com/AlexanderKhazatsky/DROID)).

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1b62r_km9pP40fKF0cBdpdTO2P_2eIbC6?usp=sharing)
-------
## Preparing Datasets
We provide all DROID datasets in RLDS format, which makes it easy to co-train with various other robot-learning datasets (such as those in the [Open X-Embodiment](https://robotics-transformer-x.github.io/)).

To download the DROID dataset from the Google cloud bucket, install the [gsutil package](https://cloud.google.com/storage/docs/gsutil_install) and run the following command (Note: the full dataset is XXX TB in size):
```
gsutil -m cp -r XXX <path_to_your_target_dir>
```

-------
We also provide a small (2GB) example dataset with 100 DROID trajectories that uses the same format as the full RLDS dataset and can be used for code prototyping and debugging:
```
gsutil -m cp -r XXX <path_to_your_target_dir>
```

**robomimic** is a framework for robot learning from demonstration.
It offers a broad set of demonstration datasets collected on robot manipulation domains and offline learning algorithms to learn from these datasets.
**robomimic** aims to make robot learning broadly *accessible* and *reproducible*, allowing researchers and practitioners to benchmark tasks and algorithms fairly and to develop the next generation of robot learning algorithms.
For good performance of DROID policies in your target setting, it is helpful to include a small number of demonstrations in your target domain into the training mix ("co-training").
Please follow the instructions [here](XXX) for collecting a small teleoperated dataset in your target domain and converting it to the RLDS training format.
Make sure that all datasets you want to train on are under the same root directory `DATA_PATH`.

## Core Features
-------
## Training
To train policies, update `DATA_PATH`, `EXP_LOG_PATH`, and `EXP_NAMES` in `robomimic/scripts/config_gen/droid_runs_language_conditioned_rlds.py` and then run:

<p align="center">
<img width="50.0%" src="docs/images/core_features.png">
</p>
`python robomimic/scripts/config_gen/droid_runs_language_conditioned_rlds.py --wandb_proj_name <WANDB_PROJ_NAME>`

<!-- **Standardized Datasets**
- Simulated and real-world tasks
- Multiple environments and robots
- Diverse human-collected and machine-generated datasets
This will generate a python command that can be run to launch training. You can also update other training parameters within `robomimic/scripts/config_gen/droid_runs_language_conditioned_rlds.py`. Please see the `robomimic` documentation for more information on how `robomimic` configs are defined. The three
most important parameters in this file are:

**Suite of Learning Algorithms**
- Imitation Learning algorithms (BC, BC-RNN, HBC)
- Offline RL algorithms (BCQ, CQL, IRIS, TD3-BC)
- `DATA_PATH`: This is the directory in which all RLDS datasets were prepared.
- `EXP_LOG_PATH`: This is the path at which experimental data (eg. policy checkpoints) will be stored.
- `EXP_NAMES`: This defines the name of each experiment (as will be logged in `wandb`), the RLDS datasets corresponding to that experiment, and the desired sample weights between those datasets. See `robomimic/scripts/config_gen/droid_runs_language_conditioned_rlds.py` for a template on how this should be formatted.

**Modular Design**
- Low-dim + Visuomotor policies
- Diverse network architectures
- Support for external datasets
During training, we use a [_shuffle buffer_](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle) to ensure that training samples are properly randomized. It is important to use a large enough shuffle buffer size.
The default `shuffle_buffer_size` is set to `500000`, but you may need to reduce this based on your RAM availability. For best results, we recommend using `shuffle_buffer_size >= 100000` if possible. All polices were trained on a single NVIDIA A100 GPU.

**Flexible Workflow**
- Hyperparameter sweep tools
- Dataset visualization tools
- Generating new datasets -->
To specify your information for Weights and Biases logging, make sure to update the `WANDB_ENTITY` and `WANDB_API_KEY` values in `robomimic/macros.py`.

We also provide a stand-alone example to load data from DROID [here](examples/droid_dataloader.py).

## Reproducing benchmarks
-------
## Code Structure

| | File | Description |
|---------------------------|---------------------------------------------------------|-------------------------------------------------------------------------------|
| Hyperparameters | [droid_runs_language_conditioned_rlds.py](robomimic/scripts/config_gen/droid_runs_language_conditioned_rlds.py) | Generates a config based on defined hyperparameters |
| Training Loop | [train.py](robomimic/scripts/train.py) | Main training script. |
| Datasets | [dataset.py](octo/data/dataset.py) | Functions for creating datasets and computing dataset statistics, |
| RLDS Data Processing | [rlds_utils.py](robomimic/utils/rlds_utils.py) | Processing to convert RLDS dataset into dataset compatible for DROID training |
| General Algorithm Class | [algo.py](robomimic/algo/algo.py) | Defines a high level template for all algorithms (eg. diffusion policy) to extend |
| Diffusion Policy | [diffusion_policy.py](robomimic/algo/diffusion_policy.py) | Implementation of diffusion policy |
| Observation Processing | [obs_nets.py](robomimic/models/obs_nets.py) | General observation pre-processing/encoding |
| Visualization | [vis_utils.py](robomimic/utils/vis_utils.py) | Utilities for generating trajectory visualizations |

The robomimic framework also makes reproducing the results from different benchmarks and datasets easy. See the [datasets page](https://robomimic.github.io/docs/datasets/overview.html) for more information on downloading datasets and reproducing experiments.
-------

## Troubleshooting
## Evaluating Trained Policies
To evaluate policies, make sure that you additionally install [DROID](https://github.com/AlexanderKhazatsky/DROID) in your conda environment and then run:
```python
python scripts/evaluation/evaluate_policy.py
```
from the DROID root directory. Make sure to use the appropriate command line arguments for the model checkpoint path and whether to do goal or language conditioning, and then follow
all resulting prompts in the terminal. To replicate experiments from the paper, use the language conditioning mode.

Please see the [troubleshooting](https://robomimic.github.io/docs/miscellaneous/troubleshooting.html) section for common fixes, or [submit an issue](https://github.com/ARISE-Initiative/robomimic/issues) on our github page.
-------

## Contributing to robomimic
This project is part of the broader [Advancing Robot Intelligence through Simulated Environments (ARISE) Initiative](https://github.com/ARISE-Initiative), with the aim of lowering the barriers of entry for cutting-edge research at the intersection of AI and Robotics.
The project originally began development in late 2018 by researchers in the [Stanford Vision and Learning Lab](http://svl.stanford.edu/) (SVL).
Now it is actively maintained and used for robotics research projects across multiple labs.
We welcome community contributions to this project.
For details please check our [contributing guidelines](https://robomimic.github.io/docs/miscellaneous/contributing.html).
## Training Policies with HDF5 Format
Natively, robomimic uses HDF5 files to store and load data. While we mainly support RLDS as the data format for training with DROID, [here](robomimic/README_hdf5.md) are instructions for how to run training with the HDF5 data format.

------------
## Citation

Please cite [this paper](https://arxiv.org/abs/2108.03298) if you use this framework in your work:

```bibtex
@inproceedings{robomimic2021,
title={What Matters in Learning from Offline Human Demonstrations for Robot Manipulation},
author={Ajay Mandlekar and Danfei Xu and Josiah Wong and Soroush Nasiriany and Chen Wang and Rohun Kulkarni and Li Fei-Fei and Silvio Savarese and Yuke Zhu and Roberto Mart\'{i}n-Mart\'{i}n},
booktitle={Conference on Robot Learning (CoRL)},
year={2021}
```
@misc{droid_2024,
title={DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
author = {XXX},
howpublished = {\url{XXX}},
year = {2024},
}
```
38 changes: 38 additions & 0 deletions README_hdf5.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
### Convert from raw DROID data format to HDF5

First you need to make sure to install the ZED SDK, follow the instructions [here](https://www.stereolabs.com/docs/installation/linux/) for your CUDA version and the accompanying `pyzed` package. Then run
`python robomimic/scripts/conversion/convert_droid.py --folder <PATH_TO_DROID_DATA_FOLDER> --imsize 128`
which will populate each demo folder with an HDF5 file `trajectory_im128.h5` which contains the full observations and actions for that demo.

### Composing a manifest file
You may want to subselect certain demos to train on. As a result, we assume that you define a manifest json file which contains a list of demos, including the path to each H5 file and the associated language instruction. For example:
```
[
{
"path": "/fullpathA/trajectory_im128.h5",
"lang": "Put the apple on the plate"
},
{
"path": "/fullpathB/trajectory_im128.h5",
"lang": "Move the fork to the sink"
},
...
]
```

### Adding language embeddings to HDF5
For the files and language specified in the above manifest JSON, run:
`python robomimic/scripts/conversion/add_lang_to_converted_data.py --manifest_file <PATH_TO_MANIFEST_FILE> --imsize 128`
to compute DistilBERT embeddings of each language instruction and add it as an observation key to the HDF5.

### Run training
To train policies, update `MANIFEST_PATH`, `EXP_LOG_PATH`, in `robomimic/scripts/config_gen/droid_runs_language_conditioned.py` and then run:

`python robomimic/scripts/config_gen/droid_runs_language_conditioned.py --wandb_proj_name <WANDB_PROJ_NAME>`

This will generate a python command that can be run to launch training. You can also update other training parameters within `robomimic/scripts/config_gen/droid_runs_language_conditioned_rlds.py`. Please see the `robomimic` documentation for more information on how `robomimic` configs are defined. The three
most important parameters in this file are:

- `MANIFEST_PATH`: This is the manifest JSON for the training data you want to use.
- `MANIFEST_2_PATH`: You can optionally set a second manfiest for another dataset to do 50-50 co-training with.
- `EXP_LOG_PATH`: This is the path at which experimental data (eg. policy checkpoints) will be stored.
94 changes: 94 additions & 0 deletions examples/droid_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import tqdm
import torch
from torch.utils.data import DataLoader
import tensorflow as tf

from robomimic.utils.rlds_utils import droid_dataset_transform, robomimic_transform, DROID_TO_RLDS_OBS_KEY_MAP, DROID_TO_RLDS_LOW_DIM_OBS_KEY_MAP, TorchRLDSDataset
import robomimic.utils.action_utils as ActionUtils
from robomimic.utils.dataset import action_stats_to_normalization_stats

from octo.data.dataset import make_dataset_from_rlds, make_interleaved_dataset
from octo.data.utils.data_utils import combine_dataset_statistics
from octo.utils.spec import ModuleSpec

tf.config.set_visible_devices([], "GPU")

# ------------------------------ Get Dataset Information ------------------------------
DATA_PATH = "" # UPDATE WITH PATH TO RLDS DATASETS
DATASET_NAMES = ["droid"] # You can add additional co-training datasets here
sample_weights = [1] # Add to this if you add additional co-training datasets

# ------------------------------ Get Observation Information ------------------------------
obs_modalities = ["camera/image/varied_camera_1_left_image", "camera/image/varied_camera_2_left_image"]
obs_low_dim_modalities = ["robot_state/cartesian_position", "robot_state/gripper_position"]

# ------------------------------ Get Action Information ------------------------------
is_abs_action = [True] * 10

# ------------------------------ Construct Dataset ------------------------------
BASE_DATASET_KWARGS = {
"data_dir": DATA_PATH,
"image_obs_keys": {"primary": DROID_TO_RLDS_OBS_KEY_MAP[obs_modalities[0]], "secondary": DROID_TO_RLDS_OBS_KEY_MAP[obs_modalities[1]]},
"state_obs_keys": [DROID_TO_RLDS_LOW_DIM_OBS_KEY_MAP[obs_key] for obs_key in obs_low_dim_modalities],
"language_key": "language_instruction",
"norm_skip_keys": ["proprio"],
"action_proprio_normalization_type": "bounds",
"absolute_action_mask": is_abs_action,
"action_normalization_mask": is_abs_action,
"standardize_fn": droid_dataset_transform,
}

filter_functions = [[ModuleSpec.create(
"robomimic.utils.rlds_utils:filter_success"
)] if d_name == "droid" else [] \
for d_name in DATASET_NAMES]
dataset_kwargs_list = [
{"name": d_name, "filter_functions": f_functions, **BASE_DATASET_KWARGS} for d_name, f_functions in zip(DATASET_NAMES, filter_functions)
]

# Compute combined normalization stats. Note: can also set this to None to normalize each dataset separately
combined_dataset_statistics = combine_dataset_statistics(
[make_dataset_from_rlds(**dataset_kwargs, train=True)[1] for dataset_kwargs in dataset_kwargs_list]
)

dataset = make_interleaved_dataset(
dataset_kwargs_list,
sample_weights,
train=True,
shuffle_buffer_size=100000,
batch_size=None, # batching will be handled in PyTorch Dataloader object
balance_weights=False,
dataset_statistics=combined_dataset_statistics,
traj_transform_kwargs=dict(
window_size=2,
future_action_window_size=15,
subsample_length=100,
skip_unlabeled=True, # skip all trajectories without language
),
frame_transform_kwargs=dict(
image_augment_kwargs=dict(
),
resize_size=dict(
primary=[128, 128],
secondary=[128, 128],
),
num_parallel_calls=200,
),
traj_transform_threads=48,
traj_read_threads=48,
)

dataset = dataset.map(robomimic_transform, num_parallel_calls=48)

# ------------------------------ Create Dataloader ------------------------------

pytorch_dataset = TorchRLDSDataset(dataset)
train_loader = DataLoader(
pytorch_dataset,
batch_size=128,
num_workers=0, # important to keep this to 0 so PyTorch does not mess with the parallelism
)

for i, sample in tqdm.tqdm(enumerate(train_loader)):
if i == 5000:
break
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,8 @@ imageio
imageio-ffmpeg
matplotlib
egl_probe>=1.0.1
torch
torch==2.0.1
torchvision
diffusers==0.11.1
opencv-python
transformers==4.34.0
1 change: 1 addition & 0 deletions robomimic/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from robomimic.algo.hbc import HBC
from robomimic.algo.iris import IRIS
from robomimic.algo.td3_bc import TD3_BC
from robomimic.algo.diffusion_policy import DiffusionPolicyUNet
Loading