diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6a7f13e --- /dev/null +++ b/.gitignore @@ -0,0 +1,62 @@ +# Ignore compiled files +*.class +*.o +*.pyc + +# Ignore build output +build/ +dist/ + +*~ +.hydra +*pyc +*.egg-info/ +datasets +datasets_scripted +*videos +*.mp4 +results +outputs/* +/MUJOCO_LOG.TXT +wandb +experiments/ +experiments_clip/* +evaluations_clip/* +experiments_finetune/* +eval_outs +pace_models/* +*.zip +*.out +*.hdf5 +*.pth +experiments_saved/ +ppo_experiments/ +slurm/ +# Ignore IDE and editor files +.idea/ +.vscode/ +*.sublime-project +*.sublime-workspace + +# Ignore dependencies +node_modules/ +vendor/ + +# Ignore logs and temporary files +*.log +*.tmp + +# Ignore sensitive information +config.ini +secrets.json + +# Ignore generated documentation +docs/ + +# Ignore OS-specific files +.DS_Store +Thumbs.db + +data +experiments +zplots \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9b1f94e --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Atharva Mete + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..b5f3c45 --- /dev/null +++ b/README.md @@ -0,0 +1,133 @@ +# QueST: Self-Supervised Skill Abstractions for Continuous Control + +Atharva Mete, Haotian Xue, Albert Wilcox, Yongxin Chen, Animesh Garg + +[![Static Badge](https://img.shields.io/badge/Project-Page-green?style=for-the-badge)](https://quest-model.github.io/) +[![arXiv](https://img.shields.io/badge/arXiv-2406.09246-df2a2a.svg?style=for-the-badge)](https://arxiv.org/abs/2407.15840) +[![PyTorch](https://img.shields.io/badge/PyTorch-2.2.0-EE4C2C.svg?style=for-the-badge&logo=pytorch)](https://pytorch.org/get-started/locally/) +[![Python](https://img.shields.io/badge/python-3.10-blue?style=for-the-badge)](https://www.python.org) +[![License](https://img.shields.io/github/license/TRI-ML/prismatic-vlms?style=for-the-badge)](LICENSE) + +[**Installation**](#installation) | [**Dataset Download**](#dataset-download) | [**Training**](#training) | [**Evaluation**](#evaluating) | [**Project Website**](https://quest-model.github.io/) + + +
+ +## Latest Updates +- [2024-09-09] Initial release + +
+ +## Installation + +Please run the following commands in the given order to install the dependency for QueST +``` +conda create -n quest python=3.10.14 +conda activate quest +git clone https://github.com/atharvamete/QueST.git +cd quest +python -m pip install torch==2.2.0 torchvision==0.17.0 +python -m pip install -e . +``` +Note: Above automatically installs metaworld as python packages + +Install LIBERO seperately +``` +git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git +cd LIBERO +python -m pip install -e . +``` +Note: All LIBERO dependencies are already included in quest/requirements.txt + +## Dataset Download +LIBERO: Please download the libero data seperately following their [docs](https://lifelong-robot-learning.github.io/LIBERO/html/algo_data/datasets.html#datasets). + +MetaWorld: We have provided the script we used to collect the data using scripted policies in the MetaWorld package. Please run the following command to collect the data. This uses configs as per [collect_data.yaml](config/collect_data.yaml). +``` +python scripts/generate_metaworld_dataset.py +``` +We generate 100 demonstrations for each of 45 pretraining tasks and 5 for downstream tasks. + +## Training +First set the path to the dataset `data_prefix` and `output_prefix` in [train_base](config/train_base.yaml). `output_prefix` is where all the logs and checkpoints will be stored. + +We provide detailed sample commands for training all stages and for all baselines in the [scripts](scripts) directory. For all methods, [autoencoder.sh](scripts/quest/autoencoder.sh) trains the autoencoder (only used in QueST and VQ-BeT), [main.sh](scripts/quest/main.sh) trains the main algorithm (skill-prior incase of QueST), and [finetune.sh](scripts/quest/finetune.sh) finetunes the model on downstream tasks. + +Run the following command to train QueST's stage-0 i.e. the autoencoder. (ref: [autoencoder.sh](scripts/quest/autoencoder.sh)) +``` +python train.py --config-name=train_autoencoder.yaml \ + task=libero_90 \ + algo=quest \ + exp_name=final \ + variant_name=block_32_ds_4 \ + algo.skill_block_size=32 \ + algo.downsample_factor=4 \ + seed=0 +``` +The above command trains the autoencoder on the libero-90 dataset with a block size of 32 and a downsample factor of 4. The run directory will be created at `///////`. For above command, it will be `./experiments/libero/libero_90/quest/final/block_32_ds_4/0/stage_0`. + +Run the following command to train QueST's stage-1 i.e. the skill-prior. (ref: [main.sh](scripts/quest/main.sh)) +``` +python train.py --config-name=train_prior.yaml \ + task=libero_90 \ + algo=quest \ + exp_name=final \ + variant_name=block_32_ds_4 \ + algo.skill_block_size=32 \ + algo.downsample_factor=4 \ + training.auto_continue=true \ + seed=0 +``` +Here, training.auto_continue will automatically load the latest checkpoint from the previous training stage. + +Run the following command to finetune QueST on a downstream tasks. (ref: [finetune.sh](scripts/quest/finetune.sh)) +``` +python train.py --config-name=train_fewshot.yaml \ + task=libero_long \ + algo=quest \ + exp_name=final \ + variant_name=block_32_ds_4 \ + algo.skill_block_size=32 \ + algo.downsample_factor=4 \ + algo.l1_loss_scale=10 \ + training.auto_continue=true \ + seed=0 +``` +Here, algo.l1_loss_scale is used to finetune the decoder of the autoencoder while finetuning. + +## Evaluating +Run the following command to evaluate the trained model. (ref: [eval.sh](scripts/eval.sh)) +``` +python evaluate.py \ + task=libero_90 \ + algo=quest \ + exp_name=final \ + variant_name=block_32_ds_4 \ + stage=1 \ + training.use_tqdm=false \ + seed=0 +``` +This will automatically load the latest checkpoint as per your exp_name, variant_name, algo, and stage. Else you can specify the checkpoint_path to load a specific checkpoint. + +## Citation +If you find this work useful, please consider citing: +``` +@misc{mete2024questselfsupervisedskillabstractions, + title={QueST: Self-Supervised Skill Abstractions for Learning Continuous Control}, + author={Atharva Mete and Haotian Xue and Albert Wilcox and Yongxin Chen and Animesh Garg}, + year={2024}, + eprint={2407.15840}, + archivePrefix={arXiv}, + primaryClass={cs.RO}, + url={https://arxiv.org/abs/2407.15840}, +} +``` + +## License +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. + +
+ +## Acknowledgements +1. We would like to thank the authors of [LIBERO](https://lifelong-robot-learning.github.io/LIBERO/) and [MetaWorld](https://meta-world.github.io/) for providing the datasets and environments for our experiments. +2. We would also like to thank the authors of our baselines [VQ-BeT](https://github.com/jayLEE0301/vq_bet_official), [ACT](), and [Diffusion Policy]() for providing the codebase for their methods; and the authors of [Robomimic]() from which we adapted the utility files for our codebase. \ No newline at end of file diff --git a/config/algo/act.yaml b/config/algo/act.yaml new file mode 100644 index 0000000..fd444c9 --- /dev/null +++ b/config/algo/act.yaml @@ -0,0 +1,55 @@ +defaults: + - base + - _self_ + +policy: + _target_: quest.algos.act.ACT + act_model: + _target_: quest.algos.baseline_modules.act_utils.detr_vae.DETRVAE + transformer: + _target_: quest.algos.baseline_modules.act_utils.transformer.build_transformer + hidden_dim: ${algo.embed_dim} + dropout: 0.1 + nheads: 8 + dim_feedforward: ${eval:'${algo.embed_dim} * 4'} + enc_layers: 4 + dec_layers: 7 + pre_norm: false + encoder: + _target_: quest.algos.baseline_modules.act_utils.detr_vae.build_encoder + d_model: ${algo.embed_dim} + nheads: 8 + dim_feedforward: ${eval:'${algo.embed_dim} * 4'} + enc_layers: 4 + pre_norm: false + dropout: 0.1 + state_dim: ${task.shape_meta.action_dim} + proprio_dim: 8 + shape_meta: ${task.shape_meta} + num_queries: ${algo.skill_block_size} + loss_fn: + _target_: torch.nn.L1Loss + kl_weight: ${algo.kl_weight} + lr_backbone: ${algo.lr} + action_horizon: ${algo.action_horizon} + obs_reduction: 'none' + +name: act + +lr: 0.0001 +weight_decay: 0.0001 + +kl_weight: 10.0 +embed_dim: 256 +action_horizon: 2 + +skill_block_size: 16 # this is output action sequence length, for ACT 16 works better than 32 +frame_stack: 1 # this is input observation sequence length + +dataset: + seq_len: ${algo.skill_block_size} + frame_stack: ${algo.frame_stack} + obs_seq_len: 1 + lowdim_obs_seq_len: null + load_obs_for_pretrain: true + load_next_obs: false diff --git a/config/algo/aug/identity.yaml b/config/algo/aug/identity.yaml new file mode 100644 index 0000000..7b32ac1 --- /dev/null +++ b/config/algo/aug/identity.yaml @@ -0,0 +1,2 @@ +_target_: quest.algos.utils.data_augmentation.IdentityAug +_partial_: true \ No newline at end of file diff --git a/config/algo/aug/image.yaml b/config/algo/aug/image.yaml new file mode 100644 index 0000000..009189b --- /dev/null +++ b/config/algo/aug/image.yaml @@ -0,0 +1,13 @@ +_target_: quest.algos.utils.data_augmentation.DataAugGroup +_partial_: true +aug_list: + - _target_: quest.algos.utils.data_augmentation.BatchWiseImgColorJitterAug + _partial_: true + brightness: 0.3 + contrast: 0.3 + saturation: 0.3 + hue: 0.3 + epsilon: 0.1 + - _target_: quest.algos.utils.data_augmentation.TranslationAug + _partial_: true + translation: 4 \ No newline at end of file diff --git a/config/algo/base.yaml b/config/algo/base.yaml new file mode 100644 index 0000000..c49561d --- /dev/null +++ b/config/algo/base.yaml @@ -0,0 +1,34 @@ +defaults: + - encoder/lowdim: mlp + - encoder/image: resnet + - aug: image + +embed_dim: 256 +lowdim_embed_dim: ${algo.embed_dim} +image_embed_dim: ${algo.embed_dim} + +policy: + image_encoder_factory: ${algo.encoder.image} + lowdim_encoder_factory: ${algo.encoder.lowdim} + aug_factory: ${algo.aug} + optimizer_factory: + _target_: torch.optim.AdamW + _partial_: true + lr: ${algo.lr} + betas: [0.9, 0.999] + weight_decay: ${algo.weight_decay} + scheduler_factory: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + _partial_: true + eta_min: 1e-5 + last_epoch: -1 + T_max: ${training.n_epochs} + embed_dim: ${algo.embed_dim} + shape_meta: ${task.shape_meta} + obs_reduction: 'none' + device: ${device} + +dataset: + lowdim_obs_seq_len: ${algo.skill_block_size} + load_next_obs: false + dataset_keys: [actions] \ No newline at end of file diff --git a/config/algo/bc_transformer.yaml b/config/algo/bc_transformer.yaml new file mode 100644 index 0000000..6805ee5 --- /dev/null +++ b/config/algo/bc_transformer.yaml @@ -0,0 +1,49 @@ +defaults: + - base + - _self_ + +policy: + _target_: quest.algos.bc_transformer.BCTransformerPolicy + transformer_model: + _target_: quest.algos.baseline_modules.bc_transformer_modules.TransformerDecoder + input_size: ${algo.embed_dim} + num_layers: 4 + num_heads: 6 + head_output_size: 64 + mlp_hidden_size: 256 + dropout: 0.1 + policy_head: + _target_: quest.algos.baseline_modules.bc_transformer_modules.GMMHead + input_size: ${algo.embed_dim} + output_size: ${task.shape_meta.action_dim} + hidden_size: 1024 + num_layers: 2 + min_std: 0.0001 + num_modes: 5 + low_eval_noise: false + activation: "softplus" + loss_coef: 1.0 + positional_encoding: + _target_: quest.algos.baseline_modules.bc_transformer_modules.SinusoidalPositionEncoding + input_size: ${algo.embed_dim} + inv_freq_factor: 10 + loss_reduction: 'mean' + obs_reduction: stack + device: ${device} + +name: bc_transformer_policy + +lr: 0.0001 +weight_decay: 0.0001 + +embed_dim: 128 +skill_block_size: 1 # bc_transformer does not do action chunking +frame_stack: 10 # this is input observation sequence length + +dataset: + seq_len: ${algo.skill_block_size} + frame_stack: ${algo.frame_stack} + obs_seq_len: 1 + lowdim_obs_seq_len: null + load_obs_for_pretrain: false + load_next_obs: false \ No newline at end of file diff --git a/config/algo/bet.yaml b/config/algo/bet.yaml new file mode 100644 index 0000000..adac744 --- /dev/null +++ b/config/algo/bet.yaml @@ -0,0 +1,62 @@ +defaults: + - base + - _self_ + +policy: + _target_: quest.algos.bet.BehaviorTransformer + autoencoder: + _target_: quest.algos.baseline_modules.vq_behavior_transformer.vqvae.VqVae + input_dim_h: ${algo.skill_block_size} # length of action chunk + input_dim_w: ${task.shape_meta.action_dim} # action dim + n_latent_dims: 512 + vqvae_n_embed: 32 + vqvae_groups: 2 + hidden_dim: 128 + num_layers: 1 + device: ${device} + policy_prior: + _target_: quest.algos.baseline_modules.vq_behavior_transformer.gpt.GPT + block_size: 30 + input_dim: ${algo.embed_dim} + output_dim: 256 # fixed as per original vqbet implementation + n_layer: 6 + n_head: 6 + n_embd: ${algo.embed_dim} + dropout: 0.1 + stage: ${stage} + loss_fn: + _target_: quest.algos.bet.FocalLoss + gamma: 2.0 + offset_loss_multiplier: ${algo.offset_loss_multiplier} + secondary_code_multiplier: ${algo.beta} + frame_stack: ${algo.frame_stack} + skill_block_size: ${algo.skill_block_size} + sequentially_select: false + action_horizon: ${algo.action_horizon} + obs_reduction: cat + device: ${device} + + +name: vqbet + +lr: 5.5e-5 +weight_decay: 2e-4 + +embed_dim: 120 +lowdim_embed_dim: 128 +image_embed_dim: 256 +offset_loss_multiplier: 100 +beta: 0.5 + +frame_stack: 10 +skill_block_size: 5 # this is input sequence length to encoder + +action_horizon: 2 # mpc horizon for execution + +dataset: + seq_len: ${eval:'${algo.frame_stack} + ${algo.skill_block_size} - 1'} + frame_stack: 1 + lowdim_obs_seq_len: null + obs_seq_len: ${algo.frame_stack} + load_obs_for_pretrain: false + load_next_obs: false \ No newline at end of file diff --git a/config/algo/data_collect.yaml b/config/algo/data_collect.yaml new file mode 100644 index 0000000..7a849cb --- /dev/null +++ b/config/algo/data_collect.yaml @@ -0,0 +1,4 @@ +# This is a dummy yaml file supplying values to make the task interpolation keys work + +frame_stack: 1 +skill_block_size: 1 \ No newline at end of file diff --git a/config/algo/diffusion_policy.yaml b/config/algo/diffusion_policy.yaml new file mode 100644 index 0000000..3277de8 --- /dev/null +++ b/config/algo/diffusion_policy.yaml @@ -0,0 +1,54 @@ +defaults: + - base + - _self_ + +policy: + _target_: quest.algos.diffusion_policy.DiffusionPolicy + diffusion_model: + _target_: quest.algos.diffusion_policy.DiffusionModel + noise_scheduler: + _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler + num_train_timesteps: ${algo.diffusion_train_steps} + beta_schedule: squaredcos_cap_v2 + action_dim: ${task.shape_meta.action_dim} + global_cond_dim: ${eval:'${algo.embed_dim} + ${algo.lang_emb_dim}'} + diffusion_step_emb_dim: ${algo.diffusion_step_emb_dim} + down_dims: [256,512,1024] + ema_power: 0.75 + skill_block_size: ${algo.skill_block_size} + diffusion_inf_steps: ${algo.diffusion_inf_steps} + device: ${device} + action_horizon: ${algo.action_horizon} + obs_reduction: cat + device: ${device} + + +name: diffusion_policy + +lr: 0.0001 +weight_decay: 0.0001 + +lowdim_embed_dim: 128 +image_embed_dim: 256 +pc_embed_dim: 256 +diffusion_step_emb_dim: 256 +lang_emb_dim: 256 # clip embedding size +embed_dim: 256 + +skill_block_size: 16 # this is input sequence length to encoder + + +diffusion_train_steps: 100 +diffusion_inf_steps: 10 + +action_horizon: 2 # mpc horizon for execution + +frame_stack: 1 + +dataset: + seq_len: ${algo.skill_block_size} + frame_stack: ${algo.frame_stack} + obs_seq_len: 1 + lowdim_obs_seq_len: null + load_obs_for_pretrain: true + load_next_obs: false \ No newline at end of file diff --git a/config/algo/encoder/image/act_style.yaml b/config/algo/encoder/image/act_style.yaml new file mode 100644 index 0000000..c4891a1 --- /dev/null +++ b/config/algo/encoder/image/act_style.yaml @@ -0,0 +1,8 @@ +# I don't rember if this works +_target_: quest.algos.utils.rgb_modules.ACTEncoder +_partial_: true +embed_dim: ${algo.image_embed_dim} +backbone_name: resnet18 +# train_backbone: true +return_interm_layers: false +dilation: false \ No newline at end of file diff --git a/config/algo/encoder/image/resnet.yaml b/config/algo/encoder/image/resnet.yaml new file mode 100644 index 0000000..802b5bf --- /dev/null +++ b/config/algo/encoder/image/resnet.yaml @@ -0,0 +1,8 @@ +_target_: quest.algos.utils.rgb_modules.ResnetEncoder +_partial_: true +output_size: ${algo.image_embed_dim} +pretrained: false +freeze: false +remove_layer_num: 4 +no_stride: false +language_fusion: 'none' \ No newline at end of file diff --git a/config/algo/encoder/lowdim/mlp.yaml b/config/algo/encoder/lowdim/mlp.yaml new file mode 100644 index 0000000..35c76e4 --- /dev/null +++ b/config/algo/encoder/lowdim/mlp.yaml @@ -0,0 +1,4 @@ +_target_: quest.algos.utils.mlp_proj.MLPProj +_partial_: true +output_size: ${algo.lowdim_embed_dim} +num_layers: 1 \ No newline at end of file diff --git a/config/algo/quest.yaml b/config/algo/quest.yaml new file mode 100644 index 0000000..d39c71a --- /dev/null +++ b/config/algo/quest.yaml @@ -0,0 +1,72 @@ +defaults: + - base + - _self_ + +policy: + _target_: quest.algos.quest.QueST + autoencoder: + _target_: quest.algos.quest_modules.skill_vae.SkillVAE + action_dim: ${task.shape_meta.action_dim} + encoder_dim: 256 + decoder_dim: 256 + skill_block_size: ${algo.skill_block_size} + downsample_factor: ${algo.downsample_factor} + attn_pdrop: 0.1 + use_causal_encoder: true + use_causal_decoder: true + encoder_heads: 4 + encoder_layers: 2 + decoder_heads: 4 + decoder_layers: 4 + vq_type: "fsq" # "vq" or "fsq" + # fsq_level: [8,5,5,5] + fsq_level: null + codebook_dim: 512 # only used for vq + codebook_size: ${algo.codebook_size} # if fsq level is null then it will automatically compute it according to this + policy_prior: + _target_: quest.algos.quest_modules.skill_gpt.SkillGPT + action_dim: ${task.shape_meta.action_dim} + start_token: 1000 # should be equal to actual fsq/vq codebook size + vocab_size: 1000 # should be equal to actual fsq/vq codebook size + block_size: ${eval:'${algo.skill_block_size} // ${algo.downsample_factor}'} + n_layer: 6 + n_head: 6 + n_embd: ${algo.embed_dim} + attn_pdrop: 0.1 + embd_pdrop: 0.1 + beam_size: 5 # value of k for top k sampling + temperature: 1.0 # temperature for sampling + device: ${device} + stage: ${stage} + loss_fn: + _target_: torch.nn.L1Loss + l1_loss_scale: ${algo.l1_loss_scale} + action_horizon: ${algo.action_horizon} + obs_reduction: cat + device: ${device} + +name: quest + +# Put hyperparameters that require tuning here +lr: 0.0001 +weight_decay: 0.0001 +l1_loss_scale: 0 # scale factor used in finetuning stage + +embed_dim: 384 # stage 2 transformer hidden dim +lowdim_embed_dim: 128 # each lowdim obs modality is embedded to this dim +image_embed_dim: 256 # each image obs vision encoder's output is embedded to this dim + +codebook_size: 1024 # note for fsq this will be computed automatically to be 1000, see get_fsq_level function in quest/algos/quest_modules/skill_vae.py +skill_block_size: 32 # this is input sequence length to encoder +downsample_factor: 4 + +action_horizon: 8 # how many predicted actions to execute +frame_stack: 1 + +dataset: + seq_len: ${algo.skill_block_size} # this denotes future timestep actions + frame_stack: ${algo.frame_stack} # this denotes past timestep observations and actions + obs_seq_len: 1 # this denotes future timestep image observations + lowdim_obs_seq_len: null # this denotes future timestep lowdim observations + load_obs_for_pretrain: false # since autoencoder training stage does not require obs + load_next_obs: false # you know this \ No newline at end of file diff --git a/config/collect_data.yaml b/config/collect_data.yaml new file mode 100644 index 0000000..dcf20f1 --- /dev/null +++ b/config/collect_data.yaml @@ -0,0 +1,22 @@ +defaults: + - task: metaworld_ml45 + - algo: data_collect + - _self_ + +algo: + name: data_collection + +exp_name: ${task.suite_name} # +variant_name: ${task.benchmark_name} +seed: 10000 +device: cuda:0 +stage: 1 +output_prefix: ./experiments +data_prefix: ./data +make_unique_experiment_dir: true + + +rollout: + enabled: true + rollouts_per_env: ${task.demos_per_env} + max_episode_length: 500 diff --git a/config/evaluate.yaml b/config/evaluate.yaml new file mode 100644 index 0000000..be72c0d --- /dev/null +++ b/config/evaluate.yaml @@ -0,0 +1,32 @@ +defaults: + - task: metaworld_ml45 + - algo: quest + - _self_ + + +training: + use_tqdm: true + n_epochs: 0 + do_profile: false + resume: false + load_obs: false + +rollout: + enabled: true + interval: 10 + rollouts_per_env: 50 + max_episode_length: ${task.horizon} + num_parallel_envs: 1 + n_video: 0 + +exp_name: debug # +variant_name: null +seed: 10000 +device: cuda:0 +stage: 1 # 0 - pretrain autoencoder, 1 - train multitask, 2 - finetune multitask +output_prefix: ./experiments +data_prefix: ./data +make_unique_experiment_dir: false + +checkpoint_path: null + diff --git a/config/task/libero_90.yaml b/config/task/libero_90.yaml new file mode 100644 index 0000000..5a2d494 --- /dev/null +++ b/config/task/libero_90.yaml @@ -0,0 +1,9 @@ +defaults: + - libero_base + - _self_ + +benchmark_name: LIBERO_90 +mode: all +n_tasks: 90 +rollouts_per_env: 50 + diff --git a/config/task/libero_base.yaml b/config/task/libero_base.yaml new file mode 100644 index 0000000..1ddec6e --- /dev/null +++ b/config/task/libero_base.yaml @@ -0,0 +1,74 @@ + +suite_name: libero +benchmark_name: null +mode: all +n_tasks: 10 +demos_per_env: 50 + +task_embedding_format: clip +img_height: 128 +img_width: 128 +horizon: 600 + +shape_meta: + action_dim: 7 + observation: + rgb: + agentview_rgb: + - 3 + - ${task.img_height} + - ${task.img_width} + eye_in_hand_rgb: + - 3 + - ${task.img_height} + - ${task.img_width} + lowdim: + joint_states: 7 + ee_pos: 3 + gripper_states: 2 + task: + type: vector + dim: 512 + +dataset: + _target_: quest.utils.libero_utils.build_dataset + data_prefix: ${data_prefix} + suite_name: ${task.suite_name} + benchmark_name: ${task.benchmark_name} + mode: ${task.mode} + seq_len: ${algo.dataset.seq_len} + frame_stack: ${algo.dataset.frame_stack} + obs_seq_len: ${algo.dataset.obs_seq_len} + shape_meta: ${task.shape_meta} + load_obs: ${training.load_obs} + task_embedding_format: ${task.task_embedding_format} + n_demos: ${task.demos_per_env} + +env_factory: + _target_: quest.utils.libero_utils.LiberoWrapper + _partial_: true + shape_meta: ${task.shape_meta} + obs_key_mapping: ${task.obs_key_mapping} + img_height: ${task.img_height} + img_width: ${task.img_width} + device: ${device} + +env_runner: + _target_: quest.env_runner.libero_runner.LiberoRunner + env_factory: ${task.env_factory} + frame_stack: ${algo.frame_stack} + benchmark_name: ${task.benchmark_name} + mode: ${task.mode} + rollouts_per_env: ${rollout.rollouts_per_env} + num_parallel_envs: ${rollout.num_parallel_envs} + max_episode_length: ${rollout.max_episode_length} + fps: 24 + debug: false + task_embedding_format: ${task.task_embedding_format} + +obs_key_mapping: + agentview_rgb: agentview_image + eye_in_hand_rgb: robot0_eye_in_hand_image + gripper_states: robot0_gripper_qpos + joint_states: robot0_joint_pos + ee_pos: robot0_eef_pos \ No newline at end of file diff --git a/config/task/libero_long.yaml b/config/task/libero_long.yaml new file mode 100644 index 0000000..5ac7146 --- /dev/null +++ b/config/task/libero_long.yaml @@ -0,0 +1,9 @@ +defaults: + - libero_base + - _self_ + +benchmark_name: LIBERO_10 +mode: all +n_tasks: 10 +rollouts_per_env: 50 + diff --git a/config/task/libero_long_fewshot.yaml b/config/task/libero_long_fewshot.yaml new file mode 100644 index 0000000..dfeeef6 --- /dev/null +++ b/config/task/libero_long_fewshot.yaml @@ -0,0 +1,9 @@ +defaults: + - libero_base + - _self_ + +benchmark_name: LIBERO_10 +mode: fewshot +n_tasks: 10 +rollouts_per_env: 5 + diff --git a/config/task/metaworld_base.yaml b/config/task/metaworld_base.yaml new file mode 100644 index 0000000..bcdfd2e --- /dev/null +++ b/config/task/metaworld_base.yaml @@ -0,0 +1,61 @@ + +suite_name: metaworld +mode: train +horizon: 500 +demo_noise: 0.2 + +img_height: 128 +img_width: 128 + +shape_meta: + action_dim: 4 + observation: + rgb: + corner_rgb: + - 3 + - ${task.img_height} + - ${task.img_width} + lowdim: + robot_states: 8 + task: + type: onehot + n_tasks: ${task.n_tasks} + +dataset: + _target_: quest.utils.metaworld_utils.build_dataset + data_prefix: ${data_prefix} + suite_name: ${task.suite_name} + benchmark_name: ${task.benchmark_name} + mode: ${task.mode} + seq_len: ${algo.dataset.seq_len} + frame_stack: ${algo.dataset.frame_stack} + obs_seq_len: ${algo.dataset.obs_seq_len} + lowdim_obs_seq_len: ${algo.dataset.lowdim_obs_seq_len} + shape_meta: ${task.shape_meta} + load_obs: ${training.load_obs} + n_demos: ${task.demos_per_env} + load_next_obs: ${algo.dataset.load_next_obs} + dataset_keys: ${algo.dataset.dataset_keys} + +env_factory: + _target_: quest.utils.metaworld_utils.MetaWorldWrapper + _partial_: true + shape_meta: ${task.shape_meta} + img_height: ${task.img_height} + img_width: ${task.img_width} + cameras: ['corner2'] + env_kwargs: null + +env_runner: + _target_: quest.env_runner.metaworld_runner.MetaWorldRunner + env_factory: + _target_: quest.utils.metaworld_utils.MetaWorldFrameStack + _partial_: true + env_factory: ${task.env_factory} + num_stack: ${algo.frame_stack} + benchmark_name: ${task.benchmark_name} + mode: ${task.mode} + rollouts_per_env: ${rollout.rollouts_per_env} + fps: 24 + debug: false + diff --git a/config/task/metaworld_ml45_prise.yaml b/config/task/metaworld_ml45_prise.yaml new file mode 100644 index 0000000..8b0b975 --- /dev/null +++ b/config/task/metaworld_ml45_prise.yaml @@ -0,0 +1,9 @@ +defaults: + - metaworld_base + - _self_ + +benchmark_name: ML45_PRISE +mode: train +n_tasks: 50 +demos_per_env: 100 + diff --git a/config/task/metaworld_ml45_prise_fewshot.yaml b/config/task/metaworld_ml45_prise_fewshot.yaml new file mode 100644 index 0000000..03bff98 --- /dev/null +++ b/config/task/metaworld_ml45_prise_fewshot.yaml @@ -0,0 +1,9 @@ +defaults: + - metaworld_base + - _self_ + +benchmark_name: ML45_PRISE +mode: test +n_tasks: 50 +demos_per_env: 5 + diff --git a/config/train_autoencoder.yaml b/config/train_autoencoder.yaml new file mode 100644 index 0000000..ac2b187 --- /dev/null +++ b/config/train_autoencoder.yaml @@ -0,0 +1,23 @@ +defaults: + - train_base + - _self_ + +train_dataloader: + batch_size: 256 + +training: + n_epochs: 100 # 100 recommended for libero, 200 for metaworld + use_amp: false # this seems to cause instabilities for shorter chunks (<64) but causes speedups for larger ones + load_obs: ${algo.dataset.load_obs_for_pretrain} + +rollout: + enabled: false + rollouts_per_env: null + max_episode_length: null + +stage: 0 # 0 - pretrain autoencoder, 1 - train multitask, 2 - finetune multitask + +logging_folder: autoencoder + + + diff --git a/config/train_base.yaml b/config/train_base.yaml new file mode 100644 index 0000000..27de721 --- /dev/null +++ b/config/train_base.yaml @@ -0,0 +1,66 @@ +defaults: + - task: metaworld_ml45 + - algo: quest + - _self_ + +# set the path to your data and experiments directories +output_prefix: ./experiments +data_prefix: ./data + +exp_name: debug # +variant_name: null +seed: 10000 +device: cuda:0 +stage: null # 0 - pretrain autoencoder, 1 - train multitask, 2 - finetune multitask +make_unique_experiment_dir: true # if true, it will create unique experiment directories with name run_00X within that experiment directory, use when debugging +logging_folder: training + +checkpoint_path: null + + +train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 128 + shuffle: true + num_workers: 6 + persistent_workers: false + pin_memory: true + multiprocessing_context: fork + # prefetch_factor: 2 + +training: + n_epochs: 100 + grad_clip: 100. + save_interval: 10 + log_interval: 100 + use_amp: false + use_tqdm: true + do_profile: false + save_all_checkpoints: false + auto_continue: false # if true, it will automatically continue from the end of stage n training for stage n+1 training + load_obs: true + cut: 0 + + # resume a training run + resume: false + resume_path: "" + +rollout: + enabled: true + interval: 10 + rollouts_per_env: 1 + max_episode_length: ${task.horizon} + n_video: 0 + num_parallel_envs: 1 + + +logging: + group: null + mode: online # set logging.mode=disabled to disable wandb + project: dynaquest # TODO: change this to your wandb project name + resume: true + save_code: true + + + + diff --git a/config/train_fewshot.yaml b/config/train_fewshot.yaml new file mode 100644 index 0000000..5c8a8b2 --- /dev/null +++ b/config/train_fewshot.yaml @@ -0,0 +1,14 @@ +defaults: + - train_base + - _self_ + +stage: 2 # 0 - pretrain autoencoder, 1 - train multitask, 2 - finetune multitask + +training: + n_epochs: 200 # 100 is best for libero, 200 for metaworld + +rollout: + interval: 25 + rollouts_per_env: 5 + +logging_folder: fewshot \ No newline at end of file diff --git a/config/train_prior.yaml b/config/train_prior.yaml new file mode 100644 index 0000000..0463b2b --- /dev/null +++ b/config/train_prior.yaml @@ -0,0 +1,15 @@ +defaults: + - train_base + - _self_ + +stage: 1 # 0 - pretrain autoencoder, 1 - train multitask, 2 - finetune multitask + +logging_folder: prior + +training: + n_epochs: 100 # on libero: 20 for quest and 100 for other algos, 100 on metaworld for all algos + +rollout: + interval: 25 # 25 is best for libero, 10 for metaworld + rollouts_per_env: 5 + num_parallel_envs: 5 # 5 recommended for libero, 1 for metaworld \ No newline at end of file diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..cdfad1c --- /dev/null +++ b/evaluate.py @@ -0,0 +1,80 @@ +import os +import time +import hydra +import wandb +from hydra.utils import instantiate +from omegaconf import OmegaConf +from tqdm import tqdm + +import torch +import torch.nn as nn +import quest.utils.utils as utils +from pyinstrument import Profiler +from moviepy.editor import ImageSequenceClip +import json + +OmegaConf.register_new_resolver("eval", eval, replace=True) + + +@hydra.main(config_path="config", config_name='evaluate', version_base=None) +def main(cfg): + device = cfg.device + seed = cfg.seed + torch.manual_seed(seed) + train_cfg = cfg.training + OmegaConf.resolve(cfg) + + # create model + save_dir, _ = utils.get_experiment_dir(cfg, evaluate=True) + os.makedirs(save_dir) + + if cfg.checkpoint_path is None: + # Basically if you don't provide a checkpoint path it will automatically find one corresponding + # to the experiment/variant name you provide + checkpoint_path, _ = utils.get_experiment_dir(cfg, evaluate=False, allow_overlap=True) + checkpoint_path = utils.get_latest_checkpoint(checkpoint_path) + else: + checkpoint_path = utils.get_latest_checkpoint(cfg.checkpoint_path) + state_dict = utils.load_state(checkpoint_path) + + if 'config' in state_dict: + print('autoloading based on saved parameters') + model = instantiate(state_dict['config']['algo']['policy'], + shape_meta=cfg.task.shape_meta) + else: + model = instantiate(cfg.algo.policy, + shape_meta=cfg.task.shape_meta) + model.to(device) + model.eval() + + model.load_state_dict(state_dict['model']) + + env_runner = instantiate(cfg.task.env_runner) + + print('Saving to:', save_dir) + print('Running evaluation...') + + def save_video_fn(video_chw, env_name, idx): + video_dir = os.path.join(save_dir, 'videos', env_name) + os.makedirs(video_dir, exist_ok=True) + save_path = os.path.join(video_dir, f'{idx}.mp4') + clip = ImageSequenceClip(list(video_chw.transpose(0, 2, 3, 1)), fps=24) + clip.write_videofile(save_path, fps=24, verbose=False, logger=None) + + if train_cfg.do_profile: + profiler = Profiler() + profiler.start() + rollout_results = env_runner.run(model, n_video=cfg.rollout.n_video, do_tqdm=train_cfg.use_tqdm, save_video_fn=save_video_fn) + if train_cfg.do_profile: + profiler.stop() + profiler.print() + print( + f"[info] success rate: {rollout_results['rollout']['overall_success_rate']:1.3f} \ + | environments solved: {rollout_results['rollout']['environments_solved']}") + + with open(os.path.join(save_dir, 'data.json'), 'w') as f: + json.dump(rollout_results, f) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0e39a33 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[build-system] +requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" + +[project] +name = "quest" +version = "0.1.0" +authors = [ + {name = "Atharva Mete", email = "amete7@gatech.edu"}, + {name = "Albert Wilcox", email = "albertwilcox@gatech.edu"}, +] +description = "Code for QueST: Self-Supervised Skill Abstractions for Continuous Control" +readme = "README.md" +requires-python = ">=3.10" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.10", +] +dynamic = ["dependencies"] + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} + +[tool.setuptools] +packages = {find = {}} \ No newline at end of file diff --git a/quest/__init__.py b/quest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/quest/algos/act.py b/quest/algos/act.py new file mode 100644 index 0000000..112b87a --- /dev/null +++ b/quest/algos/act.py @@ -0,0 +1,77 @@ +import torch +import torchvision.transforms as transforms +from quest.algos.base import ChunkPolicy +import numpy as np + +class ACT(ChunkPolicy): + def __init__( + self, + act_model, + loss_fn, + kl_weight, + lr_backbone, + **kwargs + ): + super().__init__(**kwargs) + self.loss_fn = loss_fn + self.kl_weight = kl_weight + self.lr_backbone = lr_backbone + + self.act_model = act_model.to(self.device) + + def compute_loss(self, data): + data = self.preprocess_input(data, train_mode=True) + actions = data['actions'] + perception_encodings, lowdim_encodings, lang_emb = self.get_embeddings(data) + is_pad = torch.zeros((actions.shape[0], actions.shape[1]), device=self.device, dtype=torch.bool) + pred_action, _, latent = self.act_model(lowdim_encodings, perception_encodings, lang_emb, actions, is_pad) + + # pred_action, latent = self.forward(data) + l1_loss = self.loss_fn(pred_action, data["actions"]) + total_kld, dim_wise_kld, mean_kld = kl_divergence(latent[0], latent[1]) + loss = l1_loss + total_kld[0]*self.kl_weight + info = { + 'l1_loss': l1_loss.item(), + 'total_kld': total_kld[0].item(), + 'mean_kld': mean_kld.item(), + 'total_loss': loss.item(), + } + return loss, info + + def sample_actions(self, data): + data = self.preprocess_input(data, train_mode=False) + perception_encodings, lowdim_encodings, lang_emb = self.get_embeddings(data) + pred_action, _, _ = self.act_model(lowdim_encodings, perception_encodings, lang_emb) + pred_action = pred_action.permute(1, 0, 2) + return pred_action.detach().cpu().numpy() + + def get_embeddings(self, data): + img_encodings, lowdim_encodings = self.obs_encode(data) + perception_encodings = torch.stack(img_encodings, dim=2) + B = perception_encodings.shape[0] + D = perception_encodings.shape[-1] + if len(lowdim_encodings) == 0: + lowdim_encodings = torch.zeros((B, 0, D), device=perception_encodings.device) + else: + lowdim_encodings = torch.stack(lowdim_encodings, dim=2) + lang_emb = self.get_task_emb(data) + # collapse frame stack dim and number of encoder dim into one dimension + perception_encodings = perception_encodings.reshape(B, -1, D) + lowdim_encodings = lowdim_encodings.reshape(B, -1, D) + + return perception_encodings, lowdim_encodings, lang_emb + +def kl_divergence(mu, logvar): + batch_size = mu.size(0) + assert batch_size != 0 + if mu.data.ndimension() == 4: + mu = mu.view(mu.size(0), mu.size(1)) + if logvar.data.ndimension() == 4: + logvar = logvar.view(logvar.size(0), logvar.size(1)) + + klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) + total_kld = klds.sum(1).mean(0, True) + dimension_wise_kld = klds.mean(0) + mean_kld = klds.mean(1).mean(0, True) + + return total_kld, dimension_wise_kld, mean_kld \ No newline at end of file diff --git a/quest/algos/base.py b/quest/algos/base.py new file mode 100644 index 0000000..d8514a9 --- /dev/null +++ b/quest/algos/base.py @@ -0,0 +1,227 @@ +import torch +import torch.nn as nn +from collections import deque +# from quest.modules.v1 import * +import quest.utils.tensor_utils as TensorUtils +from quest.utils.utils import map_tensor_to_device +import quest.utils.obs_utils as ObsUtils +import einops + +from abc import ABC, abstractmethod + +class Policy(nn.Module, ABC): + ''' + Super class with some basic functionality and functions we expect + from all policy classes in our training loop + ''' + + def __init__(self, + image_encoder_factory, + lowdim_encoder_factory, + aug_factory, + optimizer_factory, + scheduler_factory, + embed_dim, + obs_reduction, + shape_meta, + device, + ): + super().__init__() + + self.use_augmentation = aug_factory is not None + self.obs_reduction = obs_reduction + self.shape_meta = shape_meta + self.optimizer_factory = optimizer_factory + self.scheduler_factory = scheduler_factory + self.device = device + total_obs_channels = 0 + + do_image = image_encoder_factory is not None + do_lowdim = lowdim_encoder_factory is not None + + # observation encoders + self.image_encoders = {} + if do_image and shape_meta['observation']['rgb'] is not None: + for name, shape in shape_meta["observation"]['rgb'].items(): + shape_in = list(shape) + encoder = image_encoder_factory(shape_in) + total_obs_channels += encoder.out_channels + if obs_reduction == 'stack' and encoder.out_channels != embed_dim: + encoder = nn.Sequential( + encoder, + nn.ReLU(), + nn.Linear(encoder.out_channels, embed_dim) + ) + self.image_encoders[name] = encoder + self.image_encoders = nn.ModuleDict(self.image_encoders) + + self.lowdim_encoders = {} + if do_lowdim and shape_meta['observation']['lowdim'] is not None: + for name, shape in shape_meta['observation']['lowdim'].items(): + encoder = lowdim_encoder_factory(shape) + total_obs_channels += encoder.out_channels + if obs_reduction == 'stack' and encoder.out_channels != embed_dim: + encoder = nn.Sequential( + encoder, + nn.ReLU(), + nn.Linear(encoder.out_channels, embed_dim) + ) + self.lowdim_encoders[name] = encoder + self.lowdim_encoders = nn.ModuleDict(self.lowdim_encoders) + + if obs_reduction == 'cat': + self.obs_proj = nn.Linear(total_obs_channels, embed_dim) + else: self.obs_proj = None + + if self.use_augmentation: + self.aug = aug_factory(shape_meta=shape_meta) + + if shape_meta.task.type == "onehot": + self.task_encoder = nn.Embedding( + num_embeddings=shape_meta.task.n_tasks, + embedding_dim=embed_dim + ) + else: + self.task_encoder = nn.Linear(shape_meta.task.dim, embed_dim) + + self.device = device + + @abstractmethod + def compute_loss(self, data): + raise NotImplementedError('Implement in subclass') + + def get_optimizers(self): + decay, no_decay = TensorUtils.separate_no_decay(self) + optimizers = [ + self.optimizer_factory(params=decay), + self.optimizer_factory(params=no_decay, weight_decay=0.) + ] + return optimizers + + def get_schedulers(self, optimizers): + if self.scheduler_factory is None: + return [] + else: + return [self.scheduler_factory(optimizer=optimizer) for optimizer in optimizers] + + def preprocess_input(self, data, train_mode=True): + if train_mode and self.use_augmentation: + data = self.aug(data) + for key in self.image_encoders: + for obs_key in ('obs', 'next_obs'): + if obs_key in data: + x = TensorUtils.to_float(data[obs_key][key]) + x = x / 255. + x = torch.clip(x, 0, 1) + data[obs_key][key] = x + return data + + def obs_encode(self, data, hwc=False, obs_key='obs'): + ### 1. encode image + img_encodings, lowdim_encodings = [], [] + for img_name in self.image_encoders.keys(): + x = data[obs_key][img_name] + if hwc: + x = einops.rearrange(x, 'B T H W C -> B T C H W') + B, T, C, H, W = x.shape + e = self.image_encoders[img_name]( + x.reshape(B * T, C, H, W), + ) + e = e.view(B, T, *e.shape[1:]) + img_encodings.append(e) + + # 2. add proprio info + for lowdim_name in self.lowdim_encoders.keys(): + lowdim_encodings.append(self.lowdim_encoders[lowdim_name](data[obs_key][lowdim_name])) # add (B, T, H_extra) + + if self.obs_reduction == 'cat': + encoded = img_encodings + lowdim_encodings + encoded = torch.cat(encoded, -1) # (B, T, H_all) + if self.obs_proj is not None: + obs_emb = self.obs_proj(encoded) + elif self.obs_reduction == 'stack': + encoded = img_encodings + lowdim_encodings + encoded = torch.stack(encoded, dim=2) + obs_emb = encoded + elif self.obs_reduction == 'none': + return img_encodings, lowdim_encodings + return obs_emb + + def reset(self): + return + + def get_task_emb(self, data): + if "task_emb" in data: + return self.task_encoder(data["task_emb"]) + else: + return self.task_encoder(data["task_id"]) + + def get_action(self, obs, task_id, task_emb=None): + self.eval() + for key, value in obs.items(): + if key in self.image_encoders: + value = ObsUtils.process_frame(value, channel_dim=3) + obs[key] = torch.tensor(value) + batch = {} + batch["obs"] = obs + if task_emb is not None: + batch["task_emb"] = task_emb + else: + batch["task_id"] = torch.tensor([task_id], dtype=torch.long) + batch = map_tensor_to_device(batch, self.device) + with torch.no_grad(): + action = self.sample_actions(batch) + return action + + def preprocess_dataset(self, dataset, use_tqdm=True): + return + + @abstractmethod + def sample_actions(self, obs): + raise NotImplementedError('Implement in subclass') + + +class ChunkPolicy(Policy): + ''' + Super class for policies which predict chunks of actions + ''' + def __init__(self, + action_horizon, + **kwargs): + super().__init__(**kwargs) + + self.action_horizon = action_horizon + self.action_queue = None + + + def reset(self): + self.action_queue = deque(maxlen=self.action_horizon) + + def get_action(self, obs, task_id, task_emb=None): + assert self.action_queue is not None, "you need to call policy.reset() before getting actions" + + self.eval() + if len(self.action_queue) == 0: + for key, value in obs.items(): + if key in self.image_encoders: + value = ObsUtils.process_frame(value, channel_dim=3) + elif key in self.lowdim_encoders: + value = TensorUtils.to_float(value) # from double to float + obs[key] = torch.tensor(value) + batch = {} + batch["obs"] = obs + if task_emb is not None: + batch["task_emb"] = task_emb + else: + batch["task_id"] = torch.tensor([task_id], dtype=torch.long) + batch = map_tensor_to_device(batch, self.device) + with torch.no_grad(): + actions = self.sample_actions(batch) + self.action_queue.extend(actions[:self.action_horizon]) + action = self.action_queue.popleft() + return action + + @abstractmethod + def sample_actions(self, obs): + raise NotImplementedError('Implement in subclass') + diff --git a/quest/algos/baseline_modules/act_utils/detr_vae.py b/quest/algos/baseline_modules/act_utils/detr_vae.py new file mode 100644 index 0000000..f0cd463 --- /dev/null +++ b/quest/algos/baseline_modules/act_utils/detr_vae.py @@ -0,0 +1,232 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR model and criterion classes. +""" +import torch +from torch import nn +from torch.autograd import Variable +from quest.algos.baseline_modules.act_utils.transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer + +import numpy as np + + +def reparametrize(mu, logvar): + std = logvar.div(2).exp() + eps = Variable(std.data.new(std.size()).normal_()) + return mu + std * eps + + +def get_sinusoid_encoding_table(n_position, d_hid): + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +class DETRVAE(nn.Module): + """ This is the DETR module that performs object detection """ + def __init__(self, + # backbones, + transformer, + encoder, + state_dim, + proprio_dim, + num_queries, + shape_meta, + encoder_input=('lowdim',) + ): + """ Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + self.encoder = encoder + self.encoder_input = encoder_input + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, state_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear(state_dim, hidden_dim) # project action to embedding + self.encoder_joint_proj = nn.Linear(proprio_dim, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var + + n_inputs = 0 + obs_meta = shape_meta['observation'] + if 'lowdim' in encoder_input: + if obs_meta['lowdim'] is not None: + n_inputs += len(obs_meta['lowdim']) + if 'perception' in encoder_input: + if obs_meta['rgb'] is not None: + n_inputs += len(obs_meta['rgb']) + + self.n_encoder_inputs = n_inputs + self.register_buffer('pos_table', get_sinusoid_encoding_table(1 + num_queries + n_inputs, hidden_dim)) # [CLS], qpos, a_seq + + # decoder extra parameters + self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding + n_decoder_inputs = 0 + if obs_meta['lowdim'] is not None: + n_decoder_inputs += len(obs_meta['lowdim']) + if obs_meta['rgb'] is not None: + n_decoder_inputs += len(obs_meta['rgb']) + self.additional_pos_embed = nn.Embedding(2 + n_decoder_inputs, hidden_dim) # learned position embedding for proprio and latent + + def forward(self, lowdim_encodings, perception_encodings, task_emb, actions=None, is_pad=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs = lowdim_encodings.shape[0] + ### Obtain latent z from action sequence + if is_training: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim) + + encoder_input = [cls_embed] + if 'lowdim' in self.encoder_input: + encoder_input.append(lowdim_encodings) + if 'perception' in self.encoder_input: + encoder_input.append(perception_encodings) + encoder_input.append(action_embed) + encoder_input = torch.cat(encoder_input, axis=1) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 1 + self.n_encoder_inputs), False).to(lowdim_encodings.device) # False: not a padding + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + + encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, :self.latent_dim] + logvar = latent_info[:, self.latent_dim:] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(lowdim_encodings.device) + latent_input = self.latent_out_proj(latent_sample) + + task_emb = task_emb.unsqueeze(0) + latent_input = latent_input.unsqueeze(0) + lowdim_encodings = lowdim_encodings.permute(1, 0, 2) + perception_encodings = perception_encodings.permute(1, 0, 2) + try: + transformer_input = torch.cat([task_emb, latent_input, lowdim_encodings, perception_encodings], dim=0) + except RuntimeError: + raise ValueError(f"something went wrong with the shapes: task_emb: {task_emb.shape}, latent_input: {latent_input.shape}, lowdim_encodings: {lowdim_encodings.shape}, perception_encodings: {perception_encodings.shape}") + + query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) + input_pos_embed = self.additional_pos_embed.weight.unsqueeze(1).repeat(1, bs, 1) + + hs = self.transformer(transformer_input, None, query_embed, input_pos_embed)[0] + + a_hat = self.action_head(hs) + is_pad_hat = self.is_pad_head(hs) + return a_hat, is_pad_hat, [mu, logvar] + + def encode_actions(self, actions, qpos): + pass + + + +class CNNMLP(nn.Module): + def __init__(self, backbones, state_dim, camera_names): + """ Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.camera_names = camera_names + self.action_head = nn.Linear(1000, state_dim) # TODO add more + if backbones is not None: + self.backbones = nn.ModuleList(backbones) + backbone_down_projs = [] + for backbone in backbones: + down_proj = nn.Sequential( + nn.Conv2d(backbone.num_channels, 128, kernel_size=5), + nn.Conv2d(128, 64, kernel_size=5), + nn.Conv2d(64, 32, kernel_size=5) + ) + backbone_down_projs.append(down_proj) + self.backbone_down_projs = nn.ModuleList(backbone_down_projs) + + mlp_in_dim = 768 * len(backbones) + 14 + self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2) + else: + raise NotImplementedError + + def forward(self, qpos, image, env_state, actions=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + # Image observation features and position embeddings + all_cam_features = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[cam_id](image[:, cam_id]) + features = features[0] # take the last layer feature + pos = pos[0] # not used + all_cam_features.append(self.backbone_down_projs[cam_id](features)) + # flatten everything + flattened_features = [] + for cam_feature in all_cam_features: + flattened_features.append(cam_feature.reshape([bs, -1])) + flattened_features = torch.cat(flattened_features, axis=1) # 768 each + features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14 + a_hat = self.mlp(features) + return a_hat + + +def mlp(input_dim, hidden_dim, output_dim, hidden_depth): + if hidden_depth == 0: + mods = [nn.Linear(input_dim, output_dim)] + else: + mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] + for i in range(hidden_depth - 1): + mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] + mods.append(nn.Linear(hidden_dim, output_dim)) + trunk = nn.Sequential(*mods) + return trunk + + +def build_encoder(d_model=256, nheads=8, dim_feedforward=2048, enc_layers=4, pre_norm=False, dropout=0.1): + activation = "relu" + encoder_layer = TransformerEncoderLayer(d_model, nheads, dim_feedforward, + dropout, activation, pre_norm) + encoder_norm = nn.LayerNorm(d_model) if pre_norm else None + encoder = TransformerEncoder(encoder_layer, enc_layers, encoder_norm) + return encoder + diff --git a/quest/algos/baseline_modules/act_utils/misc.py b/quest/algos/baseline_modules/act_utils/misc.py new file mode 100644 index 0000000..dfa9fb5 --- /dev/null +++ b/quest/algos/baseline_modules/act_utils/misc.py @@ -0,0 +1,468 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from packaging import version +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if version.parse(torchvision.__version__) < version.parse('0.7'): + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if version.parse(torchvision.__version__) < version.parse('0.7'): + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/quest/algos/baseline_modules/act_utils/position_encoding.py b/quest/algos/baseline_modules/act_utils/position_encoding.py new file mode 100644 index 0000000..5bec025 --- /dev/null +++ b/quest/algos/baseline_modules/act_utils/position_encoding.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from quest.algos.baseline_modules.act_utils.misc import NestedTensor + +# import IPython +# e = IPython.embed + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor): + x = tensor + # mask = tensor_list.mask + # assert mask is not None + # not_mask = ~mask + + not_mask = torch.ones_like(x[0, [0]]) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(position_embedding, hidden_dim): + N_steps = hidden_dim // 2 + if position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {position_embedding}") + + return position_embedding diff --git a/quest/algos/baseline_modules/act_utils/transformer.py b/quest/algos/baseline_modules/act_utils/transformer.py new file mode 100644 index 0000000..c70c348 --- /dev/null +++ b/quest/algos/baseline_modules/act_utils/transformer.py @@ -0,0 +1,316 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +# import IPython +# e = IPython.embed + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed):#, latent_input=None, proprio_input=None, task_emb=None, additional_pos_embed=None): + # TODO flatten only when input has H and W + # if len(src.shape) == 4: # has H and W + # # flatten NxCxHxW to HWxNxC + # src = src.flatten(2).permute(2, 0, 1) + # pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + # # mask = mask.flatten(1) + + # additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim + # pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + # addition_input = torch.stack([task_emb, latent_input, proprio_input], axis=0) + # src = torch.cat([addition_input, src], axis=0) + # else: + # assert len(src.shape) == 3 + # # flatten NxHWxC to HWxNxC + # bs, hw, c = src.shape + # src = src.permute(1, 0, 2) + # pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) + # query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + # hw, bs, c = src.shape + # query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) + hs = hs.transpose(1, 2) + return hs + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(hidden_dim=512, dropout=0.1, nheads=8, + dim_feedforward=2048, enc_layers=6, dec_layers=6, + pre_norm=False): + return Transformer( + d_model=hidden_dim, + dropout=dropout, + nhead=nheads, + dim_feedforward=dim_feedforward, + num_encoder_layers=enc_layers, + num_decoder_layers=dec_layers, + normalize_before=pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/quest/algos/baseline_modules/bc_transformer_modules.py b/quest/algos/baseline_modules/bc_transformer_modules.py new file mode 100644 index 0000000..9a2b8a2 --- /dev/null +++ b/quest/algos/baseline_modules/bc_transformer_modules.py @@ -0,0 +1,327 @@ +import math +import numpy as np +from torch import nn +import torch +import torchvision +import torch.nn.functional as F +import torch.distributions as D + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +import quest.utils.tensor_utils as TensorUtils + + +############################################################################### +# +# Building blocks for transformers +# +############################################################################### + + +class Norm(nn.Module): + def __init__(self, dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + return self.norm(x) + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, head_output_size=64, dropout=0.0): + super().__init__() + + self.num_heads = num_heads + # \sqrt{d_{k}} + self.att_scale = head_output_size ** (-0.5) + self.qkv = nn.Linear(dim, num_heads * head_output_size * 3, bias=False) + + # We need to combine the output from all heads + self.output_layer = nn.Sequential( + nn.Linear(num_heads * head_output_size, dim), nn.Dropout(dropout) + ) + + def forward(self, x, mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = (qkv[0], qkv[1], qkv[2]) + + # q.dot(k.transpose) + attn = (q @ k.transpose(-2, -1)) * self.att_scale + if mask is not None: + mask = mask.bool() + if len(mask.shape) == 2: # (B, N) + attn = attn.masked_fill(~mask[:, None, None, :], float("-inf")) + elif len(mask.shape) == 3 and mask.shape[0] == 1: # (1, N, N) + attn = attn.masked_fill(~mask[None, :, :, :], float("-inf")) + elif ( + len(mask.shape) == 3 + ): # Consider the case where each batch has different causal mask, typically useful for MAE implementation + attn = attn.masked_fill( + ~mask[:, None, :, :].repeat(1, self.num_heads, 1, 1), float("-inf") + ) + else: + raise Exception("mask shape is not correct for attention") + attn = attn.softmax(dim=-1) + self.att_weights = attn + + # (..., num_heads, seq_len, head_output_size) + out = rearrange(torch.matmul(attn, v), "b h n d -> b n (h d)") + return self.output_layer(out) + + +class TransformerFeedForwardNN(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + # Remember the residual connection + layers = [ + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + + +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +class SinusoidalPositionEncoding(nn.Module): + def __init__(self, input_size, inv_freq_factor=10, factor_ratio=None): + super().__init__() + self.input_size = input_size + self.inv_freq_factor = inv_freq_factor + channels = self.input_size + channels = int(np.ceil(channels / 2) * 2) + + inv_freq = 1.0 / ( + self.inv_freq_factor ** (torch.arange(0, channels, 2).float() / channels) + ) + self.channels = channels + self.register_buffer("inv_freq", inv_freq) + + if factor_ratio is None: + self.factor = 1.0 + else: + factor = nn.Parameter(torch.ones(1) * factor_ratio) + self.register_parameter("factor", factor) + + def forward(self, x): + pos_x = torch.arange(x.shape[1], device=x.device).type(self.inv_freq.type()) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1) + return emb_x * self.factor + + def output_shape(self, input_shape): + return input_shape + + def output_size(self, input_size): + return input_size + + +############################################################################### +# +# Transformer Decoder (we only use transformer decoder for our policies) +# +############################################################################### + + +class TransformerDecoder(nn.Module): + def __init__( + self, + input_size, + num_layers, + num_heads, + head_output_size, + mlp_hidden_size, + dropout, + **kwargs + ): + super().__init__() + + self.layers = nn.ModuleList([]) + self.drop_path = DropPath(dropout) if dropout > 0.0 else nn.Identity() + + self.attention_output = {} + + for _ in range(num_layers): + self.layers.append( + nn.ModuleList( + [ + Norm(input_size), + Attention( + input_size, + num_heads=num_heads, + head_output_size=head_output_size, + dropout=dropout, + ), + Norm(input_size), + TransformerFeedForwardNN( + input_size, mlp_hidden_size, dropout=dropout + ), + ] + ) + ) + + self.attention_output[_] = None + self.seq_len = None + self.num_elements = None + self.mask = None + + def compute_mask(self, input_shape): + # input_shape = (:, seq_len, num_elements) + if ( + (self.num_elements is None) + or (self.seq_len is None) + or (self.num_elements != input_shape[2]) + or (self.seq_len != input_shape[1]) + ): + + self.seq_len = input_shape[1] + self.num_elements = input_shape[2] + self.original_mask = ( + torch.triu(torch.ones(self.seq_len, self.seq_len)) + - torch.eye(self.seq_len, self.seq_len) + ).to(self.device) + self.mask = 1 - self.original_mask.repeat_interleave( + self.num_elements, dim=-1 + ).repeat_interleave(self.num_elements, dim=-2).unsqueeze(0) + # (1, N, N), N = seq_len * num_elements + + def forward(self, x, mask=None): + for layer_idx, (att_norm, att, ff_norm, ff) in enumerate(self.layers): + if mask is not None: + x = x + drop_path(att(att_norm(x), mask)) + elif self.mask is not None: + x = x + drop_path(att(att_norm(x), self.mask)) + else: # no masking, just use full attention + x = x + drop_path(att(att_norm(x))) + + if not self.training: + self.attention_output[layer_idx] = att.att_weights + x = x + self.drop_path(ff(ff_norm(x))) + return x + + @property + def device(self): + return next(self.parameters()).device + + +class GMMHead(nn.Module): + def __init__( + self, + # network_kwargs + input_size, + output_size, + hidden_size=1024, + num_layers=2, + min_std=0.0001, + num_modes=5, + activation="softplus", + low_eval_noise=False, + # loss_kwargs + loss_coef=1.0, + ): + super().__init__() + self.num_modes = num_modes + self.output_size = output_size + self.min_std = min_std + + if num_layers > 0: + sizes = [input_size] + [hidden_size] * num_layers + layers = [] + for i in range(num_layers): + layers += [nn.Linear(sizes[i], sizes[i + 1]), nn.ReLU()] + layers += [nn.Linear(sizes[-2], sizes[-1])] + self.share = nn.Sequential(*layers) + else: + self.share = nn.Identity() + + self.mean_layer = nn.Linear(hidden_size, output_size * num_modes) + self.logstd_layer = nn.Linear(hidden_size, output_size * num_modes) + self.logits_layer = nn.Linear(hidden_size, num_modes) + + self.low_eval_noise = low_eval_noise + self.loss_coef = loss_coef + + if activation == "softplus": + self.actv = F.softplus + else: + self.actv = torch.exp + + def forward_fn(self, x): + # x: (B, input_size) + share = self.share(x) + means = self.mean_layer(share).view(-1, self.num_modes, self.output_size) + means = torch.tanh(means) + logits = self.logits_layer(share) + + if self.training or not self.low_eval_noise: + logstds = self.logstd_layer(share).view( + -1, self.num_modes, self.output_size + ) + stds = self.actv(logstds) + self.min_std + else: + stds = torch.ones_like(means) * 1e-4 + return means, stds, logits + + def forward(self, x): + if x.ndim == 3: + means, scales, logits = TensorUtils.time_distributed(x, self.forward_fn) + elif x.ndim < 3: + means, scales, logits = self.forward_fn(x) + + compo = D.Normal(loc=means, scale=scales) + compo = D.Independent(compo, 1) + mix = D.Categorical(logits=logits) + gmm = D.MixtureSameFamily( + mixture_distribution=mix, component_distribution=compo + ) + return gmm + + def loss_fn(self, gmm, target, reduction="mean"): + log_probs = gmm.log_prob(target) + loss = -log_probs + if reduction == "mean": + return loss.mean() * self.loss_coef + elif reduction == "none": + return loss * self.loss_coef + elif reduction == "sum": + return loss.sum() * self.loss_coef + else: + raise NotImplementedError diff --git a/quest/algos/baseline_modules/diffusion_modules.py b/quest/algos/baseline_modules/diffusion_modules.py new file mode 100644 index 0000000..8e4a4f8 --- /dev/null +++ b/quest/algos/baseline_modules/diffusion_modules.py @@ -0,0 +1,246 @@ +import math +from torch import nn +import torch +from typing import Union + + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Downsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + +class Upsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Conv1dBlock(nn.Module): + ''' + Conv1d --> GroupNorm --> Mish + ''' + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.GroupNorm(n_groups, out_channels), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class ConditionalResidualBlock1D(nn.Module): + def __init__(self, + in_channels, + out_channels, + cond_dim, + kernel_size=3, + n_groups=8): + super().__init__() + + self.blocks = nn.ModuleList([ + Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), + Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), + ]) + + # FiLM modulation https://arxiv.org/abs/1709.07871 + # predicts per-channel scale and bias + cond_channels = out_channels * 2 + self.out_channels = out_channels + self.cond_encoder = nn.Sequential( + nn.Mish(), + nn.Linear(cond_dim, cond_channels), + nn.Unflatten(-1, (-1, 1)) + ) + + # make sure dimensions compatible + self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ + if in_channels != out_channels else nn.Identity() + + def forward(self, x, cond): + ''' + x : [ batch_size x in_channels x horizon ] + cond : [ batch_size x cond_dim] + + returns: + out : [ batch_size x out_channels x horizon ] + ''' + out = self.blocks[0](x) + embed = self.cond_encoder(cond) + + embed = embed.reshape( + embed.shape[0], 2, self.out_channels, 1) + scale = embed[:,0,...] + bias = embed[:,1,...] + out = scale * out + bias + + out = self.blocks[1](out) + out = out + self.residual_conv(x) + return out + + +class ConditionalUnet1D(nn.Module): + def __init__(self, + input_dim, + global_cond_dim, + diffusion_step_embed_dim=256, + down_dims=[256,512,1024], + kernel_size=3, + n_groups=8 + ): + """ + input_dim: Dim of actions. + global_cond_dim: Dim of global conditioning applied with FiLM + in addition to diffusion step embedding. This is usually obs_horizon * obs_dim + diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k + down_dims: Channel size for each UNet level. + The length of this array determines numebr of levels. + kernel_size: Conv kernel size + n_groups: Number of groups for GroupNorm + """ + + super().__init__() + all_dims = [input_dim] + list(down_dims) + start_dim = down_dims[0] + + dsed = diffusion_step_embed_dim + diffusion_step_encoder = nn.Sequential( + SinusoidalPosEmb(dsed), + nn.Linear(dsed, dsed * 4), + nn.Mish(), + nn.Linear(dsed * 4, dsed), + ) + cond_dim = dsed + global_cond_dim + + in_out = list(zip(all_dims[:-1], all_dims[1:])) + mid_dim = all_dims[-1] + self.mid_modules = nn.ModuleList([ + ConditionalResidualBlock1D( + mid_dim, mid_dim, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups + ), + ConditionalResidualBlock1D( + mid_dim, mid_dim, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups + ), + ]) + + down_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + down_modules.append(nn.ModuleList([ + ConditionalResidualBlock1D( + dim_in, dim_out, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups), + ConditionalResidualBlock1D( + dim_out, dim_out, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups), + Downsample1d(dim_out) if not is_last else nn.Identity() + ])) + + up_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + is_last = ind >= (len(in_out) - 1) + up_modules.append(nn.ModuleList([ + ConditionalResidualBlock1D( + dim_out*2, dim_in, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups), + ConditionalResidualBlock1D( + dim_in, dim_in, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups), + Upsample1d(dim_in) if not is_last else nn.Identity() + ])) + + final_conv = nn.Sequential( + Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), + nn.Conv1d(start_dim, input_dim, 1), + ) + + self.diffusion_step_encoder = diffusion_step_encoder + self.up_modules = up_modules + self.down_modules = down_modules + self.final_conv = final_conv + + print("number of parameters: {:e}".format( + sum(p.numel() for p in self.parameters())) + ) + + def forward(self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + global_cond=None): + """ + x: (B,T,input_dim) + timestep: (B,) or int, diffusion step + global_cond: (B,global_cond_dim) + output: (B,T,input_dim) + """ + # (B,T,C) + sample = sample.moveaxis(-1,-2) + # (B,C,T) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + global_feature = self.diffusion_step_encoder(timesteps) + + if global_cond is not None: + global_feature = torch.cat([ + global_feature, global_cond + ], axis=-1) + + x = sample + h = [] + for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + h.append(x) + x = downsample(x) + + for mid_module in self.mid_modules: + x = mid_module(x, global_feature) + + for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): + x = torch.cat((x, h.pop()), dim=1) + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + x = upsample(x) + + x = self.final_conv(x) + + # (B,C,T) + x = x.moveaxis(-1,-2) + # (B,T,C) + return x \ No newline at end of file diff --git a/quest/algos/baseline_modules/vector_quantize_pytorch_bet/residual_vq.py b/quest/algos/baseline_modules/vector_quantize_pytorch_bet/residual_vq.py new file mode 100644 index 0000000..fc94698 --- /dev/null +++ b/quest/algos/baseline_modules/vector_quantize_pytorch_bet/residual_vq.py @@ -0,0 +1,295 @@ +from math import ceil +from functools import partial +from itertools import zip_longest +from random import randrange + +import torch +from torch import nn +import torch.nn.functional as F +from quest.algos.baseline_modules.vector_quantize_pytorch_bet.vector_quantize_pytorch_bet import VectorQuantize + +from einops import rearrange, repeat, pack, unpack + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def round_up_multiple(num, mult): + return ceil(num / mult) * mult + +# main class + +class ResidualVQ(nn.Module): + """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ + def __init__( + self, + *, + dim, + num_quantizers, + codebook_dim = None, + shared_codebook = False, + heads = 1, + quantize_dropout = False, + quantize_dropout_cutoff_index = 0, + quantize_dropout_multiple_of = 1, + accept_image_fmap = False, + **kwargs + ): + super().__init__() + assert heads == 1, 'residual vq is not compatible with multi-headed codes' + codebook_dim = default(codebook_dim, dim) + codebook_input_dim = codebook_dim * heads + + requires_projection = codebook_input_dim != dim + self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + + self.num_quantizers = num_quantizers + + self.accept_image_fmap = accept_image_fmap + self.layers = nn.ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)]) + + self.quantize_dropout = quantize_dropout and num_quantizers > 1 + + assert quantize_dropout_cutoff_index >= 0 + + self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index + self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 + + if not shared_codebook: + return + + first_vq, *rest_vq = self.layers + codebook = first_vq._codebook + + for vq in rest_vq: + vq._codebook = codebook + + @property + def codebooks(self): + codebooks = [layer._codebook.embed for layer in self.layers] + codebooks = torch.stack(codebooks, dim = 0) + codebooks = rearrange(codebooks, 'q 1 c d -> q c d') + return codebooks + + def get_codes_from_indices(self, indices): + + batch, quantize_dim = indices.shape[0], indices.shape[-1] + + # may also receive indices in the shape of 'b h w q' (accept_image_fmap) + + indices, ps = pack([indices], 'b * q') + + # because of quantize dropout, one can pass in indices that are coarse + # and the network should be able to reconstruct + + if quantize_dim < self.num_quantizers: + assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations' + indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1) + + # get ready for gathering + + codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch) + gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1]) + + # take care of quantizer dropout + + mask = gather_indices == -1. + gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later + + all_codes = codebooks.gather(2, gather_indices) # gather all codes + + # mask out any codes that were dropout-ed + + all_codes = all_codes.masked_fill(mask, 0.) + + # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) + + all_codes, = unpack(all_codes, ps, 'q b * d') + + return all_codes + + def draw_logits_forward(self, encoding_logits): + # encoding_indices : dim1 = batch_size dim2 = 4 (number of groups) dim3 = vq dict size (header) + encoding_logits = encoding_logits.to(self.device) + bs = encoding_logits.shape[0] + quantized = torch.zeros((bs,self.codebooks.shape[-1])).to(self.device) + for q in range(encoding_logits.shape[1]): + quantized += torch.matmul(encoding_logits[:, q], self.codebooks[q]).to(self.device) + return quantized + + def forward( + self, + x, + indices = None, + return_all_codes = False, + sample_codebook_temp = None + ): + num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device + + x = self.project_in(x) + + assert not (self.accept_image_fmap and exists(indices)) + + quantized_out = 0. + residual = x + + all_losses = [] + all_indices = [] + + if return_loss: + assert not torch.any(indices == -1), 'some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss' + ce_losses = [] + + should_quantize_dropout = self.training and self.quantize_dropout and not return_loss + + # sample a layer index at which to dropout further residual quantization + # also prepare null indices and loss + + if should_quantize_dropout: + rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant) + + if quant_dropout_multiple_of != 1: + rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1 + + null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2]) + null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long) + null_loss = torch.full((1,), 0., device = device, dtype = x.dtype) + + # go through the layers + + for quantizer_index, layer in enumerate(self.layers): + + if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index: + all_indices.append(null_indices) + all_losses.append(null_loss) + continue + + layer_indices = None + if return_loss: + layer_indices = indices[..., quantizer_index] + + quantized, *rest = layer(residual, indices = layer_indices, sample_codebook_temp = sample_codebook_temp) + + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + + if return_loss: + ce_loss = rest[0] + ce_losses.append(ce_loss) + continue + + embed_indices, loss = rest + + all_indices.append(embed_indices) + all_losses.append(loss) + + # project out, if needed + + quantized_out = self.project_out(quantized_out) + + # whether to early return the cross entropy loss + + if return_loss: + return quantized_out, sum(ce_losses) + + # stack all losses and indices + + all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices)) + + ret = (quantized_out, all_indices, all_losses) + + if return_all_codes: + # whether to return all codes from all codebooks across layers + all_codes = self.get_codes_from_indices(all_indices) + + # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) + ret = (*ret, all_codes) + + return ret + +# grouped residual vq + +class GroupedResidualVQ(nn.Module): + def __init__( + self, + *, + dim, + groups = 1, + accept_image_fmap = False, + **kwargs + ): + super().__init__() + self.dim = dim + self.groups = groups + assert (dim % groups) == 0 + dim_per_group = dim // groups + + self.accept_image_fmap = accept_image_fmap + + self.rvqs = nn.ModuleList([]) + + for _ in range(groups): + self.rvqs.append(ResidualVQ( + dim = dim_per_group, + accept_image_fmap = accept_image_fmap, + **kwargs + )) + + @property + def codebooks(self): + return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs)) + + def get_codes_from_indices(self, indices): + codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices)) + return torch.stack(codes) + + def forward( + self, + x, + indices = None, + return_all_codes = False, + sample_codebook_temp = None + ): + shape = x.shape + split_dim = 1 if self.accept_image_fmap else -1 + assert shape[split_dim] == self.dim + + # split the feature dimension into groups + + x = x.chunk(self.groups, dim = split_dim) + + indices = default(indices, tuple()) + return_ce_loss = len(indices) > 0 + assert len(indices) == 0 or len(indices) == self.groups + + forward_kwargs = dict( + return_all_codes = return_all_codes, + sample_codebook_temp = sample_codebook_temp + ) + + # invoke residual vq on each group + + out = tuple(rvq(chunk, indices = chunk_indices, **forward_kwargs) for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices)) + out = tuple(zip(*out)) + + # if returning cross entropy loss to rvq codebooks + + if return_ce_loss: + quantized, ce_losses = out + return torch.cat(quantized, dim = split_dim), sum(ce_losses) + + # otherwise, get all the zipped outputs and combine them + + quantized, all_indices, commit_losses, *maybe_all_codes = out + + quantized = torch.cat(quantized, dim = split_dim) + all_indices = torch.stack(all_indices) + commit_losses = torch.stack(commit_losses) + + ret = (quantized, all_indices, commit_losses, *maybe_all_codes) + return ret diff --git a/quest/algos/baseline_modules/vector_quantize_pytorch_bet/vector_quantize_pytorch_bet.py b/quest/algos/baseline_modules/vector_quantize_pytorch_bet/vector_quantize_pytorch_bet.py new file mode 100644 index 0000000..e6cd989 --- /dev/null +++ b/quest/algos/baseline_modules/vector_quantize_pytorch_bet/vector_quantize_pytorch_bet.py @@ -0,0 +1,1050 @@ +from functools import partial + +import torch +from torch import nn, einsum +import torch.nn.functional as F +import torch.distributed as distributed +from torch.optim import Optimizer +from torch.cuda.amp import autocast + +from einops import rearrange, repeat, reduce, pack, unpack + +from typing import Callable + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def noop(*args, **kwargs): + pass + +def identity(t): + return t + +def l2norm(t): + return F.normalize(t, p = 2, dim = -1) + +def cdist(x, y): + x2 = reduce(x ** 2, 'b n d -> b n', 'sum') + y2 = reduce(y ** 2, 'b n d -> b n', 'sum') + xy = einsum('b i d, b j d -> b i j', x, y) * -2 + return (rearrange(x2, 'b i -> b i 1') + rearrange(y2, 'b j -> b 1 j') + xy).sqrt() + +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +def ema_inplace(old, new, decay): + is_mps = str(old.device).startswith('mps:') + + if not is_mps: + old.lerp_(new, 1 - decay) + else: + old.mul_(decay).add_(new * (1 - decay)) + +def pack_one(t, pattern): + return pack([t], pattern) + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + +def uniform_init(*shape): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + +def gumbel_sample( + logits, + temperature = 1., + stochastic = False, + straight_through = False, + reinmax = False, + dim = -1, + training = True +): + dtype, size = logits.dtype, logits.shape[dim] + + if training and stochastic and temperature > 0: + sampling_logits = (logits / temperature) + gumbel_noise(logits) + else: + sampling_logits = logits + + ind = sampling_logits.argmax(dim = dim) + one_hot = F.one_hot(ind, size).type(dtype) + + assert not (reinmax and not straight_through), 'reinmax can only be turned on if using straight through gumbel softmax' + + if not straight_through or temperature <= 0. or not training: + return ind, one_hot + + # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612 + # algorithm 2 + + if reinmax: + π0 = logits.softmax(dim = dim) + π1 = (one_hot + (logits / temperature).softmax(dim = dim)) / 2 + π1 = ((log(π1) - logits).detach() + logits).softmax(dim = 1) + π2 = 2 * π1 - 0.5 * π0 + one_hot = π2 - π2.detach() + one_hot + else: + π1 = (logits / temperature).softmax(dim = dim) + one_hot = one_hot + π1 - π1.detach() + + return ind, one_hot + +def laplace_smoothing(x, n_categories, eps = 1e-5, dim = -1): + denom = x.sum(dim = dim, keepdim = True) + return (x + eps) / (denom + n_categories * eps) + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + if num_samples >= num: + indices = torch.randperm(num_samples, device = device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device = device) + + return samples[indices] + +def batched_sample_vectors(samples, num): + return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0) + +def pad_shape(shape, size, dim = 0): + return [size if i == dim else s for i, s in enumerate(shape)] + +def sample_multinomial(total_count, probs): + device = probs.device + probs = probs.cpu() + + total_count = probs.new_full((), total_count) + remainder = probs.new_ones(()) + sample = torch.empty_like(probs, dtype = torch.long) + + for i, p in enumerate(probs): + s = torch.binomial(total_count, p / remainder) + sample[i] = s + total_count -= s + remainder -= p + + return sample.to(device) + +def all_gather_sizes(x, dim): + size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device) + all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] + distributed.all_gather(all_sizes, size) + return torch.stack(all_sizes) + +def all_gather_variably_sized(x, sizes, dim = 0): + rank = distributed.get_rank() + all_x = [] + + for i, size in enumerate(sizes): + t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) + distributed.broadcast(t, src = i, async_op = True) + all_x.append(t) + + distributed.barrier() + return all_x + +def sample_vectors_distributed(local_samples, num): + local_samples = rearrange(local_samples, '1 ... -> ...') + + rank = distributed.get_rank() + all_num_samples = all_gather_sizes(local_samples, dim = 0) + + if rank == 0: + samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) + else: + samples_per_rank = torch.empty_like(all_num_samples) + + distributed.broadcast(samples_per_rank, src = 0) + samples_per_rank = samples_per_rank.tolist() + + local_samples = sample_vectors(local_samples, samples_per_rank[rank]) + all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0) + out = torch.cat(all_samples, dim = 0) + + return rearrange(out, '... -> 1 ...') + +def batched_bincount(x, *, minlength): + batch, dtype, device = x.shape[0], x.dtype, x.device + target = torch.zeros(batch, minlength, dtype = dtype, device = device) + values = torch.ones_like(x) + target.scatter_add_(-1, x, values) + return target + +def kmeans( + samples, + num_clusters, + num_iters = 10, + use_cosine_sim = False, + sample_fn = batched_sample_vectors, + all_reduce_fn = noop +): + num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device + + means = sample_fn(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ rearrange(means, 'h n d -> h d n') + else: + dists = -torch.cdist(samples, means, p = 2) + + buckets = torch.argmax(dists, dim = -1) + bins = batched_bincount(buckets, minlength = num_clusters) + all_reduce_fn(bins) + + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype) + + new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples) + new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1') + all_reduce_fn(new_means) + + if use_cosine_sim: + new_means = l2norm(new_means) + + means = torch.where( + rearrange(zero_mask, '... -> ... 1'), + means, + new_means + ) + + return means, bins + +def batched_embedding(indices, embeds): + batch, dim = indices.shape[1], embeds.shape[-1] + indices = repeat(indices, 'h b n -> h b n d', d = dim) + embeds = repeat(embeds, 'h c d -> h b c d', b = batch) + return embeds.gather(2, indices) + +# regularization losses + +def orthogonal_loss_fn(t): + # eq (2) from https://arxiv.org/abs/2112.00384 + h, n = t.shape[:2] + normed_codes = l2norm(t) + cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes) + return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n) + +# distance types + +class EuclideanCodebook(nn.Module): + def __init__( + self, + dim, + codebook_size, + num_codebooks = 1, + kmeans_init = False, + kmeans_iters = 10, + sync_kmeans = True, + decay = 0.8, + eps = 1e-5, + threshold_ema_dead_code = 2, + reset_cluster_size = None, + use_ddp = False, + learnable_codebook = False, + gumbel_sample = gumbel_sample, + sample_codebook_temp = 1., + ema_update = True, + affine_param = False, + sync_affine_param = False, + affine_param_batch_decay = 0.99, + affine_param_codebook_decay = 0.9 + ): + super().__init__() + self.transform_input = identity + + self.decay = decay + self.ema_update = ema_update + + init_fn = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(num_codebooks, codebook_size, dim) + + self.codebook_size = codebook_size + self.num_codebooks = num_codebooks + + self.kmeans_iters = kmeans_iters + self.eps = eps + self.threshold_ema_dead_code = threshold_ema_dead_code + self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code) + + assert callable(gumbel_sample) + self.gumbel_sample = gumbel_sample + self.sample_codebook_temp = sample_codebook_temp + + assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now' + + self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors + self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop + self.all_reduce_fn = distributed.all_reduce if use_ddp else noop + + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size)) + self.register_buffer('embed_avg', embed.clone()) + + self.learnable_codebook = learnable_codebook + if learnable_codebook: + self.embed = nn.Parameter(embed) + else: + self.register_buffer('embed', embed) + + # affine related params + + self.affine_param = affine_param + self.sync_affine_param = sync_affine_param + + if not affine_param: + return + + self.affine_param_batch_decay = affine_param_batch_decay + self.affine_param_codebook_decay = affine_param_codebook_decay + + self.register_buffer('batch_mean', None) + self.register_buffer('batch_variance', None) + + self.register_buffer('codebook_mean_needs_init', torch.Tensor([True])) + self.register_buffer('codebook_mean', torch.empty(num_codebooks, 1, dim)) + self.register_buffer('codebook_variance_needs_init', torch.Tensor([True])) + self.register_buffer('codebook_variance', torch.empty(num_codebooks, 1, dim)) + + @torch.jit.ignore + def init_embed_(self, data, mask = None): + if self.initted: + return + + if exists(mask): + c = data.shape[0] + data = rearrange(data[mask], '(c n) d -> c n d', c = c) + + embed, cluster_size = kmeans( + data, + self.codebook_size, + self.kmeans_iters, + sample_fn = self.sample_fn, + all_reduce_fn = self.kmeans_all_reduce_fn + ) + + embed_sum = embed * rearrange(cluster_size, '... -> ... 1') + + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed_sum) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + @torch.jit.ignore + def update_with_decay(self, buffer_name, new_value, decay): + old_value = getattr(self, buffer_name) + + needs_init = getattr(self, buffer_name + "_needs_init", False) + + if needs_init: + self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False])) + + if not exists(old_value) or needs_init: + self.register_buffer(buffer_name, new_value.detach()) + + return + + value = old_value * decay + new_value.detach() * (1 - decay) + self.register_buffer(buffer_name, value) + + @torch.jit.ignore + def update_affine(self, data, embed, mask = None): + assert self.affine_param + + var_fn = partial(torch.var, unbiased = False) + + # calculate codebook mean and variance + + embed = rearrange(embed, 'h ... d -> h (...) d') + + if self.training: + self.update_with_decay('codebook_mean', reduce(embed, 'h n d -> h 1 d', 'mean'), self.affine_param_codebook_decay) + self.update_with_decay('codebook_variance', reduce(embed, 'h n d -> h 1 d', var_fn), self.affine_param_codebook_decay) + + # prepare batch data, which depends on whether it has masking + + data = rearrange(data, 'h ... d -> h (...) d') + + if exists(mask): + c = data.shape[0] + data = rearrange(data[mask], '(c n) d -> c n d', c = c) + + # calculate batch mean and variance + + if not self.sync_affine_param: + self.update_with_decay('batch_mean', reduce(data, 'h n d -> h 1 d', 'mean'), self.affine_param_batch_decay) + self.update_with_decay('batch_variance', reduce(data, 'h n d -> h 1 d', var_fn), self.affine_param_batch_decay) + return + + num_vectors, device, dtype = data.shape[-2], data.device, data.dtype + + # number of vectors, for denominator + + num_vectors = torch.tensor([num_vectors], device = device, dtype = dtype) + distributed.all_reduce(num_vectors) + + # calculate distributed mean + + batch_sum = reduce(data, 'h n d -> h 1 d', 'sum') + distributed.all_reduce(batch_sum) + batch_mean = batch_sum / num_vectors + + self.update_with_decay('batch_mean', batch_mean, self.affine_param_batch_decay) + + # calculate distributed variance + + variance_numer = reduce((data - batch_mean) ** 2, 'h n d -> h 1 d', 'sum') + distributed.all_reduce(variance_numer) + batch_variance = variance_numer / num_vectors + + self.update_with_decay('batch_variance', batch_variance, self.affine_param_batch_decay) + + def replace(self, batch_samples, batch_mask): + for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))): + if not torch.any(mask): + continue + + sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) + sampled = rearrange(sampled, '1 ... -> ...') + + self.embed.data[ind][mask] = sampled + + self.cluster_size.data[ind][mask] = self.reset_cluster_size + self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') + self.replace(batch_samples, batch_mask = expired_codes) + + @autocast(enabled = False) + def forward( + self, + x, + sample_codebook_temp = None, + mask = None, + freeze_codebook = False + ): + needs_codebook_dim = x.ndim < 4 + sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp) + + x = x.float() + + if needs_codebook_dim: + x = rearrange(x, '... -> 1 ...') + + dtype = x.dtype + flatten, ps = pack_one(x, 'h * d') + + if exists(mask): + mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1])) + + self.init_embed_(flatten, mask = mask) + + if self.affine_param: + self.update_affine(flatten, self.embed, mask = mask) + + embed = self.embed if self.learnable_codebook else self.embed.detach() + + if self.affine_param: + codebook_std = self.codebook_variance.clamp(min = 1e-5).sqrt() + batch_std = self.batch_variance.clamp(min = 1e-5).sqrt() + embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean + + dist = -cdist(flatten, embed) + + embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training) + + embed_ind = unpack_one(embed_ind, ps, 'h *') + + if self.training: + unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c') + quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed) + else: + quantize = batched_embedding(embed_ind, embed) + + if self.training and self.ema_update and not freeze_codebook: + + if self.affine_param: + flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean + + if exists(mask): + embed_onehot[~mask] = 0. + + cluster_size = embed_onehot.sum(dim = 1) + + self.all_reduce_fn(cluster_size) + ema_inplace(self.cluster_size.data, cluster_size, self.decay) + + embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) + self.all_reduce_fn(embed_sum.contiguous()) + ema_inplace(self.embed_avg.data, embed_sum, self.decay) + + cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True) + + embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1') + self.embed.data.copy_(embed_normalized) + self.expire_codes_(x) + + if needs_codebook_dim: + quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) + + dist = unpack_one(dist, ps, 'h * d') + + return quantize, embed_ind, dist + +class CosineSimCodebook(nn.Module): + def __init__( + self, + dim, + codebook_size, + num_codebooks = 1, + kmeans_init = False, + kmeans_iters = 10, + sync_kmeans = True, + decay = 0.8, + eps = 1e-5, + threshold_ema_dead_code = 2, + reset_cluster_size = None, + use_ddp = False, + learnable_codebook = False, + gumbel_sample = gumbel_sample, + sample_codebook_temp = 1., + ema_update = True + ): + super().__init__() + self.transform_input = l2norm + + self.ema_update = ema_update + self.decay = decay + + if not kmeans_init: + embed = l2norm(uniform_init(num_codebooks, codebook_size, dim)) + else: + embed = torch.zeros(num_codebooks, codebook_size, dim) + + self.codebook_size = codebook_size + self.num_codebooks = num_codebooks + + self.kmeans_iters = kmeans_iters + self.eps = eps + self.threshold_ema_dead_code = threshold_ema_dead_code + self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code) + + assert callable(gumbel_sample) + self.gumbel_sample = gumbel_sample + self.sample_codebook_temp = sample_codebook_temp + + self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors + self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop + self.all_reduce_fn = distributed.all_reduce if use_ddp else noop + + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size)) + self.register_buffer('embed_avg', embed.clone()) + + self.learnable_codebook = learnable_codebook + if learnable_codebook: + self.embed = nn.Parameter(embed) + else: + self.register_buffer('embed', embed) + + @torch.jit.ignore + def init_embed_(self, data, mask = None): + if self.initted: + return + + if exists(mask): + c = data.shape[0] + data = rearrange(data[mask], '(c n) d -> c n d', c = c) + + embed, cluster_size = kmeans( + data, + self.codebook_size, + self.kmeans_iters, + use_cosine_sim = True, + sample_fn = self.sample_fn, + all_reduce_fn = self.kmeans_all_reduce_fn + ) + + embed_sum = embed * rearrange(cluster_size, '... -> ... 1') + + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed_sum) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + def replace(self, batch_samples, batch_mask): + batch_samples = l2norm(batch_samples) + + for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))): + if not torch.any(mask): + continue + + sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) + sampled = rearrange(sampled, '1 ... -> ...') + + self.embed.data[ind][mask] = sampled + self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size + self.cluster_size.data[ind][mask] = self.reset_cluster_size + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') + self.replace(batch_samples, batch_mask = expired_codes) + + @autocast(enabled = False) + def forward( + self, + x, + sample_codebook_temp = None, + mask = None, + freeze_codebook = False + ): + needs_codebook_dim = x.ndim < 4 + sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp) + + x = x.float() + + if needs_codebook_dim: + x = rearrange(x, '... -> 1 ...') + + dtype = x.dtype + + flatten, ps = pack_one(x, 'h * d') + + if exists(mask): + mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1])) + + self.init_embed_(flatten, mask = mask) + + embed = self.embed if self.learnable_codebook else self.embed.detach() + + dist = einsum('h n d, h c d -> h n c', flatten, embed) + + embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training) + embed_ind = unpack_one(embed_ind, ps, 'h *') + + if self.training: + unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c') + quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed) + else: + quantize = batched_embedding(embed_ind, embed) + + if self.training and self.ema_update and not freeze_codebook: + if exists(mask): + embed_onehot[~mask] = 0. + + bins = embed_onehot.sum(dim = 1) + self.all_reduce_fn(bins) + + ema_inplace(self.cluster_size.data, bins, self.decay) + + embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) + self.all_reduce_fn(embed_sum.contiguous()) + ema_inplace(self.embed_avg.data, embed_sum, self.decay) + + cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True) + + embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1') + embed_normalized = l2norm(embed_normalized) + + self.embed.data.copy_(l2norm(embed_normalized)) + self.expire_codes_(x) + + if needs_codebook_dim: + quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) + + dist = unpack_one(dist, ps, 'h * d') + return quantize, embed_ind, dist + +# main class + +class VectorQuantize(nn.Module): + def __init__( + self, + dim, + codebook_size, + codebook_dim = None, + heads = 1, + separate_codebook_per_head = False, + decay = 0.8, + eps = 1e-5, + freeze_codebook = False, + kmeans_init = False, + kmeans_iters = 10, + sync_kmeans = True, + use_cosine_sim = False, + threshold_ema_dead_code = 0, + channel_last = True, + accept_image_fmap = False, + commitment_weight = 1., + commitment_use_cross_entropy_loss = False, + orthogonal_reg_weight = 0., + orthogonal_reg_active_codes_only = False, + orthogonal_reg_max_codes = None, + stochastic_sample_codes = False, + sample_codebook_temp = 1., + straight_through = False, + reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all + sync_codebook = None, + sync_affine_param = False, + ema_update = True, + learnable_codebook = False, + in_place_codebook_optimizer: Callable[..., Optimizer] = None, # Optimizer used to update the codebook embedding if using learnable_codebook + affine_param = False, + affine_param_batch_decay = 0.99, + affine_param_codebook_decay = 0.9, + sync_update_v = 0. # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf + ): + super().__init__() + self.dim = dim + self.heads = heads + self.separate_codebook_per_head = separate_codebook_per_head + + codebook_dim = default(codebook_dim, dim) + codebook_input_dim = codebook_dim * heads + + requires_projection = codebook_input_dim != dim + self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + + self.eps = eps + self.commitment_weight = commitment_weight + self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss + + self.learnable_codebook = learnable_codebook + + has_codebook_orthogonal_loss = orthogonal_reg_weight > 0 + self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss + self.orthogonal_reg_weight = orthogonal_reg_weight + self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only + self.orthogonal_reg_max_codes = orthogonal_reg_max_codes + + assert not (ema_update and learnable_codebook), 'learnable codebook not compatible with EMA update' + + assert 0 <= sync_update_v <= 1. + assert not (sync_update_v > 0. and not learnable_codebook), 'learnable codebook must be turned on' + + self.sync_update_v = sync_update_v + + codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook + + gumbel_sample_fn = partial( + gumbel_sample, + stochastic = stochastic_sample_codes, + reinmax = reinmax, + straight_through = straight_through + ) + + if not exists(sync_codebook): + sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1 + + codebook_kwargs = dict( + dim = codebook_dim, + num_codebooks = heads if separate_codebook_per_head else 1, + codebook_size = codebook_size, + kmeans_init = kmeans_init, + kmeans_iters = kmeans_iters, + sync_kmeans = sync_kmeans, + decay = decay, + eps = eps, + threshold_ema_dead_code = threshold_ema_dead_code, + use_ddp = sync_codebook, + learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook, + sample_codebook_temp = sample_codebook_temp, + gumbel_sample = gumbel_sample_fn, + ema_update = ema_update + ) + + if affine_param: + assert not use_cosine_sim, 'affine param is only compatible with euclidean codebook' + codebook_kwargs = dict( + **codebook_kwargs, + affine_param = True, + sync_affine_param = sync_affine_param, + affine_param_batch_decay = affine_param_batch_decay, + affine_param_codebook_decay = affine_param_codebook_decay, + ) + + self._codebook = codebook_class(**codebook_kwargs) + + self.in_place_codebook_optimizer = in_place_codebook_optimizer(self._codebook.parameters()) if exists(in_place_codebook_optimizer) else None + + self.codebook_size = codebook_size + + self.accept_image_fmap = accept_image_fmap + self.channel_last = channel_last + + @property + def codebook(self): + codebook = self._codebook.embed + + if self.separate_codebook_per_head: + return codebook + + return rearrange(codebook, '1 ... -> ...') + + @codebook.setter + def codebook(self, codes): + if not self.separate_codebook_per_head: + codes = rearrange(codes, '... -> 1 ...') + + self._codebook.embed.copy_(codes) + + def get_codes_from_indices(self, indices): + codebook = self.codebook + is_multiheaded = codebook.ndim > 2 + + if not is_multiheaded: + codes = codebook[indices] + return rearrange(codes, '... h d -> ... (h d)') + + indices, ps = pack_one(indices, 'b * h') + indices = rearrange(indices, 'b n h -> b h n') + + indices = repeat(indices, 'b h n -> b h n d', d = codebook.shape[-1]) + codebook = repeat(codebook, 'h n d -> b h n d', b = indices.shape[0]) + + codes = codebook.gather(2, indices) + codes = rearrange(codes, 'b h n d -> b n (h d)') + codes = unpack_one(codes, ps, 'b * d') + return codes + + def forward( + self, + x, + indices = None, + mask = None, + sample_codebook_temp = None, + freeze_codebook = False + ): + orig_input = x + + only_one = x.ndim == 2 + + if only_one: + assert not exists(mask) + x = rearrange(x, 'b d -> b 1 d') + + shape, device, heads, is_multiheaded, codebook_size, return_loss = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size, exists(indices) + + need_transpose = not self.channel_last and not self.accept_image_fmap + should_inplace_optimize = exists(self.in_place_codebook_optimizer) + + # rearrange inputs + + if self.accept_image_fmap: + height, width = x.shape[-2:] + x = rearrange(x, 'b c h w -> b (h w) c') + + if need_transpose: + x = rearrange(x, 'b d n -> b n d') + + # project input + + x = self.project_in(x) + + # handle multi-headed separate codebooks + + if is_multiheaded: + ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d' + x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads) + + # l2norm for cosine sim, otherwise identity + + x = self._codebook.transform_input(x) + + # codebook forward kwargs + + codebook_forward_kwargs = dict( + sample_codebook_temp = sample_codebook_temp, + mask = mask, + freeze_codebook = freeze_codebook + ) + + # quantize + + quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) + + # one step in-place update + + if should_inplace_optimize and self.training and not freeze_codebook: + + if exists(mask): + loss = F.mse_loss(quantize, x.detach(), reduction = 'none') + + loss_mask = mask + if is_multiheaded: + loss_mask = repeat(mask, 'b n -> c (b h) n', c = loss.shape[0], h = loss.shape[1] // mask.shape[0]) + + loss = loss[loss_mask].mean() + + else: + loss = F.mse_loss(quantize, x.detach()) + + loss.backward() + self.in_place_codebook_optimizer.step() + self.in_place_codebook_optimizer.zero_grad() + + # quantize again + + quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) + + if self.training: + # determine code to use for commitment loss + maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity + + commit_quantize = maybe_detach(quantize) + + # straight through + + quantize = x + (quantize - x).detach() + + if self.sync_update_v > 0.: + # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf + quantize = quantize + self.sync_update_v * (quantize - quantize.detach()) + + # function for calculating cross entropy loss to distance matrix + # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss + + def calculate_ce_loss(codes): + if not is_multiheaded: + dist_einops_eq = '1 b n l -> b l n' + elif self.separate_codebook_per_head: + dist_einops_eq = 'c b n l -> b l n c' + else: + dist_einops_eq = '1 (b h) n l -> b l n h' + + ce_loss = F.cross_entropy( + rearrange(distances, dist_einops_eq, b = shape[0]), + codes, + ignore_index = -1 + ) + + return ce_loss + + # if returning cross entropy loss on codes that were passed in + + if return_loss: + return quantize, calculate_ce_loss(indices) + + # transform embedding indices + + if is_multiheaded: + if self.separate_codebook_per_head: + embed_ind = rearrange(embed_ind, 'h b n -> b n h', h = heads) + else: + embed_ind = rearrange(embed_ind, '1 (b h) n -> b n h', h = heads) + + if self.accept_image_fmap: + embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width) + + if only_one: + embed_ind = rearrange(embed_ind, 'b 1 -> b') + + # aggregate loss + + loss = torch.tensor([0.], device = device, requires_grad = self.training) + + if self.training: + if self.commitment_weight > 0: + if self.commitment_use_cross_entropy_loss: + if exists(mask): + ce_loss_mask = mask + if is_multiheaded: + ce_loss_mask = repeat(ce_loss_mask, 'b n -> b n h', h = heads) + + embed_ind.masked_fill_(~ce_loss_mask, -1) + + commit_loss = calculate_ce_loss(embed_ind) + else: + if exists(mask): + # with variable lengthed sequences + commit_loss = F.mse_loss(commit_quantize, x, reduction = 'none') + + loss_mask = mask + if is_multiheaded: + loss_mask = repeat(loss_mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0]) + + commit_loss = commit_loss[loss_mask].mean() + else: + commit_loss = F.mse_loss(commit_quantize, x) + + loss = loss + commit_loss * self.commitment_weight + + if self.has_codebook_orthogonal_loss: + codebook = self._codebook.embed + + # only calculate orthogonal loss for the activated codes for this batch + + if self.orthogonal_reg_active_codes_only: + assert not (is_multiheaded and self.separate_codebook_per_head), 'orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet' + unique_code_ids = torch.unique(embed_ind) + codebook = codebook[:, unique_code_ids] + + num_codes = codebook.shape[-2] + + if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: + rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes] + codebook = codebook[:, rand_ids] + + orthogonal_reg_loss = orthogonal_loss_fn(codebook) + loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight + + # handle multi-headed quantized embeddings + + if is_multiheaded: + if self.separate_codebook_per_head: + quantize = rearrange(quantize, 'h b n d -> b n (h d)', h = heads) + else: + quantize = rearrange(quantize, '1 (b h) n d -> b n (h d)', h = heads) + + # project out + + quantize = self.project_out(quantize) + + # rearrange quantized embeddings + + if need_transpose: + quantize = rearrange(quantize, 'b n d -> b d n') + + if self.accept_image_fmap: + quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width) + + if only_one: + quantize = rearrange(quantize, 'b 1 d -> b d') + + # if masking, only return quantized for where mask has True + + if exists(mask): + quantize = torch.where( + rearrange(mask, '... -> ... 1'), + quantize, + orig_input + ) + + return quantize, embed_ind, loss diff --git a/quest/algos/baseline_modules/vq_behavior_transformer/gpt.py b/quest/algos/baseline_modules/vq_behavior_transformer/gpt.py new file mode 100644 index 0000000..16f8e46 --- /dev/null +++ b/quest/algos/baseline_modules/vq_behavior_transformer/gpt.py @@ -0,0 +1,291 @@ +""" +An adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch. +Original source: https://github.com/karpathy/nanoGPT + +Original License: +MIT License + +Copyright (c) 2022 Andrej Karpathy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Original comments: +Full definition of a GPT Language Model, all of it in this single file. +References: +1) the official GPT-2 TensorFlow implementation released by OpenAI: +https://github.com/openai/gpt-2/blob/master/src/model.py +2) huggingface/transformers PyTorch implementation: +https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py +""" + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) +def new_gelu(x): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). + Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 + """ + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))) + ) + ) + + +class CausalSelfAttention(nn.Module): + def __init__(self, n_embd, n_head, dropout, block_size): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(n_embd, 3 * n_embd) + # output projection + self.c_proj = nn.Linear(n_embd, n_embd) + # regularization + self.attn_dropout = nn.Dropout(dropout) + self.resid_dropout = nn.Dropout(dropout) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "bias", + torch.tril(torch.ones(block_size, block_size)).view( + 1, 1, block_size, block_size + ), + ) + self.n_head = n_head + self.n_embd = n_embd + + def forward(self, x): + ( + B, + T, + C, + ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class MLP(nn.Module): + def __init__(self, n_embd, dropout): + super().__init__() + self.c_fc = nn.Linear(n_embd, 4 * n_embd) + self.c_proj = nn.Linear(4 * n_embd, n_embd) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.c_fc(x) + x = new_gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + + +class Block(nn.Module): + def __init__(self, n_embd, n_head, dropout, block_size): + super().__init__() + self.ln_1 = nn.LayerNorm(n_embd) + self.attn = CausalSelfAttention(n_embd, n_head, dropout, block_size) + self.ln_2 = nn.LayerNorm(n_embd) + self.mlp = MLP(n_embd, dropout) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +@dataclass +class GPTConfig: + block_size: int = 1024 + input_dim: int = 256 + output_dim: int = 256 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.1 + + +class GPT(nn.Module): + def __init__(self, + block_size: int = 1024, + input_dim: int = 256, + output_dim: int = 256, + n_layer: int = 12, + n_head: int = 12, + n_embd: int = 768, + dropout: float = 0.1): + super().__init__() + assert input_dim is not None + assert output_dim is not None + assert block_size is not None + self.block_size = block_size + self.output_dim = output_dim + self.n_embd = n_embd + # self.config = config + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Linear(input_dim, n_embd), + wpe=nn.Embedding(block_size, n_embd), + drop=nn.Dropout(dropout), + h=nn.ModuleList([Block(n_embd, n_head, dropout, block_size) for _ in range(n_layer)]), + ln_f=nn.LayerNorm(n_embd), + ) + ) + self.lm_head = nn.Linear(n_embd, output_dim, bias=False) + # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * n_layer) + ) + + # report number of parameters + n_params = sum(p.numel() for p in self.parameters()) + print("number of parameters: %.2fM" % (n_params / 1e6,)) + + def forward(self, input, targets=None): + device = input.device + b, t, d = input.size() + assert ( + t <= self.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.block_size}" + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( + 0 + ) # shape (1, t) + + # forward the GPT model itself + tok_emb = self.transformer.wte( + input + ) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe( + pos + ) # position embeddings of shape (1, t, n_embd) + x = self.transformer.drop(tok_emb + pos_emb) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + logits = self.lm_head(x) + return logits + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + + def crop_block_size(self, block_size): + assert block_size <= self.block_size + self.block_size = block_size + self.transformer.wpe.weight = nn.Parameter( + self.transformer.wpe.weight[:block_size] + ) + for block in self.transformer.h: + block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] + + def configure_optimizers(self, weight_decay, learning_rate, betas): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear,) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, ( + "parameters %s made it into both decay/no_decay sets!" + % (str(inter_params),) + ) + assert len(param_dict.keys() - union_params) == 0, ( + "parameters %s were not separated into either decay/no_decay set!" + % (str(param_dict.keys() - union_params),) + ) + + # create the pytorch optimizer object + optim_groups = [ + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + return optimizer diff --git a/quest/algos/baseline_modules/vq_behavior_transformer/utils.py b/quest/algos/baseline_modules/vq_behavior_transformer/utils.py new file mode 100644 index 0000000..a8e49a4 --- /dev/null +++ b/quest/algos/baseline_modules/vq_behavior_transformer/utils.py @@ -0,0 +1,124 @@ +import torch +from typing import Callable, List, Optional + + +class MLP(torch.nn.Sequential): + """This block implements the multi-layer perceptron (MLP) module. + Adapted for backward compatibility from the torchvision library: + https://pytorch.org/vision/0.14/generated/torchvision.ops.MLP.html + + LICENSE: + + From PyTorch: + + Copyright (c) 2016- Facebook, Inc (Adam Paszke) + Copyright (c) 2014- Facebook, Inc (Soumith Chintala) + Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) + Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) + Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) + Copyright (c) 2011-2013 NYU (Clement Farabet) + Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) + Copyright (c) 2006 Idiap Research Institute (Samy Bengio) + Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + + From Caffe2: + + Copyright (c) 2016-present, Facebook Inc. All rights reserved. + + All contributions by Facebook: + Copyright (c) 2016 Facebook Inc. + + All contributions by Google: + Copyright (c) 2015 Google Inc. + All rights reserved. + + All contributions by Yangqing Jia: + Copyright (c) 2015 Yangqing Jia + All rights reserved. + + All contributions by Kakao Brain: + Copyright 2019-2020 Kakao Brain + + All contributions by Cruise LLC: + Copyright (c) 2022 Cruise LLC. + All rights reserved. + + All contributions from Caffe: + Copyright(c) 2013, 2014, 2015, the respective contributors + All rights reserved. + + All other contributions: + Copyright(c) 2015, 2016 the respective contributors + All rights reserved. + + Caffe2 uses a copyright model similar to Caffe: each contributor holds + copyright over their contributions to Caffe2. The project versioning records + all such contribution and copyright details. If a contributor wants to further + mark their specific copyright on a particular contribution, they should + indicate their copyright solely in the commit message of the change when it is + committed. + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + + Args: + in_channels (int): Number of channels of the input + hidden_channels (List[int]): List of the hidden channel dimensions + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` + activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` + inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. + Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. + bias (bool): Whether to use bias in the linear layer. Default ``True`` + dropout (float): The probability for the dropout layer. Default: 0.0 + """ + + def __init__( + self, + in_channels: int, + hidden_channels: List[int], + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + inplace: Optional[bool] = None, + bias: bool = True, + dropout: float = 0.0, + ): + params = {} if inplace is None else {"inplace": inplace} + + layers = [] + in_dim = in_channels + for hidden_dim in hidden_channels[:-1]: + layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) + layers.append(activation_layer(**params)) + layers.append(torch.nn.Dropout(dropout, **params)) + in_dim = hidden_dim + + layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) + layers.append(torch.nn.Dropout(dropout, **params)) + + super().__init__(*layers) diff --git a/quest/algos/baseline_modules/vq_behavior_transformer/vqvae.py b/quest/algos/baseline_modules/vq_behavior_transformer/vqvae.py new file mode 100644 index 0000000..7c8a0c5 --- /dev/null +++ b/quest/algos/baseline_modules/vq_behavior_transformer/vqvae.py @@ -0,0 +1,233 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import jit +import numpy as np +import einops +from quest.algos.baseline_modules.vector_quantize_pytorch_bet.residual_vq import ResidualVQ + +class EncoderMLP(nn.Module): + def __init__( + self, + input_dim, + output_dim=16, + hidden_dim=128, + layer_num=1, + last_activation=None, + ): + super(EncoderMLP, self).__init__() + layers = [] + + layers.append(nn.Linear(input_dim, hidden_dim)) + layers.append(nn.ReLU()) + for _ in range(layer_num): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + layers.append(nn.ReLU()) + + self.encoder = nn.Sequential(*layers) + self.fc = nn.Linear(hidden_dim, output_dim) + + if last_activation is not None: + self.last_layer = last_activation + else: + self.last_layer = None + self.apply(weights_init_encoder) + + def forward(self, x): + h = self.encoder(x) + state = self.fc(h) + if self.last_layer: + state = self.last_layer(state) + return state + + +class VqVae(nn.Module): + def __init__( + self, + obs_dim=60, + input_dim_h=1, # length of action chunk + input_dim_w=7, # action dim + n_latent_dims=512, + vqvae_n_embed=16, + vqvae_groups=2, + hidden_dim=128, + num_layers=1, + device="cuda", + encoder_loss_multiplier=1.0, + act_scale=1.0, + ): + super().__init__() + self.n_latent_dims = n_latent_dims + self.input_dim_h = input_dim_h + self.input_dim_w = input_dim_w + self.rep_dim = self.n_latent_dims + self.vqvae_n_embed = vqvae_n_embed + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.vqvae_lr = 1e-3 + self.vqvae_groups = vqvae_groups + self.device = device + self.encoder_loss_multiplier = encoder_loss_multiplier + self.act_scale = act_scale + + discrete_cfg = {"groups": self.vqvae_groups, "n_embed": self.vqvae_n_embed} + + self.vq_layer = ResidualVQ( + dim=self.n_latent_dims, + num_quantizers=discrete_cfg["groups"], + codebook_size=self.vqvae_n_embed, + ).to(self.device) + self.embedding_dim = self.n_latent_dims + + self.vq_layer.device = device + + if self.input_dim_h == 1: + self.encoder = EncoderMLP( + input_dim=input_dim_w, hidden_dim=self.hidden_dim, layer_num=self.num_layers, output_dim=n_latent_dims + ).to(self.device) + self.decoder = EncoderMLP( + input_dim=n_latent_dims, hidden_dim=self.hidden_dim, layer_num=self.num_layers, output_dim=input_dim_w + ).to(self.device) + else: + self.encoder = EncoderMLP( + input_dim=input_dim_w * self.input_dim_h, hidden_dim=self.hidden_dim, layer_num=self.num_layers, output_dim=n_latent_dims + ).to(self.device) + self.decoder = EncoderMLP( + input_dim=n_latent_dims, hidden_dim=self.hidden_dim, layer_num=self.num_layers, output_dim=input_dim_w * self.input_dim_h + ).to(self.device) + + # params = ( + # list(self.encoder.parameters()) + # + list(self.decoder.parameters()) + # + list(self.vq_layer.parameters()) + # ) + # self.vqvae_optimizer = torch.optim.Adam( + # params, lr=self.vqvae_lr, weight_decay=0.0001 + # ) + + # if load_dir is not None: + # try: + # state_dict = torch.load(load_dir) + # except RuntimeError: + # state_dict = torch.load(load_dir, map_location=torch.device("cpu")) + # self.load_state_dict(state_dict) + + # if eval: + # self.vq_layer.eval() + # else: + # self.vq_layer.train() + + def draw_logits_forward(self, encoding_logits): + z_embed = self.vq_layer.draw_logits_forward(encoding_logits) + return z_embed + + def draw_code_forward(self, encoding_indices): + with torch.no_grad(): + z_embed = self.vq_layer.get_codes_from_indices(encoding_indices) + z_embed = z_embed.sum(dim=0) + return z_embed + + def get_action_from_latent(self, latent): + output = self.decoder(latent) * self.act_scale + if self.input_dim_h == 1: + return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w) + else: + return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w) + + def preprocess(self, state): + if not torch.is_tensor(state): + state = get_tensor(state, self.device) + if self.input_dim_h == 1: + state = state.squeeze(-2) # state.squeeze(-1) + else: + state = einops.rearrange(state, "N T A -> N (T A)") + return state.to(self.device) + + def get_code(self, state, required_recon=False): + state = state / self.act_scale + state = self.preprocess(state) + with torch.no_grad(): + state_rep = self.encoder(state) + state_rep_shape = state_rep.shape[:-1] + state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1)) + state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat) + state_vq = state_rep_flat.view(*state_rep_shape, -1) + vq_code = vq_code.view(*state_rep_shape, -1) + vq_loss_state = torch.sum(vq_loss_state) + if required_recon: + recon_state = self.decoder(state_vq) * self.act_scale + recon_state_ae = self.decoder(state_rep) * self.act_scale + if self.input_dim_h == 1: + return state_vq, vq_code, recon_state, recon_state_ae + else: + return ( + state_vq, + vq_code, + torch.swapaxes(recon_state, -2, -1), + torch.swapaxes(recon_state_ae, -2, -1), + ) + else: + # econ_from_code = self.draw_code_forward(vq_code) + return state_vq, vq_code + + def forward(self, state): + state = state / self.act_scale + state = self.preprocess(state) + state_rep = self.encoder(state) + state_rep_shape = state_rep.shape[:-1] + state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1)) + state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat) + state_vq = state_rep_flat.view(*state_rep_shape, -1) + vq_code = vq_code.view(*state_rep_shape, -1) + vq_loss_state = torch.sum(vq_loss_state) + + dec_out = self.decoder(state_vq) + encoder_loss = (state - dec_out).abs().mean() + rep_loss = encoder_loss * self.encoder_loss_multiplier + (vq_loss_state * 5) + pp = len(torch.unique(vq_code))/1024 + + return dec_out, rep_loss, encoder_loss.clone().detach(), vq_loss_state.clone().detach(), pp + + + # def state_dict(self): + # return { + # "encoder": self.encoder.state_dict(), + # "decoder": self.decoder.state_dict(), + # "vq_embedding": self.vq_layer.state_dict(), + # } + + # def load_state_dict(self, state_dict): + # self.encoder.load_state_dict(state_dict["encoder"]) + # self.decoder.load_state_dict(state_dict["decoder"]) + # self.vq_layer.load_state_dict(state_dict["vq_embedding"]) + # self.vq_layer.eval() + + +def weights_init_encoder(m): + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + m.bias.data.fill_(0.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + assert m.weight.size(2) == m.weight.size(3) + m.weight.data.fill_(0.0) + m.bias.data.fill_(0.0) + mid = m.weight.size(2) // 2 + gain = nn.init.calculate_gain("relu") + nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) + + +# def var(tensor): +# return tensor.to(device) + + +def get_tensor(z, device): + if z is None: + return None + if z[0].dtype == np.dtype("O"): + return None + if len(z.shape) == 1: + return torch.FloatTensor(z.copy()).to(device).unsqueeze(0) + # return torch.from_numpy(z.copy()).float().to(device).unsqueeze(0) + else: + return torch.FloatTensor(z.copy()).to(device) + # return torch.from_numpy(z.copy()).float().to(device) \ No newline at end of file diff --git a/quest/algos/bc_transformer.py b/quest/algos/bc_transformer.py new file mode 100644 index 0000000..074ac99 --- /dev/null +++ b/quest/algos/bc_transformer.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import quest.utils.tensor_utils as TensorUtils +from quest.utils.utils import map_tensor_to_device +import quest.utils.obs_utils as ObsUtils +from quest.algos.base import Policy + +class BCTransformerPolicy(Policy): + def __init__( + self, + transformer_model, + policy_head, + positional_encoding, + loss_reduction, + **kwargs + ): + super().__init__(**kwargs) + self.temporal_transformer = transformer_model.to(self.device) + self.policy_head = policy_head.to(self.device) + self.temporal_position_encoding_fn = positional_encoding.to(self.device) + self.reduction = loss_reduction + + def temporal_encode(self, x): + pos_emb = self.temporal_position_encoding_fn(x) + x = x + pos_emb.unsqueeze(1) # (B, T, num_modality, E) + sh = x.shape + self.temporal_transformer.compute_mask(x.shape) + + x = TensorUtils.join_dimensions(x, 1, 2) # (B, T*num_modality, E) + x = self.temporal_transformer(x) + x = x.reshape(*sh) + return x[:, :, 0] # (B, T, E) + + def spatial_encode(self, data): + obs_emb = self.obs_encode(data) # (B, T, num_mod, E) + text_emb = self.get_task_emb(data) # (B, E) + B, T, num_mod, E = obs_emb.shape + text_emb = text_emb.view(B, 1, 1, -1).expand(-1, T, 1, -1) + x = torch.cat([text_emb, obs_emb], dim=2) # (B, T, num_mod+1, E) + return x + + def forward(self, data): + x = self.spatial_encode(data) + x = self.temporal_encode(x) + dist = self.policy_head(x) + return dist + + def compute_loss(self, data): + data = self.preprocess_input(data, train_mode=True) + dist = self.forward(data) + loss = self.policy_head.loss_fn(dist, data["actions"], self.reduction) + info = { + 'loss': loss.item(), + } + return loss, info + + def sample_actions(self, batch): + batch = self.preprocess_input(batch, train_mode=False) + x = self.spatial_encode(batch) + x = self.temporal_encode(x) + dist = self.policy_head(x[:, -1]) + action = dist.sample().cpu().numpy() + return action + diff --git a/quest/algos/bet.py b/quest/algos/bet.py new file mode 100644 index 0000000..b8399c6 --- /dev/null +++ b/quest/algos/bet.py @@ -0,0 +1,367 @@ +import logging +from enum import Enum +from pathlib import Path +from typing import Dict, Optional, Tuple +from collections import deque + +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F +from quest.algos.baseline_modules.vq_behavior_transformer.utils import MLP +import quest.utils.tensor_utils as TensorUtils + + +from quest.algos.base import ChunkPolicy + +class BehaviorTransformer(ChunkPolicy): + GOAL_SPEC = Enum("GOAL_SPEC", "concat stack unconditional") + + def __init__( + self, + autoencoder, + policy_prior, + stage, + loss_fn, + offset_loss_multiplier: float = 1.0e3, + secondary_code_multiplier: float = 0.5, + frame_stack=10, + skill_block_size=5, + sequentially_select=False, + **kwargs + ): + super().__init__(**kwargs) + + self.autoencoder = autoencoder + self.policy_prior = policy_prior + self.stage = stage + + self.frame_stack = frame_stack + self.skill_block_size = skill_block_size + self.sequentially_select = sequentially_select + self._cbet_method = self.GOAL_SPEC.concat + self._offset_loss_multiplier = offset_loss_multiplier + self._secondary_code_multiplier = secondary_code_multiplier + self._criterion = loss_fn + + # For now, we assume the number of clusters is given. + self._G = self.autoencoder.vqvae_groups # G(number of groups) + self._C = self.autoencoder.vqvae_n_embed # C(number of code integers) + self._D = self.autoencoder.embedding_dim # D(embedding dims) + + if self.sequentially_select: + print("use sequantial prediction for vq dictionary!") + self._map_to_cbet_preds_bin1 = MLP( + in_channels=policy_prior.output_dim, + hidden_channels=[512, 512, self._C], + ) + self._map_to_cbet_preds_bin2 = MLP( + in_channels=policy_prior.output_dim + self._C, + hidden_channels=[512, self._C], + ) + else: + self._map_to_cbet_preds_bin = MLP( + in_channels=policy_prior.output_dim, + hidden_channels=[1024, 1024, self._G * self._C], + ) + self._map_to_cbet_preds_offset = MLP( + in_channels=policy_prior.output_dim, + hidden_channels=[ + 1024, + 1024, + self._G * self._C * (self.shape_meta.action_dim * self.skill_block_size), + ], + ) + + def compute_loss(self, data): + if self.stage == 0: + return self.compute_autoencoder_loss(data) + elif self.stage == 1: + return self.compute_prior_loss(data) + elif self.stage == 2: + return self.compute_prior_loss(data) + + def compute_autoencoder_loss(self, data): + action_input = data["actions"][:, :self.skill_block_size, :] + pred, total_loss, l1_loss, codebook_loss, pp = self.autoencoder(action_input) + info = { + 'recon_loss': l1_loss.item(), + 'codebook_loss': codebook_loss.item(), + 'pp': pp} + return total_loss, info + + def compute_prior_loss(self, data): + data = self.preprocess_input(data) + + context = self.get_context(data) + predicted_action, decoded_action, sampled_centers, logit_info = self._predict(context) + action_seq = data['actions'] + n, total_w, act_dim = action_seq.shape + act_w = self.autoencoder.input_dim_h + obs_w = total_w + 1 - act_w + output_shape = (n, obs_w, act_w, act_dim) + output = torch.empty(output_shape, device=action_seq.device) + for i in range(obs_w): + output[:, i, :, :] = action_seq[:, i : i + act_w, :] + action_seq = einops.rearrange(output, "N T W A -> (N T) W A") + NT = action_seq.shape[0] + # First, we need to find the closest cluster center for each action. + state_vq, action_bins = self.autoencoder.get_code( + action_seq + ) # action_bins: NT, G + + # Now we can compute the loss. + if action_seq.ndim == 2: + action_seq = action_seq.unsqueeze(0) + + offset_loss = torch.nn.L1Loss()(action_seq, predicted_action) + + action_diff = F.mse_loss( + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[ + :, -1, 0, : + ], + einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[ + :, -1, 0, : + ], + ) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout + action_diff_tot = F.mse_loss( + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[ + :, -1, :, : + ], + einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[ + :, -1, :, : + ], + ) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout + action_diff_mean_res1 = ( + abs( + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[ + :, -1, 0, : + ] + - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=obs_w)[ + :, -1, 0, : + ] + ) + ).mean() + action_diff_mean_res2 = ( + abs( + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[ + :, -1, 0, : + ] + - einops.rearrange( + predicted_action, "(N T) W A -> N T W A", T=obs_w + )[:, -1, 0, :] + ) + ).mean() + action_diff_max = ( + abs( + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[ + :, -1, 0, : + ] + - einops.rearrange( + predicted_action, "(N T) W A -> N T W A", T=obs_w + )[:, -1, 0, :] + ) + ).max() + + if self.sequentially_select: + cbet_logits1, gpt_output = logit_info + cbet_loss1 = self._criterion( # F.cross_entropy + cbet_logits1[:, :], + action_bins[:, 0], + ) + cbet_logits2 = self._map_to_cbet_preds_bin2( + torch.cat( + (gpt_output, F.one_hot(action_bins[:, 0], num_classes=self._C)), + axis=1, + ) + ) + cbet_loss2 = self._criterion( # F.cross_entropy + cbet_logits2[:, :], + action_bins[:, 1], + ) + else: + cbet_logits = logit_info + cbet_loss1 = self._criterion( # F.cross_entropy + cbet_logits[:, 0, :], + action_bins[:, 0], + ) + cbet_loss2 = self._criterion( # F.cross_entropy + cbet_logits[:, 1, :], + action_bins[:, 1], + ) + cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self._secondary_code_multiplier + + equal_total_code_rate = ( + torch.sum( + ( + torch.sum((action_bins == sampled_centers).int(), axis=1) == self._G + ).int() + ) + / NT + ) + equal_single_code_rate = torch.sum( + (action_bins[:, 0] == sampled_centers[:, 0]).int() + ) / (NT) + equal_single_code_rate2 = torch.sum( + (action_bins[:, 1] == sampled_centers[:, 1]).int() + ) / (NT) + + loss = cbet_loss + self._offset_loss_multiplier * offset_loss + info = { + "classification_loss": cbet_loss.detach().cpu().item(), + "offset_loss": offset_loss.detach().cpu().item(), + "total_loss": loss.detach().cpu().item(), + "equal_total_code_rate": equal_total_code_rate.item(), + "equal_single_code_rate": equal_single_code_rate.item(), + "equal_single_code_rate2": equal_single_code_rate2.item(), + "action_diff": action_diff.detach().cpu().item(), + "action_diff_tot": action_diff_tot.detach().cpu().item(), + "action_diff_mean_res1": action_diff_mean_res1.detach().cpu().item(), + "action_diff_mean_res2": action_diff_mean_res2.detach().cpu().item(), + "action_diff_max": action_diff_max.detach().cpu().item(), + } + return loss, info + + def _predict( + self, + gpt_input): + + gpt_output = self.policy_prior(gpt_input) + + # there is one task embedding vector in the context so we slice it out here + gpt_output = gpt_output[:, 1:, :] + + gpt_output = einops.rearrange(gpt_output, "N T (G C) -> (N T) (G C)", G=self._G) + + if self.sequentially_select: + cbet_logits1 = self._map_to_cbet_preds_bin1(gpt_output) + cbet_offsets = self._map_to_cbet_preds_offset(gpt_output) + cbet_offsets = einops.rearrange( + cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C + ) + cbet_probs1 = torch.softmax(cbet_logits1, dim=-1) + NT, choices = cbet_probs1.shape + G = self._G + sampled_centers1 = einops.rearrange( + torch.multinomial(cbet_probs1.view(-1, choices), num_samples=1), + "(NT) 1 -> NT", + NT=NT, + ) + cbet_logits2 = self._map_to_cbet_preds_bin2( + torch.cat( + (gpt_output, F.one_hot(sampled_centers1, num_classes=self._C)), + axis=1, + ) + ) + cbet_probs2 = torch.softmax(cbet_logits2, dim=-1) + sampled_centers2 = einops.rearrange( + torch.multinomial(cbet_probs2.view(-1, choices), num_samples=1), + "(NT) 1 -> NT", + NT=NT, + ) + sampled_centers = torch.stack( + (sampled_centers1, sampled_centers2), axis=1 + ) # NT, G + else: + cbet_logits = self._map_to_cbet_preds_bin(gpt_output) + cbet_offsets = self._map_to_cbet_preds_offset(gpt_output) + cbet_logits = einops.rearrange( + cbet_logits, "(NT) (G C) -> (NT) G C", G=self._G + ) + cbet_offsets = einops.rearrange( + cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C + ) + cbet_probs = torch.softmax(cbet_logits, dim=-1) + NT, G, choices = cbet_probs.shape + sampled_centers = einops.rearrange( + torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), + "(NT G) 1 -> NT G", + NT=NT, + ) + + indices = ( + torch.arange(NT).unsqueeze(1).cuda(), + torch.arange(self._G).unsqueeze(0).cuda(), + sampled_centers, + ) + # Use advanced indexing to sample the values + sampled_offsets = cbet_offsets[indices] # NT, G, W, A(?) or NT, G, A + + sampled_offsets = sampled_offsets.sum(dim=1) + centers = self.autoencoder.draw_code_forward(sampled_centers).view( + NT, -1, self._D + ) + return_decoder_input = einops.rearrange( + centers.clone().detach(), "NT G D -> NT (G D)" + ) + decoded_action = ( + self.autoencoder.get_action_from_latent(return_decoder_input) + .clone() + .detach() + ) # NT, A + sampled_offsets = einops.rearrange(sampled_offsets, "NT (W A) -> NT W A", W=self.autoencoder.input_dim_h) + predicted_action = decoded_action + sampled_offsets + + if self.sequentially_select: + return predicted_action, decoded_action, sampled_centers, (cbet_logits1, gpt_output) + return predicted_action, decoded_action, sampled_centers, cbet_logits + + def get_optimizers(self): + if self.stage == 0: + decay, no_decay = TensorUtils.separate_no_decay(self.autoencoder) + optimizers = [ + self.optimizer_factory(params=decay), + self.optimizer_factory(params=no_decay, weight_decay=0.) + ] + return optimizers + elif self.stage == 1: + decay, no_decay = TensorUtils.separate_no_decay(self, + name_blacklist=('autoencoder',)) + optimizers = [ + self.optimizer_factory(params=decay), + self.optimizer_factory(params=no_decay, weight_decay=0.) + ] + return optimizers + elif self.stage == 2: + decay, no_decay = TensorUtils.separate_no_decay(self, + name_blacklist=('autoencoder',)) + optimizers = [ + self.optimizer_factory(params=decay), + self.optimizer_factory(params=no_decay, weight_decay=0.) + ] + return optimizers + + def sample_actions(self, data): + data = self.preprocess_input(data, train_mode=False) + + context = self.get_context(data) + # breakpoint() + predicted_act, _, _, _ = self._predict(context) + + predicted_act = einops.rearrange(predicted_act, "(N T) W A -> N T W A", T=self.frame_stack)[:, -1, :, :] + predicted_act = predicted_act.permute(1,0,2) + return predicted_act.detach().cpu().numpy() + + def get_context(self, data): + obs_emb = self.obs_encode(data) + task_emb = self.get_task_emb(data).unsqueeze(1) + context = torch.cat([task_emb, obs_emb], dim=1) + return context + + +class FocalLoss(nn.Module): + def __init__(self, gamma: float = 0, size_average: bool = True): + super(FocalLoss, self).__init__() + self.gamma = gamma + self.size_average = size_average + + def forward(self, input, target): + logpt = F.log_softmax(input, dim=-1) + logpt = logpt.gather(1, target.view(-1, 1)).view(-1) + pt = logpt.exp() + + loss = -1 * (1 - pt) ** self.gamma * logpt + if self.size_average: + return loss.mean() + else: + return loss.sum() diff --git a/quest/algos/diffusion_policy.py b/quest/algos/diffusion_policy.py new file mode 100644 index 0000000..9f83d33 --- /dev/null +++ b/quest/algos/diffusion_policy.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from quest.algos.baseline_modules.diffusion_modules import ConditionalUnet1D +from diffusers.training_utils import EMAModel +from quest.algos.base import ChunkPolicy + +class DiffusionPolicy(ChunkPolicy): + def __init__( + self, + diffusion_model, + **kwargs + ): + super().__init__(**kwargs) + + self.diffusion_model = diffusion_model.to(self.device) + + def compute_loss(self, data): + data = self.preprocess_input(data, train_mode=True) + cond = self.get_cond(data) + loss = self.diffusion_model(cond, data["actions"]) + info = { + 'loss': loss.item(), + } + return loss, info + + def sample_actions(self, data): + data = self.preprocess_input(data, train_mode=False) + cond = self.get_cond(data) + actions = self.diffusion_model.get_action(cond) + actions = actions.permute(1,0,2) + return actions.detach().cpu().numpy() + + def get_cond(self, data): + obs_emb = self.obs_encode(data) + obs_emb = obs_emb.reshape(obs_emb.shape[0], -1) + lang_emb = self.get_task_emb(data) + cond = torch.cat([obs_emb, lang_emb], dim=-1) + return cond + + +class DiffusionModel(nn.Module): + def __init__(self, + noise_scheduler, + action_dim, + global_cond_dim, + diffusion_step_emb_dim, + down_dims, + ema_power, + skill_block_size, + diffusion_inf_steps, + device): + super().__init__() + self.device = device + net = ConditionalUnet1D( + input_dim=action_dim, + global_cond_dim=global_cond_dim, + diffusion_step_embed_dim=diffusion_step_emb_dim, + down_dims=down_dims, + ).to(self.device) + self.ema = EMAModel( + parameters=net.parameters(), + decay=ema_power) + self.net = net + self.noise_scheduler = noise_scheduler + self.action_dim = action_dim + self.skill_block_size = skill_block_size + self.diffusion_inf_steps = diffusion_inf_steps + + def forward(self, cond, actions): + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, + (cond.shape[0],), device=self.device + ).long() + noise = torch.randn(actions.shape, device=self.device) + # add noise to the clean actions according to the noise magnitude at each diffusion iteration + # (this is the forward diffusion process) + noisy_actions = self.noise_scheduler.add_noise( + actions, noise, timesteps) + # predict the noise residual + noise_pred = self.net( + noisy_actions, timesteps, global_cond=cond) + loss = F.mse_loss(noise_pred, noise) + return loss + + def get_action(self, cond): + nets = self.net + noisy_action = torch.randn( + (cond.shape[0], self.skill_block_size, self.action_dim), device=self.device) + naction = noisy_action + # init scheduler + self.noise_scheduler.set_timesteps(self.diffusion_inf_steps) + + for k in self.noise_scheduler.timesteps: + # predict noise + noise_pred = nets( + sample=naction, + timestep=k, + global_cond=cond + ) + # inverse diffusion step (remove noise) + naction = self.noise_scheduler.step( + model_output=noise_pred, + timestep=k, + sample=naction + ).prev_sample + return naction + + def ema_update(self): + self.ema.step(self.net.parameters()) \ No newline at end of file diff --git a/quest/algos/quest.py b/quest/algos/quest.py new file mode 100644 index 0000000..cc2976b --- /dev/null +++ b/quest/algos/quest.py @@ -0,0 +1,120 @@ +import torch +import torch.nn.functional as F +import numpy as np +import quest.utils.tensor_utils as TensorUtils +import itertools + +from quest.algos.base import ChunkPolicy + + +class QueST(ChunkPolicy): + def __init__(self, + autoencoder, + policy_prior, + stage, + loss_fn, + l1_loss_scale, + **kwargs + ): + super().__init__(**kwargs) + self.autoencoder = autoencoder + self.policy_prior = policy_prior + self.stage = stage + + self.start_token = self.policy_prior.start_token + self.l1_loss_scale = l1_loss_scale if stage == 2 else 0 + self.codebook_size = np.array(autoencoder.fsq_level).prod() + + self.loss = loss_fn + + def get_optimizers(self): + if self.stage == 0: + decay, no_decay = TensorUtils.separate_no_decay(self.autoencoder) + optimizers = [ + self.optimizer_factory(params=decay), + self.optimizer_factory(params=no_decay, weight_decay=0.) + ] + return optimizers + elif self.stage == 1: + decay, no_decay = TensorUtils.separate_no_decay(self, + name_blacklist=('autoencoder',)) + optimizers = [ + self.optimizer_factory(params=decay), + self.optimizer_factory(params=no_decay, weight_decay=0.) + ] + return optimizers + elif self.stage == 2: + decay, no_decay = TensorUtils.separate_no_decay(self, + name_blacklist=('autoencoder',)) + decoder_decay, decoder_no_decay = TensorUtils.separate_no_decay(self.autoencoder.decoder) + optimizers = [ + self.optimizer_factory(params=itertools.chain(decay, decoder_decay)), + self.optimizer_factory(params=itertools.chain(no_decay, decoder_no_decay), weight_decay=0.) + ] + return optimizers + + def get_context(self, data): + obs_emb = self.obs_encode(data) + task_emb = self.get_task_emb(data).unsqueeze(1) + context = torch.cat([task_emb, obs_emb], dim=1) + return context + + def compute_loss(self, data): + if self.stage == 0: + return self.compute_autoencoder_loss(data) + elif self.stage == 1: + return self.compute_prior_loss(data) + elif self.stage == 2: + return self.compute_prior_loss(data) + + def compute_autoencoder_loss(self, data): + pred, pp, pp_sample, aux_loss, _ = self.autoencoder(data["actions"]) + recon_loss = self.loss(pred, data["actions"]) + if self.autoencoder.vq_type == 'vq': + loss = recon_loss + aux_loss + else: + loss = recon_loss + + info = { + 'loss': loss.item(), + 'recon_loss': recon_loss.item(), + 'aux_loss': aux_loss.sum().item(), + 'pp': pp.item(), + 'pp_sample': pp_sample.item(), + } + return loss, info + + def compute_prior_loss(self, data): + data = self.preprocess_input(data, train_mode=True) + with torch.no_grad(): + indices = self.autoencoder.get_indices(data["actions"]).long() + context = self.get_context(data) + start_tokens = (torch.ones((context.shape[0], 1), device=self.device, dtype=torch.long) * self.start_token) + x = torch.cat([start_tokens, indices[:,:-1]], dim=1) + targets = indices.clone() + logits = self.policy_prior(x, context) + prior_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + with torch.no_grad(): + logits = logits[:,:,:self.codebook_size] + probs = torch.softmax(logits, dim=-1) + sampled_indices = torch.multinomial(probs.view(-1,logits.shape[-1]),1) + sampled_indices = sampled_indices.view(-1,logits.shape[1]) + + pred_actions = self.autoencoder.decode_actions(sampled_indices) + l1_loss = self.loss(pred_actions, data["actions"]) + total_loss = prior_loss + self.l1_loss_scale * l1_loss + info = { + 'loss': total_loss.item(), + 'nll_loss': prior_loss.item(), + 'l1_loss': l1_loss.item() + } + return total_loss, info + + def sample_actions(self, data): + data = self.preprocess_input(data, train_mode=False) + context = self.get_context(data) + sampled_indices = self.policy_prior.get_indices_top_k(context, self.codebook_size) + pred_actions = self.autoencoder.decode_actions(sampled_indices) + pred_actions = pred_actions.permute(1,0,2) + return pred_actions.detach().cpu().numpy() diff --git a/quest/algos/quest_modules/skill_gpt.py b/quest/algos/quest_modules/skill_gpt.py new file mode 100644 index 0000000..e74b5b0 --- /dev/null +++ b/quest/algos/quest_modules/skill_gpt.py @@ -0,0 +1,89 @@ +from torch import nn +import torch +from torch.nn import functional as F +from positional_encodings.torch_encodings import PositionalEncoding1D, Summer +from quest.algos.utils.mlp_proj import MLPProj + + +class SkillGPT(nn.Module): + def __init__(self, + action_dim, + start_token, + vocab_size, + block_size, + n_layer, + n_head, + n_embd, + attn_pdrop, + embd_pdrop, + beam_size, # value of k for top k sampling + temperature, # temperature for sampling + device, + ): + super().__init__() + self.action_dim = action_dim + self.start_token = start_token + self.block_size = block_size + self.n_embd = n_embd + self.beam_size = beam_size + self.temperature = temperature + self.device = device + + self.tok_emb = nn.Embedding(vocab_size+1, n_embd) + self.add_positional_emb = Summer(PositionalEncoding1D(n_embd)) + self.decoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model=n_embd, + nhead=n_head, + dim_feedforward=4*n_embd, + dropout=attn_pdrop, + activation='gelu', + batch_first=True, + norm_first=True + ), + num_layers=n_layer, + enable_nested_tensor=False, + ) + self.head = nn.Linear(n_embd, vocab_size) + self.drop = nn.Dropout(embd_pdrop) + self.lnf = nn.LayerNorm(n_embd) + + def forward(self, idx, context): + x = self.tok_emb(idx) + x = self.add_positional_emb(x) + x = torch.cat([context, x], dim=1) + x = self.drop(x) + mask = nn.Transformer.generate_square_subsequent_mask(x.size(1),x.device) + x = self.decoder(x, mask=mask, is_causal=True) + x = x[:, context.size(1):, :] + x = self.lnf(x) + logits = self.head(x) + return logits + + def get_indices_top_k(self, context, codebook_size): + x = torch.ones((context.shape[0], 1), device=self.device, dtype=torch.long) * self.start_token + for i in range(self.block_size): + if i == self.block_size-1: + logits = self.forward(x, context) + logits = logits[:,:,:codebook_size] + else: + logits = self.forward(x, context) + logits = logits[:,:,:codebook_size] + next_indices = top_k_sampling(logits[:,-1,:], self.beam_size, self.temperature) + x = torch.cat([x, next_indices], dim=1) + return x[:,1:] + +def top_k_sampling(logits, k, temperature=1.0): + # Apply temperature scaling + scaled_logits = logits / temperature + # Find the top k values and indices + top_values, top_indices = torch.topk(scaled_logits, k, dim=-1) + # Compute probabilities from top values + top_probs = torch.softmax(top_values, dim=-1) + # Sample token index from the filtered probabilities + sampled_indices = torch.multinomial(top_probs, num_samples=1, replacement=True) + # Map the sampled index back to the original logits tensor + original_indices = top_indices.gather(-1, sampled_indices) + return original_indices + + diff --git a/quest/algos/quest_modules/skill_vae.py b/quest/algos/quest_modules/skill_vae.py new file mode 100644 index 0000000..692dbe5 --- /dev/null +++ b/quest/algos/quest_modules/skill_vae.py @@ -0,0 +1,297 @@ +import numpy as np +from torch import nn +import torch +from torch.nn import functional as F +from einops.layers.torch import Rearrange +from vector_quantize_pytorch import VectorQuantize, FSQ +from positional_encodings.torch_encodings import PositionalEncoding1D, Summer + + + +############################################################################### +# +# Skill-VAE module +# +############################################################################### + +def get_fsq_level(codebook_size): + power = int(np.log2(codebook_size)) + if power == 4: # 16 + fsq_level = [5, 3] + elif power == 6: # 64 + fsq_level = [8, 8] + elif power == 8: # 256 + fsq_level = [8, 6, 5] + elif power == 9: # 512 + fsq_level = [8, 8, 8] + elif power == 10: # 1024 + fsq_level = [8, 5, 5, 5] + elif power == 11: # 2048 + fsq_level = [8, 8, 6, 5] + elif power == 12: # 4096 + fsq_level = [7, 5, 5, 5, 5] + return fsq_level + + +class SkillVAE(nn.Module): + def __init__(self, + action_dim, + encoder_dim, + decoder_dim, + + skill_block_size, + downsample_factor, + + attn_pdrop, + use_causal_encoder, + use_causal_decoder, + + encoder_heads, + encoder_layers, + decoder_heads, + decoder_layers, + + vq_type, + fsq_level, + codebook_dim, + codebook_size, + ): + super().__init__() + self.encoder_dim = encoder_dim + self.decoder_dim = decoder_dim + self.skill_block_size = skill_block_size + self.use_causal_encoder = use_causal_encoder + self.use_causal_decoder = use_causal_decoder + self.vq_type = vq_type + self.fsq_level = fsq_level + + assert int(np.log2(downsample_factor)) == np.log2(downsample_factor), 'downsample_factor must be a power of 2' + strides = [2] * int(np.log2(downsample_factor)) + [1] + kernel_sizes = [5] + [3] * int(np.log2(downsample_factor)) + + if vq_type == 'vq': + self.vq = VectorQuantize(dim=encoder_dim, codebook_dim=codebook_dim, codebook_size=codebook_size) + elif vq_type == 'fsq': + if fsq_level is None: + fsq_level = get_fsq_level(codebook_size) + self.vq = FSQ(dim=encoder_dim, levels=fsq_level) + else: + raise NotImplementedError('Unknown vq_type') + self.action_proj = nn.Linear(action_dim, encoder_dim) + self.action_head = nn.Linear(decoder_dim, action_dim) + self.conv_block = ResidualTemporalBlock( + encoder_dim, encoder_dim, kernel_size=kernel_sizes, + stride=strides, causal=use_causal_encoder) + + encoder_layer = nn.TransformerEncoderLayer(d_model=encoder_dim, + nhead=encoder_heads, + dim_feedforward=4*encoder_dim, + dropout=attn_pdrop, + activation='gelu', + batch_first=True, + norm_first=True) + self.encoder = nn.TransformerEncoder(encoder_layer, + num_layers=encoder_layers, + enable_nested_tensor=False) + decoder_layer = nn.TransformerDecoderLayer(d_model=decoder_dim, + nhead=decoder_heads, + dim_feedforward=4*decoder_dim, + dropout=attn_pdrop, + activation='gelu', + batch_first=True, + norm_first=True) + self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=decoder_layers) + self.add_positional_emb = Summer(PositionalEncoding1D(encoder_dim)) + self.fixed_positional_emb = PositionalEncoding1D(decoder_dim) + + def encode(self, act, obs_emb=None): + x = self.action_proj(act) + x = self.conv_block(x) + B, H, D = x.shape + + if obs_emb is not None: + x = torch.cat([obs_emb, x], dim=1) + x = self.add_positional_emb(x) + + if self.use_causal_encoder: + mask = nn.Transformer.generate_square_subsequent_mask(x.size(1), device=x.device) + x = self.encoder(x, mask=mask, is_causal=True) + else: + x = self.encoder(x) + + x = x[:, -H:] + + return x + + def quantize(self, z): + if self.vq_type == 'vq': + codes, indices, commitment_loss = self.vq(z) + pp = torch.tensor(torch.unique(indices).shape[0] / self.vq.codebook_size, device=z.device) + else: + codes, indices = self.vq(z) + commitment_loss = torch.tensor([0.0], device=z.device) + pp = torch.tensor(torch.unique(indices).shape[0] / self.vq.codebook_size, device=z.device) + ## pp_sample is the average number of unique indices per sequence while pp is for the whole batch + pp_sample = torch.tensor(np.mean([len(torch.unique(index_seq)) for index_seq in indices])/z.shape[1], device=z.device) + return codes, indices, pp, pp_sample, commitment_loss + + def decode(self, codes, obs_emb=None): + x = self.fixed_positional_emb(torch.zeros((codes.shape[0], self.skill_block_size, self.decoder_dim), dtype=codes.dtype, device=codes.device)) + if obs_emb is not None: + codes = torch.cat([obs_emb, codes], dim=1) + if self.use_causal_decoder: + mask = nn.Transformer.generate_square_subsequent_mask(x.size(1), device=x.device) + x = self.decoder(x, codes, tgt_mask=mask, tgt_is_causal=True) + else: + x = self.decoder(x, codes) + x = self.action_head(x) + return x + + def forward(self, act, obs_emb=None): + z = self.encode(act, obs_emb=obs_emb) + codes, _, pp, pp_sample, commitment_loss = self.quantize(z) + x = self.decode(codes, obs_emb=obs_emb) + return x, pp, pp_sample, commitment_loss, codes + + def get_indices(self, act, obs_emb=None): + z = self.encode(act, obs_emb=obs_emb) + _, indices, _, _, _ = self.quantize(z) + return indices + + def decode_actions(self, indices): + if self.vq_type == 'fsq': + codes = self.vq.indices_to_codes(indices) + else: + codes = self.vq.get_output_from_indices(indices) + x = self.decode(codes) + return x + + @property + def device(self): + return next(self.parameters()).device + + +class CausalConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation, stride, no_pad=False): + super(CausalConv1d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + if no_pad: + self.padding = 0 + else: + self.padding = dilation*(kernel_size-1) + self.stride = stride + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.padding, dilation=dilation, stride=stride) + + def forward(self, x): + x = self.conv(x) + last_n = (2*self.padding-self.kernel_size)//self.stride + 1 + if last_n> 0: + return x[:, :, :-last_n] + else: + return x + + +class Conv1dBlock(nn.Module): + ''' + Conv1d --> GroupNorm --> Mish + from https://github.com/jannerm/diffuser/blob/06b8e6a042e6a3312d50ed8048cba14afeab3085/diffuser/models/helpers.py#L46 + ''' + def __init__(self, inp_channels, out_channels, kernel_size, stride, n_groups=4, causal=True, no_pad=False): + super().__init__() + if causal: + conv = CausalConv1d(inp_channels, out_channels, kernel_size, dilation=1, stride=stride, no_pad=no_pad) + else: + conv = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size//2, stride=stride) + + self.block = nn.Sequential( + conv, + Rearrange('batch channels horizon -> batch channels 1 horizon'), + nn.GroupNorm(n_groups, out_channels), + Rearrange('batch channels 1 horizon -> batch channels horizon'), + nn.Mish(), + ) + def forward(self, x): + return self.block(x) + + +# TODO: delete deconv modules for final release version +class CausalDeConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation, stride): + super(CausalDeConv1d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride) + + def forward(self, x): + x = self.conv(x) + last_n = self.kernel_size-self.stride + if last_n> 0: + return x[:, :, :-last_n] + else: + return x + +class DeConv1dBlock(nn.Module): + ''' + Conv1d --> GroupNorm --> Mish + from https://github.com/jannerm/diffuser/blob/06b8e6a042e6a3312d50ed8048cba14afeab3085/diffuser/models/helpers.py#L46 + ''' + def __init__(self, inp_channels, out_channels, kernel_size, stride, n_groups=8, causal=True): + super().__init__() + if causal: + conv = CausalDeConv1d(inp_channels, out_channels, kernel_size, dilation=1, stride=stride) + else: + conv = nn.ConvTranspose1d(inp_channels, out_channels, kernel_size, padding=kernel_size//2, stride=stride, output_padding=stride-1) + + self.block = nn.Sequential( + conv, + Rearrange('batch channels horizon -> batch channels 1 horizon'), + nn.GroupNorm(n_groups, out_channels), + Rearrange('batch channels 1 horizon -> batch channels horizon'), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class ResidualTemporalBlock(nn.Module): + def __init__(self, inp_channels, out_channels, kernel_size=[5,3], stride=[2,2], n_groups=8, causal=True, residual=False, pooling_layers=[]): + super().__init__() + self.pooling_layers = pooling_layers + self.blocks = nn.ModuleList() + for i in range(len(kernel_size)): + block = Conv1dBlock( + inp_channels if i == 0 else out_channels, + out_channels, + kernel_size[i], + stride[i], + n_groups=n_groups, + causal=causal + ) + self.blocks.append(block) + if residual: + if out_channels == inp_channels and stride[0] == 1: + self.residual_conv = nn.Identity() + else: + self.residual_conv = nn.Conv1d(inp_channels, out_channels, kernel_size=1, stride=sum(stride)) + if pooling_layers: + self.pooling = nn.AvgPool1d(kernel_size=2, stride=2) + + def forward(self, input_dict): + x = input_dict + x = torch.transpose(x, 1, 2) + out = x + layer_num = 0 + for block in self.blocks: + out = block(out) + if hasattr(self, 'pooling'): + if layer_num in self.pooling_layers: + out = self.pooling(out) + layer_num += 1 + if hasattr(self, 'residual_conv'): + out = out + self.residual_conv(x) + return torch.transpose(out, 1, 2) diff --git a/quest/algos/utils/data_augmentation.py b/quest/algos/utils/data_augmentation.py new file mode 100644 index 0000000..a032895 --- /dev/null +++ b/quest/algos/utils/data_augmentation.py @@ -0,0 +1,179 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from quest.algos.utils.obs_core import CropRandomizer +import einops + + +class IdentityAug(nn.Module): + def __init__(self, shape_meta=None, *args, **kwargs): + super().__init__() + + def forward(self, x): + return x + + +class TranslationAug(nn.Module): + """ + Utilize the random crop from robomimic. + """ + + def __init__( + self, + shape_meta, + translation, + ): + super().__init__() + + self.randomizers = {} + self.shape_meta = shape_meta + + for name, input_shape in shape_meta['observation']['rgb'].items(): + input_shape = tuple(input_shape) + + self.pad_translation = translation // 2 + pad_output_shape = ( + input_shape[0], + input_shape[1] + translation, + input_shape[2] + translation, + ) + + crop_randomizer = CropRandomizer( + input_shape=pad_output_shape, + crop_height=input_shape[1], + crop_width=input_shape[2], + ) + self.randomizers[input_shape] = crop_randomizer + + def forward(self, data): + if self.training: + + for name in self.shape_meta['observation']['rgb']: + obs_data = data['obs'] + x = obs_data[name] + + batch_size, temporal_len, img_c, img_h, img_w = x.shape + + input_shape = (img_c, img_h, img_w) + crop_randomizer = self.randomizers[input_shape] + + x = x.reshape(batch_size, temporal_len * img_c, img_h, img_w) + out = F.pad(x, pad=(self.pad_translation,) * 4, mode="replicate") + out = crop_randomizer.forward_in(out) + out = out.reshape(batch_size, temporal_len, img_c, img_h, img_w) + + obs_data[name] = out + return data + + +class ImgColorJitterAug(torch.nn.Module): + """ + Conduct color jittering augmentation outside of proposal boxes + """ + + def __init__( + self, + shape_meta, + brightness=0.3, + contrast=0.3, + saturation=0.3, + hue=0.3, + epsilon=0.05, + ): + super().__init__() + self.color_jitter = torchvision.transforms.ColorJitter( + brightness=brightness, contrast=contrast, saturation=saturation, hue=hue + ) + self.epsilon = epsilon + self.shape_meta = shape_meta + + def forward(self, data): + if self.training and np.random.rand() > self.epsilon: + for name in self.shape_meta['observation']['rgb']: + data['obs'][name] = self.color_jitter(data['obs'][name]) + return data + + +class ImgColorJitterGroupAug(torch.nn.Module): + """ + Conduct color jittering augmentation outside of proposal boxes + """ + + def __init__( + self, + shape_meta, + brightness=0.3, + contrast=0.3, + saturation=0.3, + hue=0.3, + epsilon=0.05, + ): + super().__init__() + self.color_jitter = torchvision.transforms.ColorJitter( + brightness=brightness, contrast=contrast, saturation=saturation, hue=hue + ) + self.epsilon = epsilon + self.shape_meta = shape_meta + + def forward(self, x): + raise NotImplementedError + if self.training and np.random.rand() > self.epsilon: + out = self.color_jitter(x) + else: + out = x + return out + + +class BatchWiseImgColorJitterAug(torch.nn.Module): + """ + Color jittering augmentation to individual batch. + This is to create variation in training data to combat + BatchNorm in convolution network. + """ + + def __init__( + self, + shape_meta, + brightness=0.3, + contrast=0.3, + saturation=0.3, + hue=0.3, + epsilon=0.1, + ): + super().__init__() + self.color_jitter = torchvision.transforms.ColorJitter( + brightness=brightness, contrast=contrast, saturation=saturation, hue=hue + ) + self.epsilon = epsilon + self.shape_meta = shape_meta + + def forward(self, data): + if self.training: + for name in self.shape_meta['observation']['rgb']: + x = data['obs'][name] + mask = torch.rand((x.shape[0], *(1,)*(len(x.shape)-1)), device=x.device) > self.epsilon + + jittered = self.color_jitter(x) + + out = mask * jittered + torch.logical_not(mask) * x + data['obs'][name] = out + + return data + + +class DataAugGroup(nn.Module): + """ + Add augmentation to multiple inputs + """ + + def __init__(self, aug_list, shape_meta): + super().__init__() + aug_list = [aug(shape_meta) for aug in aug_list] + self.aug_layer = nn.Sequential(*aug_list) + + def forward(self, data): + return self.aug_layer(data) + \ No newline at end of file diff --git a/quest/algos/utils/mlp_proj.py b/quest/algos/utils/mlp_proj.py new file mode 100644 index 0000000..f1c5fe6 --- /dev/null +++ b/quest/algos/utils/mlp_proj.py @@ -0,0 +1,32 @@ +import torch.nn as nn + + +class MLPProj(nn.Module): + """ + Encode any embedding + + h = f(e), where + e: embedding from some model + h: latent embedding (B, H) + """ + + def __init__(self, input_size, output_size, hidden_size=None, num_layers=1, dropout=0.0): + super().__init__() + assert num_layers >= 1, "[error] num_layers < 1" + sizes = [input_size] + [hidden_size] * (num_layers - 1) + [output_size] + layers = [] + for i in range(num_layers - 1): + layers.append(nn.Linear(sizes[i], sizes[i + 1])) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Dropout(p=dropout)) + layers.append(nn.Linear(sizes[-2], sizes[-1])) + self.projection = nn.Sequential(*layers) + self.out_channels = output_size + + def forward(self, data): + """ + data: + task_emb: (B, E) + """ + h = self.projection(data) # (B, H) + return h diff --git a/quest/algos/utils/obs_core.py b/quest/algos/utils/obs_core.py new file mode 100644 index 0000000..ddfb635 --- /dev/null +++ b/quest/algos/utils/obs_core.py @@ -0,0 +1,378 @@ +""" +Contains torch Modules for core observation processing blocks +such as encoders (e.g. EncoderCore, VisualCore, ScanCore, ...) +and randomizers (e.g. Randomizer, CropRandomizer). + +This file is taken from robomimic +""" + +import abc +import numpy as np +import textwrap +import random + +import torch +import torch.nn as nn + +import quest.utils.tensor_utils as TensorUtils +import quest.utils.obs_utils as ObsUtils + + +""" +================================================ +Observation Randomizer Networks +================================================ +""" +class Randomizer(nn.Module): + """ + Base class for randomizer networks. Each randomizer should implement the @output_shape_in, + @output_shape_out, @forward_in, and @forward_out methods. The randomizer's @forward_in + method is invoked on raw inputs, and @forward_out is invoked on processed inputs + (usually processed by a @VisualCore instance). Note that the self.training property + can be used to change the randomizer's behavior at train vs. test time. + """ + def __init__(self): + super(Randomizer, self).__init__() + + def __init_subclass__(cls, **kwargs): + """ + Hook method to automatically register all valid subclasses so we can keep track of valid observation randomizers + in a global dict. + + This global dict stores mapping from observation randomizer network name to class. + We keep track of these registries to enable automated class inference at runtime, allowing + users to simply extend our base randomizer class and refer to that class in string form + in their config, without having to manually register their class internally. + This also future-proofs us for any additional randomizer classes we would + like to add ourselves. + """ + ObsUtils.register_randomizer(cls) + + def output_shape(self, input_shape=None): + """ + This function is unused. See @output_shape_in and @output_shape_out. + """ + raise NotImplementedError + + @abc.abstractmethod + def output_shape_in(self, input_shape=None): + """ + Function to compute output shape from inputs to this module. Corresponds to + the @forward_in operation, where raw inputs (usually observation modalities) + are passed in. + + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + raise NotImplementedError + + @abc.abstractmethod + def output_shape_out(self, input_shape=None): + """ + Function to compute output shape from inputs to this module. Corresponds to + the @forward_out operation, where processed inputs (usually encoded observation + modalities) are passed in. + + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + raise NotImplementedError + + def forward_in(self, inputs): + """ + Randomize raw inputs if training. + """ + if self.training: + randomized_inputs = self._forward_in(inputs=inputs) + # if VISUALIZE_RANDOMIZER: + # num_samples_to_visualize = min(4, inputs.shape[0]) + # self._visualize(inputs, randomized_inputs, num_samples_to_visualize=num_samples_to_visualize) + return randomized_inputs + else: + return self._forward_in_eval(inputs) + + def forward_out(self, inputs): + """ + Processing for network outputs. + """ + if self.training: + return self._forward_out(inputs) + else: + return self._forward_out_eval(inputs) + + @abc.abstractmethod + def _forward_in(self, inputs): + """ + Randomize raw inputs. + """ + raise NotImplementedError + + def _forward_in_eval(self, inputs): + """ + Test-time behavior for the randomizer + """ + return inputs + + @abc.abstractmethod + def _forward_out(self, inputs): + """ + Processing for network outputs. + """ + return inputs + + def _forward_out_eval(self, inputs): + """ + Test-time behavior for the randomizer + """ + return inputs + + @abc.abstractmethod + def _visualize(self, pre_random_input, randomized_input, num_samples_to_visualize=2): + """ + Visualize the original input and the randomized input for _forward_in for debugging purposes. + """ + pass + + +class CropRandomizer(Randomizer): + """ + Randomly sample crops at input, and then average across crop features at output. + """ + def __init__( + self, + input_shape, + crop_height=76, + crop_width=76, + num_crops=1, + pos_enc=False, + ): + """ + Args: + input_shape (tuple, list): shape of input (not including batch dimension) + crop_height (int): crop height + crop_width (int): crop width + num_crops (int): number of random crops to take + pos_enc (bool): if True, add 2 channels to the output to encode the spatial + location of the cropped pixels in the source image + """ + super(CropRandomizer, self).__init__() + + assert len(input_shape) == 3 # (C, H, W) + assert crop_height < input_shape[1] + assert crop_width < input_shape[2] + + self.input_shape = input_shape + self.crop_height = crop_height + self.crop_width = crop_width + self.num_crops = num_crops + self.pos_enc = pos_enc + + def output_shape_in(self, input_shape=None): + """ + Function to compute output shape from inputs to this module. Corresponds to + the @forward_in operation, where raw inputs (usually observation modalities) + are passed in. + + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + + # outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because + # the number of crops are reshaped into the batch dimension, increasing the batch + # size from B to B * N + out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0] + return [out_c, self.crop_height, self.crop_width] + + def output_shape_out(self, input_shape=None): + """ + Function to compute output shape from inputs to this module. Corresponds to + the @forward_out operation, where processed inputs (usually encoded observation + modalities) are passed in. + + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + + # since the forward_out operation splits [B * N, ...] -> [B, N, ...] + # and then pools to result in [B, ...], only the batch dimension changes, + # and so the other dimensions retain their shape. + return list(input_shape) + + def _forward_in(self, inputs): + """ + Samples N random crops for each input in the batch, and then reshapes + inputs to [B * N, ...]. + """ + assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions + out, _ = ObsUtils.sample_random_image_crops( + images=inputs, + crop_height=self.crop_height, + crop_width=self.crop_width, + num_crops=self.num_crops, + pos_enc=self.pos_enc, + ) + # [B, N, ...] -> [B * N, ...] + return TensorUtils.join_dimensions(out, 0, 1) + + def _forward_in_eval(self, inputs): + """ + Do center crops during eval + """ + assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions + inputs = inputs.permute(*range(inputs.dim()-3), inputs.dim()-2, inputs.dim()-1, inputs.dim()-3) + out = ObsUtils.center_crop(inputs, self.crop_height, self.crop_width) + out = out.permute(*range(out.dim()-3), out.dim()-1, out.dim()-3, out.dim()-2) + return out + + def _forward_out(self, inputs): + """ + Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N + to result in shape [B, ...] to make sure the network output is consistent with + what would have happened if there were no randomization. + """ + batch_size = (inputs.shape[0] // self.num_crops) + out = TensorUtils.reshape_dimensions(inputs, begin_axis=0, end_axis=0, + target_dims=(batch_size, self.num_crops)) + return out.mean(dim=1) + + def _visualize(self, pre_random_input, randomized_input, num_samples_to_visualize=2): + batch_size = pre_random_input.shape[0] + random_sample_inds = torch.randint(0, batch_size, size=(num_samples_to_visualize,)) + pre_random_input_np = TensorUtils.to_numpy(pre_random_input)[random_sample_inds] + randomized_input = TensorUtils.reshape_dimensions( + randomized_input, + begin_axis=0, + end_axis=0, + target_dims=(batch_size, self.num_crops) + ) # [B * N, ...] -> [B, N, ...] + randomized_input_np = TensorUtils.to_numpy(randomized_input[random_sample_inds]) + + pre_random_input_np = pre_random_input_np.transpose((0, 2, 3, 1)) # [B, C, H, W] -> [B, H, W, C] + randomized_input_np = randomized_input_np.transpose((0, 1, 3, 4, 2)) # [B, N, C, H, W] -> [B, N, H, W, C] + + # visualize_image_randomizer( + # pre_random_input_np, + # randomized_input_np, + # randomizer_name='{}'.format(str(self.__class__.__name__)) + # ) + + def __repr__(self): + """Pretty print network.""" + header = '{}'.format(str(self.__class__.__name__)) + msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format( + self.input_shape, self.crop_height, self.crop_width, self.num_crops) + return msg + + +class GaussianNoiseRandomizer(Randomizer): + """ + Randomly sample gaussian noise at input, and then average across noises at output. + """ + def __init__( + self, + input_shape, + noise_mean=0.0, + noise_std=0.3, + limits=None, + num_samples=1, + ): + """ + Args: + input_shape (tuple, list): shape of input (not including batch dimension) + noise_mean (float): Mean of noise to apply + noise_std (float): Standard deviation of noise to apply + limits (None or 2-tuple): If specified, should be the (min, max) values to clamp all noisied samples to + num_samples (int): number of random color jitters to take + """ + super(GaussianNoiseRandomizer, self).__init__() + + self.input_shape = input_shape + self.noise_mean = noise_mean + self.noise_std = noise_std + self.limits = limits + self.num_samples = num_samples + + def output_shape_in(self, input_shape=None): + # outputs are same shape as inputs + return list(input_shape) + + def output_shape_out(self, input_shape=None): + # since the forward_out operation splits [B * N, ...] -> [B, N, ...] + # and then pools to result in [B, ...], only the batch dimension changes, + # and so the other dimensions retain their shape. + return list(input_shape) + + def _forward_in(self, inputs): + """ + Samples N random gaussian noises for each input in the batch, and then reshapes + inputs to [B * N, ...]. + """ + out = TensorUtils.repeat_by_expand_at(inputs, repeats=self.num_samples, dim=0) + + # Sample noise across all samples + out = torch.rand(size=out.shape) * self.noise_std + self.noise_mean + out + + # Possibly clamp + if self.limits is not None: + out = torch.clip(out, min=self.limits[0], max=self.limits[1]) + + return out + + def _forward_out(self, inputs): + """ + Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N + to result in shape [B, ...] to make sure the network output is consistent with + what would have happened if there were no randomization. + """ + batch_size = (inputs.shape[0] // self.num_samples) + out = TensorUtils.reshape_dimensions(inputs, begin_axis=0, end_axis=0, + target_dims=(batch_size, self.num_samples)) + return out.mean(dim=1) + + def _visualize(self, pre_random_input, randomized_input, num_samples_to_visualize=2): + batch_size = pre_random_input.shape[0] + random_sample_inds = torch.randint(0, batch_size, size=(num_samples_to_visualize,)) + pre_random_input_np = TensorUtils.to_numpy(pre_random_input)[random_sample_inds] + randomized_input = TensorUtils.reshape_dimensions( + randomized_input, + begin_axis=0, + end_axis=0, + target_dims=(batch_size, self.num_samples) + ) # [B * N, ...] -> [B, N, ...] + randomized_input_np = TensorUtils.to_numpy(randomized_input[random_sample_inds]) + + pre_random_input_np = pre_random_input_np.transpose((0, 2, 3, 1)) # [B, C, H, W] -> [B, H, W, C] + randomized_input_np = randomized_input_np.transpose((0, 1, 3, 4, 2)) # [B, N, C, H, W] -> [B, N, H, W, C] + + # visualize_image_randomizer( + # pre_random_input_np, + # randomized_input_np, + # randomizer_name='{}'.format(str(self.__class__.__name__)) + # ) + + def __repr__(self): + """Pretty print network.""" + header = '{}'.format(str(self.__class__.__name__)) + msg = header + f"(input_shape={self.input_shape}, noise_mean={self.noise_mean}, noise_std={self.noise_std}, " \ + f"limits={self.limits}, num_samples={self.num_samples})" + return msg diff --git a/quest/algos/utils/rgb_modules.py b/quest/algos/utils/rgb_modules.py new file mode 100644 index 0000000..8580917 --- /dev/null +++ b/quest/algos/utils/rgb_modules.py @@ -0,0 +1,373 @@ +""" +This file contains all neural modules related to encoding the spatial +information of obs_t, i.e., the abstracted knowledge of the current visual +input conditioned on the language. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torchvision.models._utils import IntermediateLayerGetter + + +from quest.algos.baseline_modules.act_utils.misc import NestedTensor, is_main_process + + +############################################################################### +# +# Modules related to encoding visual information (can conditioned on language) +# +############################################################################### + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] + produce nans. + + From ACT codebase + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class PatchEncoder(nn.Module): + """ + A patch encoder that does a linear projection of patches in a RGB image. + """ + + def __init__( + self, input_shape, patch_size=[16, 16], embed_size=64, no_patch_embed_bias=False + ): + super().__init__() + C, H, W = input_shape + num_patches = (H // patch_size[0] // 2) * (W // patch_size[1] // 2) + self.img_size = (H, W) + self.patch_size = patch_size + self.num_patches = num_patches + self.h, self.w = H // patch_size[0] // 2, W // patch_size[1] // 2 + + self.conv = nn.Sequential( + nn.Conv2d( + C, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False + ), + nn.BatchNorm2d( + 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True + ), + nn.ReLU(inplace=True), + ) + self.proj = nn.Conv2d( + 64, + embed_size, + kernel_size=patch_size, + stride=patch_size, + bias=False if no_patch_embed_bias else True, + ) + self.bn = nn.BatchNorm2d(embed_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.conv(x) + x = self.proj(x) + x = self.bn(x) + return x + + +class SpatialSoftmax(nn.Module): + """ + The spatial softmax layer (https://rll.berkeley.edu/dsae/dsae.pdf) + """ + + def __init__(self, in_c, in_h, in_w, num_kp=None): + super().__init__() + self._spatial_conv = nn.Conv2d(in_c, num_kp, kernel_size=1) + + pos_x, pos_y = torch.meshgrid( + torch.linspace(-1, 1, in_w).float(), + torch.linspace(-1, 1, in_h).float(), + indexing="ij", + ) + + pos_x = pos_x.reshape(1, in_w * in_h) + pos_y = pos_y.reshape(1, in_w * in_h) + self.register_buffer("pos_x", pos_x) + self.register_buffer("pos_y", pos_y) + + if num_kp is None: + self._num_kp = in_c + else: + self._num_kp = num_kp + + self._in_c = in_c + self._in_w = in_w + self._in_h = in_h + + def forward(self, x): + assert x.shape[1] == self._in_c + assert x.shape[2] == self._in_h + assert x.shape[3] == self._in_w + + h = x + if self._num_kp != self._in_c: + h = self._spatial_conv(h) + h = h.contiguous().view(-1, self._in_h * self._in_w) + + attention = F.softmax(h, dim=-1) + keypoint_x = ( + (self.pos_x * attention).sum(1, keepdims=True).view(-1, self._num_kp) + ) + keypoint_y = ( + (self.pos_y * attention).sum(1, keepdims=True).view(-1, self._num_kp) + ) + keypoints = torch.cat([keypoint_x, keypoint_y], dim=1) + return keypoints + + +class SpatialProjection(nn.Module): + def __init__(self, input_shape, out_dim): + super().__init__() + + assert ( + len(input_shape) == 3 + ), "[error] spatial projection: input shape is not a 3-tuple" + in_c, in_h, in_w = input_shape + num_kp = out_dim // 2 + self.out_dim = out_dim + self.spatial_softmax = SpatialSoftmax(in_c, in_h, in_w, num_kp=num_kp) + self.projection = nn.Linear(num_kp * 2, out_dim) + + def forward(self, x): + out = self.spatial_softmax(x) + out = self.projection(out) + return out + + def output_shape(self, input_shape): + return input_shape[:-3] + (self.out_dim,) + + +class ResnetEncoder(nn.Module): + """ + A Resnet-18-based encoder for mapping an image to a latent vector + + Encode (f) an image into a latent vector. + + y = f(x), where + x: (B, C, H, W) + y: (B, H_out) + + Args: + input_shape: (C, H, W), the shape of the image + output_size: H_out, the latent vector size + pretrained: whether use pretrained resnet + freeze: whether freeze the pretrained resnet + remove_layer_num: remove the top # layers + no_stride: do not use striding + """ + + def __init__( + self, + input_shape, + output_size, + pretrained=False, + freeze=False, + remove_layer_num=2, + no_stride=False, + language_dim=768, + language_fusion="film", + do_projection=True, + ): + + super().__init__() + + ### 1. encode input (images) using convolutional layers + assert remove_layer_num <= 5, "[error] please only remove <=5 layers" + weights = torchvision.models.ResNet18_Weights if pretrained else None + layers = list(torchvision.models.resnet18(weights=weights).children())[ + :-remove_layer_num + ] + self.remove_layer_num = remove_layer_num + + assert ( + len(input_shape) == 3 + ), "[error] input shape of resnet should be (C, H, W)" + + in_channels = input_shape[0] + if in_channels != 3: # has eye_in_hand, increase channel size + conv0 = nn.Conv2d( + in_channels=in_channels, + out_channels=64, + kernel_size=(7, 7), + stride=(2, 2), + padding=(3, 3), + bias=False, + ) + layers[0] = conv0 + + self.no_stride = no_stride + if self.no_stride: + layers[0].stride = (1, 1) + layers[3].stride = 1 + + self.resnet18_base = nn.Sequential(*layers[:4]) + self.block_1 = layers[4][0] + self.block_2 = layers[4][1] + self.block_3 = layers[5][0] + self.block_4 = layers[5][1] + + self.language_fusion = language_fusion + if language_fusion != "none": + self.lang_proj1 = nn.Linear(language_dim, 64 * 2) + self.lang_proj2 = nn.Linear(language_dim, 64 * 2) + self.lang_proj3 = nn.Linear(language_dim, 128 * 2) + self.lang_proj4 = nn.Linear(language_dim, 128 * 2) + + if freeze: + if in_channels != 3: + raise Exception( + "[error] cannot freeze pretrained " + + "resnet with the extra eye_in_hand input" + ) + for param in self.resnet18_embeddings.parameters(): + param.requires_grad = False + + if pretrained: + self.normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + else: + self.normalizer = nn.Identity() + + + x = torch.zeros(1, *input_shape) + y = self.block_4( + self.block_3(self.block_2(self.block_1(self.resnet18_base(x)))) + ) + output_shape = y.shape # compute the out dim + if do_projection: + ### 2. project the encoded input to a latent space + self.projection_layer = SpatialProjection(output_shape[1:], output_size) + self.out_channels = self.projection_layer(y).shape[1] + else: + self.projection_layer = None + self.out_channels = y.shape[-1] + + def forward(self, x, langs=None): + x = self.normalizer(x) + h = self.resnet18_base(x) + + h = self.block_1(h) + if langs is not None and self.language_fusion != "none": # FiLM layer + B, C, H, W = h.shape + beta, gamma = torch.split( + self.lang_proj1(langs).reshape(B, C * 2, 1, 1), [C, C], 1 + ) + h = (1 + gamma) * h + beta + + h = self.block_2(h) + if langs is not None and self.language_fusion != "none": # FiLM layer + B, C, H, W = h.shape + beta, gamma = torch.split( + self.lang_proj2(langs).reshape(B, C * 2, 1, 1), [C, C], 1 + ) + h = (1 + gamma) * h + beta + + h = self.block_3(h) + if langs is not None and self.language_fusion != "none": # FiLM layer + B, C, H, W = h.shape + beta, gamma = torch.split( + self.lang_proj3(langs).reshape(B, C * 2, 1, 1), [C, C], 1 + ) + h = (1 + gamma) * h + beta + + h = self.block_4(h) + if langs is not None and self.language_fusion != "none": # FiLM layer + B, C, H, W = h.shape + beta, gamma = torch.split( + self.lang_proj4(langs).reshape(B, C * 2, 1, 1), [C, C], 1 + ) + h = (1 + gamma) * h + beta + + if self.projection_layer is not None: + h = self.projection_layer(h) + + return h + + +class DINOEncoder(nn.Module): + + def __init__( + self, + input_shape, + output_size, + pretrained=True, + ): + super().__init__() + assert ( + len(input_shape) == 3 + ), "[error] input shape of resnet should be (C, H, W)" + + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + self.preprocess = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize(224, interpolation=3), + torchvision.transforms.Normalize(mean=mean, std=std), + ] + ) + self.dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') + self.mlp_block = nn.Sequential( + nn.Linear(384, 64), + nn.GELU(), + nn.Dropout(0.1), + nn.Linear(64, 1), + nn.Dropout(0.1), + ) + self.projection = nn.Linear(384, output_size) + self.output_shape = output_size + + if pretrained: + for param in self.dino.parameters(): + param.requires_grad = False + + def forward(self, x, langs=None): + x = self.preprocess(x) + x = self.dino(x,is_training=True)['x_norm_patchtokens'] + mask = self.mlp_block(x).permute(0, 2, 1) + mask = F.softmax(mask, dim=-1) + x = torch.einsum('...si,...id->...sd', mask, x) + x = self.projection(x) + return x + + def output_shape(self, input_shape, shape_meta): + return self.output_shape + + diff --git a/quest/algos/utils/task_encoders.py b/quest/algos/utils/task_encoders.py new file mode 100644 index 0000000..79ed906 --- /dev/null +++ b/quest/algos/utils/task_encoders.py @@ -0,0 +1,8 @@ +import torch.nn as nn + +class TaskEmbeddingEncoder(nn.Module): + def __init__(self, n_tasks, embed_dim): + self.task_encodings = nn.Embedding(n_tasks, embed_dim) + + def forward(self, task_id): + return self.task_encodings[task_id] \ No newline at end of file diff --git a/quest/env_runner/libero_runner.py b/quest/env_runner/libero_runner.py new file mode 100644 index 0000000..da82ce3 --- /dev/null +++ b/quest/env_runner/libero_runner.py @@ -0,0 +1,158 @@ +import numpy as np +import gc +import quest.utils.libero_utils as lu +import quest.utils.obs_utils as ObsUtils +import wandb +from tqdm import tqdm +import multiprocessing + +class LiberoRunner(): + def __init__(self, + env_factory, + benchmark_name, + mode, # all or few + rollouts_per_env, + num_parallel_envs, + max_episode_length, + frame_stack=1, + fps=10, + debug=False, + task_embedding_format='clip', + ): + self.env_factory = env_factory + self.benchmark_name = benchmark_name + self.benchmark = lu.get_benchmark(benchmark_name)() + descriptions = [self.benchmark.get_task(i).language for i in range(self.benchmark.n_tasks)] + task_embs = lu.get_task_embs(task_embedding_format, descriptions) + self.benchmark.set_task_embs(task_embs) + self.env_names = self.benchmark.get_task_names() + + self.mode = mode + self.rollouts_per_env = rollouts_per_env + self.num_parallel_envs = num_parallel_envs + self.frame_stack = frame_stack + if num_parallel_envs>1: + if multiprocessing.get_start_method(allow_none=True) != "spawn": + multiprocessing.set_start_method("spawn", force=True) + self.max_episode_length = max_episode_length + self.fps = fps + + def run(self, policy, n_video=0, do_tqdm=False, save_video_fn=None): + env_names = self.env_names + successes, per_env_any_success, rewards = [], [], [] + per_env_success_rates, per_env_rewards = {}, {} + videos = {} + for env_name in tqdm(env_names, disable=not do_tqdm): + + any_success = False + env_succs, env_rews, env_video = [], [], [] + rollouts = self.run_policy_in_env(env_name, policy, render=n_video > 0) + for i, (success, total_reward, episode) in enumerate(rollouts): + any_success = any_success or success + successes.append(success) + env_succs.append(success) + env_rews.append(total_reward) + rewards.append(total_reward) + + if i < n_video: + if save_video_fn is not None: + video_hwc = np.array(episode['render']) + video_chw = video_hwc.transpose((0, 3, 1, 2)) + save_video_fn(video_chw, env_name, i) + else: + env_video.extend(episode['render']) + + per_env_success_rates[env_name] = np.mean(env_succs) + per_env_rewards[env_name] = np.mean(env_rews) + per_env_any_success.append(any_success) + + if len(env_video) > 0: + video_hwc = np.array(env_video) + video_chw = video_hwc.transpose((0, 3, 1, 2)) + videos[env_name] = wandb.Video(video_chw, fps=self.fps) + + output = {} + output['rollout'] = { + 'overall_success_rate': np.mean(successes), + 'overall_average_reward': np.mean(rewards), + 'environments_solved': int(np.sum(per_env_any_success)), + } + output['rollout_success_rate'] = {} + for env_name in env_names: + output['rollout_success_rate'][env_name] = per_env_success_rates[env_name] + if len(videos) > 0: + output['rollout_videos'] = {} + for env_name in videos: + + output['rollout_videos'][env_name] = videos[env_name] + + return output + + def run_policy_in_env(self, env_name, policy, render=False): + env_id = self.env_names.index(env_name) + env_num = min(self.num_parallel_envs, self.rollouts_per_env) + env_fn = lambda: lu.LiberoFrameStack(self.env_factory(env_id, self.benchmark), self.frame_stack) + env = lu.LiberoVectorWrapper(env_fn, self.num_parallel_envs) + + all_init_states = self.benchmark.get_task_init_states(env_id) + count = 0 + eval_loop_num = (self.rollouts_per_env+self.num_parallel_envs-1)//self.num_parallel_envs + + while count < eval_loop_num: + indices = np.arange(count * env_num, (count + 1) * env_num) % all_init_states.shape[0] + init_states_ = all_init_states[indices] + success, total_reward, episode = self.run_episode(env, + env_name, + policy, + init_states_, + env_num, + render) + count += 1 + for k in range(env_num): + episode_k = {key: value[:,k] for key, value in episode.items()} + yield success[k], total_reward[k], episode_k + env._env.close() + gc.collect() + del env + # TODO: envs are not being closed properly hence getting EGL error + + def run_episode(self, env, env_name, policy, init_states_, env_num, render=False): + obs, info = env.reset(init_states=init_states_) + + if hasattr(policy, 'get_action'): + policy.reset() + policy_object = policy + policy = lambda obs, task_id, task_emb: policy_object.get_action(obs, task_id, task_emb) + + success, total_reward = [False]*env_num, [0]*env_num + + episode = {key: [value[:,-1]] for key, value in obs.items()} + episode['actions'] = [] + if render: + episode['render'] = [env.render()] + + task_id = self.env_names.index(env_name) + task_emb = self.benchmark.get_task_emb(task_id).repeat(env_num, 1) + steps = 0 + while steps < self.max_episode_length: + action = policy(obs, task_id, task_emb) + # action = env.action_space.sample() + action = np.clip(action, env.action_space.low, env.action_space.high) + next_obs, reward, terminated, truncated, info = env.step(action) + total_reward += reward + obs = next_obs + for key, value in obs.items(): + episode[key].append(value[:,-1]) + episode['actions'].append(action) + if render: + episode['render'].append(env.render()) + + for k in range(env_num): + success[k] = success[k] or terminated[k] + + if all(success): + break + steps += 1 + + episode = {key: np.array(value) for key, value in episode.items()} + return success, total_reward, episode \ No newline at end of file diff --git a/quest/env_runner/metaworld_runner.py b/quest/env_runner/metaworld_runner.py new file mode 100644 index 0000000..96d0f6b --- /dev/null +++ b/quest/env_runner/metaworld_runner.py @@ -0,0 +1,161 @@ +import numpy as np + +import quest.utils.metaworld_utils as mu +import wandb +from tqdm import tqdm + + +class MetaWorldRunner(): + def __init__(self, + env_factory, + benchmark_name, + mode, # train or test + rollouts_per_env, + fps=10, + debug=False, + random_task=False, + ): + self.env_factory = env_factory + self.benchmark_name = benchmark_name + self.benchmark = mu.get_benchmark(benchmark_name) if not debug else None + self.mode = mode + self.rollouts_per_env = rollouts_per_env + self.fps = fps + self.random_task = random_task + + + def run(self, policy, n_video=0, do_tqdm=False, save_video_fn=None): + # print + env_names = mu.get_env_names(self.benchmark_name, self.mode) + successes, per_env_any_success, rewards = [], [], [] + per_env_success_rates, per_env_rewards = {}, {} + videos = {} + for env_name in tqdm(env_names, disable=not do_tqdm): + + any_success = False + env_succs, env_rews, env_video = [], [], [] + rollouts = self.run_policy_in_env(env_name, policy, render=n_video > 0) + for i, (success, total_reward, episode) in enumerate(rollouts): + any_success = any_success or success + successes.append(success) + env_succs.append(success) + env_rews.append(total_reward) + rewards.append(total_reward) + + if i < n_video: + if save_video_fn is not None: + video_hwc = np.array(episode['render']) + video_chw = video_hwc.transpose((0, 3, 1, 2)) + save_video_fn(video_chw, env_name, i) + else: + env_video.extend(episode['render']) + + per_env_success_rates[env_name] = np.mean(env_succs) + per_env_rewards[env_name] = np.mean(env_rews) + per_env_any_success.append(any_success) + + if len(env_video) > 0: + video_hwc = np.array(env_video) + video_chw = video_hwc.transpose((0, 3, 1, 2)) + videos[env_name] = wandb.Video(video_chw, fps=self.fps) + + # output['rollout'] = {} + output = {} + output['rollout'] = { + 'overall_success_rate': np.mean(successes), + 'overall_average_reward': np.mean(rewards), + 'environments_solved': int(np.sum(per_env_any_success)), + } + output['rollout_success_rate'] = {} + for env_name in env_names: + output['rollout_success_rate'][env_name] = per_env_success_rates[env_name] + # This metric isn't that useful + # output[f'rollout_detail/average_reward_{env_name}'] = per_env_rewards[env_name] + if len(videos) > 0: + output['rollout_videos'] = {} + for env_name in videos: + + output['rollout_videos'][env_name] = videos[env_name] + + return output + + + def run_policy_in_env(self, env_name, policy, render=False): + env = self.env_factory(env_name=env_name) + tasks = mu.get_tasks(self.benchmark, self.mode) + + env_tasks = [task for task in tasks if task.env_name == env_name] + count = 0 + while count < self.rollouts_per_env: + if len(env_tasks) > 0: + if self.random_task: + task_ind = np.random.randint(len(env_tasks)) + task = env_tasks[task_ind] + else: + task = env_tasks[count % len(env_tasks)] + env.set_task(task) + + success, total_reward, episode = self.run_episode(env, + env_name, + policy, + render) + count += 1 + yield success, total_reward, episode + + env.close() + del env + + + def run_episode(self, env, env_name, policy, render=False): + obs, _ = env.reset() + # breakpoint() + if hasattr(policy, 'get_action'): + policy.reset() + policy_object = policy + # breakpoint() + policy = lambda obs, task_id: policy_object.get_action(obs, task_id) + + done, success, total_reward = False, False, 0 + + episode = {key: [value[-1]] for key, value in obs.items()} + episode['actions'] = [] + episode['terminated'] = [] + episode['truncated'] = [] + episode['reward'] = [] + episode['success'] = [] + if render: + episode['render'] = [env.render()] + + task_id = mu.get_index(env_name) + + count = 0 + + while not done: + obs = {k: np.expand_dims(v, 0) for k, v in obs.items()} + action = policy(obs, task_id).squeeze() + # action = env.action_space.sample() + action = np.clip(action, env.action_space.low, env.action_space.high) + next_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + total_reward += reward + obs = next_obs + + for key, value in obs.items(): + episode[key].append(value[-1]) + episode['actions'].append(action) + episode['terminated'].append(terminated) + episode['truncated'].append(truncated) + episode['reward'].append(reward) + episode['success'].append(info['success']) + if int(info["success"]) == 1: + success = True + if render: + episode['render'].append(env.render()) + + count += 1 + # if count > 50: + # break + + episode = {key: np.array(value) for key, value in episode.items()} + return success, total_reward, episode + \ No newline at end of file diff --git a/quest/utils/dataset.py b/quest/utils/dataset.py new file mode 100644 index 0000000..c42780f --- /dev/null +++ b/quest/utils/dataset.py @@ -0,0 +1,673 @@ +""" +This file contains Dataset classes that are used by torch dataloaders +to fetch batches from hdf5 files. + +This file is adapted from Robomimic +https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/utils/dataset.py +""" +import os +import h5py +import numpy as np +from copy import deepcopy +from contextlib import contextmanager + +import torch.utils.data + +import quest.utils.tensor_utils as TensorUtils +import quest.utils.obs_utils as ObsUtils +from tqdm import tqdm + + +class SequenceDataset(torch.utils.data.Dataset): + def __init__( + self, + hdf5_path, + obs_keys, + dataset_keys, + frame_stack=1, + seq_length=1, + obs_seq_length=1, + lowdim_obs_seq_length=None, + pad_frame_stack=True, + pad_seq_length=True, + get_pad_mask=False, + goal_mode=None, + hdf5_cache_mode=None, + hdf5_use_swmr=True, + hdf5_normalize_obs=False, + filter_by_attribute=None, + load_next_obs=True, + few_demos=None, + n_demos=None, + ): + """ + Dataset class for fetching sequences of experience. + Length of the fetched sequence is equal to (@frame_stack - 1 + @seq_length) + + Args: + hdf5_path (str): path to hdf5 + + obs_keys (tuple, list): keys to observation items (image, object, etc) to be fetched from the dataset + + dataset_keys (tuple, list): keys to dataset items (actions, rewards, etc) to be fetched from the dataset + + frame_stack (int): numbers of stacked frames to fetch. Defaults to 1 (single frame). + + seq_length (int): length of sequences to sample. Defaults to 1 (single frame). + + pad_frame_stack (int): whether to pad sequence for frame stacking at the beginning of a demo. This + ensures that partial frame stacks are observed, such as (s_0, s_0, s_0, s_1). Otherwise, the + first frame stacked observation would be (s_0, s_1, s_2, s_3). + + pad_seq_length (int): whether to pad sequence for sequence fetching at the end of a demo. This + ensures that partial sequences at the end of a demonstration are observed, such as + (s_{T-1}, s_{T}, s_{T}, s_{T}). Otherwise, the last sequence provided would be + (s_{T-3}, s_{T-2}, s_{T-1}, s_{T}). + + get_pad_mask (bool): if True, also provide padding masks as part of the batch. This can be + useful for masking loss functions on padded parts of the data. + + goal_mode (str): either "last" or None. Defaults to None, which is to not fetch goals + + hdf5_cache_mode (str): one of ["all", "low_dim", or None]. Set to "all" to cache entire hdf5 + in memory - this is by far the fastest for data loading. Set to "low_dim" to cache all + non-image data. Set to None to use no caching - in this case, every batch sample is + retrieved via file i/o. You should almost never set this to None, even for large + image datasets. + + hdf5_use_swmr (bool): whether to use swmr feature when opening the hdf5 file. This ensures + that multiple Dataset instances can all access the same hdf5 file without problems. + + hdf5_normalize_obs (bool): if True, normalize observations by computing the mean observation + and std of each observation (in each dimension and modality), and normalizing to unit + mean and variance in each dimension. + + filter_by_attribute (str): if provided, use the provided filter key to look up a subset of + demonstrations to load + + load_next_obs (bool): whether to load next_obs from the dataset + """ + super(SequenceDataset, self).__init__() + self.hdf5_path = os.path.expanduser(hdf5_path) + self.hdf5_use_swmr = hdf5_use_swmr + self.hdf5_normalize_obs = hdf5_normalize_obs + self._hdf5_file = None + + assert hdf5_cache_mode in ["all", "low_dim", None] + self.hdf5_cache_mode = hdf5_cache_mode + + self.load_next_obs = load_next_obs + self.filter_by_attribute = filter_by_attribute + + # get all keys that needs to be fetched + self.obs_keys = tuple(obs_keys) + self.dataset_keys = tuple(dataset_keys) + + self.n_frame_stack = frame_stack + assert self.n_frame_stack >= 1 + + self.seq_length = seq_length + assert self.seq_length >= 1 + self.obs_seq_length = obs_seq_length + assert self.obs_seq_length >= 1 + self.lowdim_obs_seq_length = lowdim_obs_seq_length + + self.goal_mode = goal_mode + if self.goal_mode is not None: + assert self.goal_mode in ["last"] + if not self.load_next_obs: + assert self.goal_mode != "last" # we use last next_obs as goal + + self.pad_seq_length = pad_seq_length + self.pad_frame_stack = pad_frame_stack + self.get_pad_mask = get_pad_mask + + self.few_demos = few_demos + + self.load_demo_info(filter_by_attribute=self.filter_by_attribute, demos=self.few_demos, n_demos=n_demos) + + # maybe prepare for observation normalization + self.obs_normalization_stats = None + if self.hdf5_normalize_obs: + self.obs_normalization_stats = self.normalize_obs() + + # maybe store dataset in memory for fast access + if self.hdf5_cache_mode in ["all", "low_dim"]: + obs_keys_in_memory = self.obs_keys + if self.hdf5_cache_mode == "low_dim": + # only store low-dim observations + obs_keys_in_memory = [] + for k in self.obs_keys: + if ObsUtils.key_is_obs_modality(k, "low_dim"): + obs_keys_in_memory.append(k) + self.obs_keys_in_memory = obs_keys_in_memory + + self.hdf5_cache = self.load_dataset_in_memory( + demo_list=self.demos, + hdf5_file=self.hdf5_file, + obs_keys=self.obs_keys_in_memory, + dataset_keys=self.dataset_keys, + load_next_obs=self.load_next_obs + ) + + if self.hdf5_cache_mode == "all": + # cache getitem calls for even more speedup. We don't do this for + # "low-dim" since image observations require calls to getitem anyways. + print("SequenceDataset: caching get_item calls...") + self.getitem_cache = [self.get_item(i) for i in tqdm(range(len(self)))] + + # don't need the previous cache anymore + del self.hdf5_cache + self.hdf5_cache = None + else: + self.hdf5_cache = None + + self.close_and_delete_hdf5_handle() + + def load_demo_info(self, filter_by_attribute=None, demos=None, n_demos=None): + """ + Args: + filter_by_attribute (str): if provided, use the provided filter key + to select a subset of demonstration trajectories to load + + demos (list): list of demonstration keys to load from the hdf5 file. If + omitted, all demos in the file (or under the @filter_by_attribute + filter key) are used. + """ + # filter demo trajectory by mask + if demos is not None: + self.demos = demos + elif filter_by_attribute is not None: + self.demos = [elem.decode("utf-8") for elem in np.array(self.hdf5_file["mask/{}".format(filter_by_attribute)][:])] + else: + self.demos = list(self.hdf5_file["data"].keys()) + + if n_demos is not None: + assert n_demos <= len(self.demos), 'asking for more demos than available in the dataset' + self.demos = self.demos[:n_demos] + + # sort demo keys + inds = np.argsort([int(elem[5:]) for elem in self.demos]) + self.demos = [self.demos[i] for i in inds] + + self.n_demos = len(self.demos) + + # keep internal index maps to know which transitions belong to which demos + self._index_to_demo_id = dict() # maps every index to a demo id + self._demo_id_to_start_indices = dict() # gives start index per demo id + self._demo_id_to_demo_length = dict() + + # determine index mapping + self.total_num_sequences = 0 + for ep in self.demos: + demo_length = self.hdf5_file["data/{}".format(ep)].attrs["num_samples"] + self._demo_id_to_start_indices[ep] = self.total_num_sequences + self._demo_id_to_demo_length[ep] = demo_length + + num_sequences = demo_length + # determine actual number of sequences taking into account whether to pad for frame_stack and seq_length + if not self.pad_frame_stack: + num_sequences -= (self.n_frame_stack - 1) + if not self.pad_seq_length: + num_sequences -= (self.seq_length - 1) + + if self.pad_seq_length: + assert demo_length >= 1 # sequence needs to have at least one sample + num_sequences = max(num_sequences, 1) + else: + assert num_sequences >= 1 # assume demo_length >= (self.n_frame_stack - 1 + self.seq_length) + + for _ in range(num_sequences): + self._index_to_demo_id[self.total_num_sequences] = ep + self.total_num_sequences += 1 + + @property + def hdf5_file(self): + """ + This property allows for a lazy hdf5 file open. + """ + if self._hdf5_file is None: + self._hdf5_file = h5py.File(self.hdf5_path, 'r', swmr=self.hdf5_use_swmr, libver='latest') + return self._hdf5_file + + def close_and_delete_hdf5_handle(self): + """ + Maybe close the file handle. + """ + if self._hdf5_file is not None: + self._hdf5_file.close() + self._hdf5_file = None + + @contextmanager + def hdf5_file_opened(self): + """ + Convenient context manager to open the file on entering the scope + and then close it on leaving. + """ + should_close = self._hdf5_file is None + yield self.hdf5_file + if should_close: + self.close_and_delete_hdf5_handle() + + def __del__(self): + self.close_and_delete_hdf5_handle() + + def __repr__(self): + """ + Pretty print the class and important attributes on a call to `print`. + """ + msg = str(self.__class__.__name__) + msg += " (\n\tpath={}\n\tobs_keys={}\n\tseq_length={}\n\tfilter_key={}\n\tframe_stack={}\n" + msg += "\tpad_seq_length={}\n\tpad_frame_stack={}\n\tgoal_mode={}\n" + msg += "\tcache_mode={}\n" + msg += "\tnum_demos={}\n\tnum_sequences={}\n)" + filter_key_str = self.filter_by_attribute if self.filter_by_attribute is not None else "none" + goal_mode_str = self.goal_mode if self.goal_mode is not None else "none" + cache_mode_str = self.hdf5_cache_mode if self.hdf5_cache_mode is not None else "none" + msg = msg.format(self.hdf5_path, self.obs_keys, self.seq_length, filter_key_str, self.n_frame_stack, + self.pad_seq_length, self.pad_frame_stack, goal_mode_str, cache_mode_str, + self.n_demos, self.total_num_sequences) + return msg + + def __len__(self): + """ + Ensure that the torch dataloader will do a complete pass through all sequences in + the dataset before starting a new iteration. + """ + return self.total_num_sequences + + def load_dataset_in_memory(self, demo_list, hdf5_file, obs_keys, dataset_keys, load_next_obs): + """ + Loads the hdf5 dataset into memory, preserving the structure of the file. Note that this + differs from `self.getitem_cache`, which, if active, actually caches the outputs of the + `getitem` operation. + + Args: + demo_list (list): list of demo keys, e.g., 'demo_0' + hdf5_file (h5py.File): file handle to the hdf5 dataset. + obs_keys (list, tuple): observation keys to fetch, e.g., 'images' + dataset_keys (list, tuple): dataset keys to fetch, e.g., 'actions' + load_next_obs (bool): whether to load next_obs from the dataset + + Returns: + all_data (dict): dictionary of loaded data. + """ + all_data = dict() + # print("SequenceDataset: loading dataset into memory...") + for ep in tqdm(demo_list, disable=True): + all_data[ep] = {} + all_data[ep]["attrs"] = {} + all_data[ep]["attrs"]["num_samples"] = hdf5_file["data/{}".format(ep)].attrs["num_samples"] + # get obs + all_data[ep]["obs"] = {k: hdf5_file["data/{}/obs/{}".format(ep, k)][()].astype('float32') for k in obs_keys} + # if load_next_obs: + # all_data[ep]["next_obs"] = {k: hdf5_file["data/{}/next_obs/{}".format(ep, k)][()].astype('float32') for k in obs_keys} + # get other dataset keys + for k in dataset_keys: + if k in hdf5_file["data/{}".format(ep)]: + all_data[ep][k] = hdf5_file["data/{}/{}".format(ep, k)][()].astype('float32') + else: + all_data[ep][k] = np.zeros((all_data[ep]["attrs"]["num_samples"], 1), dtype=np.float32) + + if "model_file" in hdf5_file["data/{}".format(ep)].attrs: + all_data[ep]["attrs"]["model_file"] = hdf5_file["data/{}".format(ep)].attrs["model_file"] + + return all_data + + def normalize_obs(self): + """ + Computes a dataset-wide mean and standard deviation for the observations + (per dimension and per obs key) and returns it. + """ + def _compute_traj_stats(traj_obs_dict): + """ + Helper function to compute statistics over a single trajectory of observations. + """ + traj_stats = { k : {} for k in traj_obs_dict } + for k in traj_obs_dict: + traj_stats[k]["n"] = traj_obs_dict[k].shape[0] + traj_stats[k]["mean"] = traj_obs_dict[k].mean(axis=0, keepdims=True) # [1, ...] + traj_stats[k]["sqdiff"] = ((traj_obs_dict[k] - traj_stats[k]["mean"]) ** 2).sum(axis=0, keepdims=True) # [1, ...] + return traj_stats + + def _aggregate_traj_stats(traj_stats_a, traj_stats_b): + """ + Helper function to aggregate trajectory statistics. + See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + for more information. + """ + merged_stats = {} + for k in traj_stats_a: + n_a, avg_a, M2_a = traj_stats_a[k]["n"], traj_stats_a[k]["mean"], traj_stats_a[k]["sqdiff"] + n_b, avg_b, M2_b = traj_stats_b[k]["n"], traj_stats_b[k]["mean"], traj_stats_b[k]["sqdiff"] + n = n_a + n_b + mean = (n_a * avg_a + n_b * avg_b) / n + delta = (avg_b - avg_a) + M2 = M2_a + M2_b + (delta ** 2) * (n_a * n_b) / n + merged_stats[k] = dict(n=n, mean=mean, sqdiff=M2) + return merged_stats + + # Run through all trajectories. For each one, compute minimal observation statistics, and then aggregate + # with the previous statistics. + ep = self.demos[0] + obs_traj = {k: self.hdf5_file["data/{}/obs/{}".format(ep, k)][()].astype('float32') for k in self.obs_keys} + obs_traj = ObsUtils.process_obs_dict(obs_traj) + merged_stats = _compute_traj_stats(obs_traj) + print("SequenceDataset: normalizing observations...") + for ep in tqdm(self.demos[1:]): + obs_traj = {k: self.hdf5_file["data/{}/obs/{}".format(ep, k)][()].astype('float32') for k in self.obs_keys} + obs_traj = ObsUtils.process_obs_dict(obs_traj) + traj_stats = _compute_traj_stats(obs_traj) + merged_stats = _aggregate_traj_stats(merged_stats, traj_stats) + + obs_normalization_stats = { k : {} for k in merged_stats } + for k in merged_stats: + # note we add a small tolerance of 1e-3 for std + obs_normalization_stats[k]["mean"] = merged_stats[k]["mean"] + obs_normalization_stats[k]["std"] = np.sqrt(merged_stats[k]["sqdiff"] / merged_stats[k]["n"]) + 1e-3 + return obs_normalization_stats + + def get_obs_normalization_stats(self): + """ + Returns dictionary of mean and std for each observation key if using + observation normalization, otherwise None. + + Returns: + obs_normalization_stats (dict): a dictionary for observation + normalization. This maps observation keys to dicts + with a "mean" and "std" of shape (1, ...) where ... is the default + shape for the observation. + """ + assert self.hdf5_normalize_obs, "not using observation normalization!" + return deepcopy(self.obs_normalization_stats) + + def get_dataset_for_ep(self, ep, key): + """ + Helper utility to get a dataset for a specific demonstration. + Takes into account whether the dataset has been loaded into memory. + """ + + # check if this key should be in memory + key_should_be_in_memory = (self.hdf5_cache_mode in ["all", "low_dim"]) + if key_should_be_in_memory: + # if key is an observation, it may not be in memory + if '/' in key: + key1, key2 = key.split('/') + assert(key1 in ['obs', 'next_obs']) + if key2 not in self.obs_keys_in_memory: + key_should_be_in_memory = False + + if key_should_be_in_memory: + # read cache + if '/' in key: + key1, key2 = key.split('/') + assert(key1 in ['obs', 'next_obs']) + ret = self.hdf5_cache[ep][key1][key2] + else: + ret = self.hdf5_cache[ep][key] + else: + # read from file + hd5key = "data/{}/{}".format(ep, key) + ret = self.hdf5_file[hd5key] + return ret + + def __getitem__(self, index): + """ + Fetch dataset sequence @index (inferred through internal index map), using the getitem_cache if available. + """ + if self.hdf5_cache_mode == "all": + return self.getitem_cache[index] + return self.get_item(index) + + def get_item(self, index): + """ + Main implementation of getitem when not using cache. + """ + + demo_id = self._index_to_demo_id[index] + demo_start_index = self._demo_id_to_start_indices[demo_id] + demo_length = self._demo_id_to_demo_length[demo_id] + + # start at offset index if not padding for frame stacking + demo_index_offset = 0 if self.pad_frame_stack else (self.n_frame_stack - 1) + index_in_demo = index - demo_start_index + demo_index_offset + + # end at offset index if not padding for seq length + demo_length_offset = 0 if self.pad_seq_length else (self.seq_length - 1) + end_index_in_demo = demo_length - demo_length_offset + + meta = self.get_dataset_sequence_from_demo( + demo_id, + index_in_demo=index_in_demo, + keys=self.dataset_keys, + seq_length=self.seq_length + ) + + # determine goal index + goal_index = None + if self.goal_mode == "last": + goal_index = end_index_in_demo - 1 + + + # print(high_dim_keys) + # print(low_dim_keys) + + # self.get_obs_sequence_from_demo(demo_id, index_in_demo=index_in_demo, keys=self.obs_keys, num_frames_to_stack=self.n_frame_stack - 1, seq_length=self.obs_seq_length, prefix="obs") + + if self.lowdim_obs_seq_length is None: + meta["obs"] = self.get_obs_sequence_from_demo( + demo_id, + index_in_demo=index_in_demo, + keys=self.obs_keys, + num_frames_to_stack=self.n_frame_stack - 1, + seq_length=self.obs_seq_length, + prefix="obs" + ) + else: + high_dim_keys = [key for key in self.obs_keys if not ObsUtils.key_is_obs_modality(key, "low_dim")] + low_dim_keys = [key for key in self.obs_keys if ObsUtils.key_is_obs_modality(key, "low_dim")] + + meta["obs"] = self.get_obs_sequence_from_demo( + demo_id, + index_in_demo=index_in_demo, + keys=high_dim_keys, + num_frames_to_stack=self.n_frame_stack - 1, + seq_length=self.obs_seq_length, + prefix="obs" + ) + meta["obs"].update(self.get_obs_sequence_from_demo( + demo_id, + index_in_demo=index_in_demo, + keys=low_dim_keys, + num_frames_to_stack=self.n_frame_stack, + seq_length=self.lowdim_obs_seq_length, + prefix="obs" + )) + + if self.hdf5_normalize_obs: + meta["obs"] = ObsUtils.normalize_obs(meta["obs"], obs_normalization_stats=self.obs_normalization_stats) + + # print(meta['obs']) + # print(self.hdf5_normalize_obs) + # breakpoint() + + if self.load_next_obs: + # meta["next_obs"] = self.get_obs_sequence_from_demo( + # demo_id, + # index_in_demo=index_in_demo, + # keys=self.obs_keys, + # num_frames_to_stack=self.n_frame_stack - 1, + # seq_length=self.obs_seq_length, + # prefix="next_obs" + # ) + + # I'm redefining next obs to be the observation after the sequence + + next_obs_start = min(index_in_demo + self.seq_length, demo_length - 1) + meta["next_obs"] = self.get_obs_sequence_from_demo( + demo_id, + index_in_demo=next_obs_start, + keys=self.obs_keys, + num_frames_to_stack=self.n_frame_stack - 1, + seq_length=1, + prefix="obs" + ) + + if self.hdf5_normalize_obs: + meta["next_obs"] = ObsUtils.normalize_obs(meta["next_obs"], obs_normalization_stats=self.obs_normalization_stats) + + if goal_index is not None: + goal = self.get_obs_sequence_from_demo( + demo_id, + index_in_demo=goal_index, + keys=self.obs_keys, + num_frames_to_stack=0, + seq_length=1, + prefix="next_obs", + ) + if self.hdf5_normalize_obs: + goal = ObsUtils.normalize_obs(goal, obs_normalization_stats=self.obs_normalization_stats) + meta["goal_obs"] = {k: goal[k][0] for k in goal} # remove sequence dimension for goal + + return meta + + def get_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1): + """ + Extract a (sub)sequence of data items from a demo given the @keys of the items. + + Args: + demo_id (str): id of the demo, e.g., demo_0 + index_in_demo (int): beginning index of the sequence wrt the demo + keys (tuple): list of keys to extract + num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range + seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range + + Returns: + a dictionary of extracted items. + """ + assert num_frames_to_stack >= 0 + assert seq_length >= 1 + + demo_length = self._demo_id_to_demo_length[demo_id] + assert index_in_demo < demo_length + + # determine begin and end of sequence + seq_begin_index = max(0, index_in_demo - num_frames_to_stack) + seq_end_index = min(demo_length, index_in_demo + seq_length) + + # determine sequence padding + seq_begin_pad = max(0, num_frames_to_stack - index_in_demo) # pad for frame stacking + seq_end_pad = max(0, index_in_demo + seq_length - demo_length) # pad for sequence length + + # make sure we are not padding if specified. + if not self.pad_frame_stack: + assert seq_begin_pad == 0 + if not self.pad_seq_length: + assert seq_end_pad == 0 + + # fetch observation from the dataset file + seq = dict() + for k in keys: + data = self.get_dataset_for_ep(demo_id, k) + seq[k] = data[seq_begin_index: seq_end_index] + + seq = TensorUtils.pad_sequence(seq, padding=(seq_begin_pad, seq_end_pad), pad_same=True) + pad_mask = np.array([0] * seq_begin_pad + [1] * (seq_end_index - seq_begin_index) + [0] * seq_end_pad) + pad_mask = pad_mask[:, None].astype(bool) + + return seq, pad_mask + + def get_obs_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1, prefix="obs"): + """ + Extract a (sub)sequence of observation items from a demo given the @keys of the items. + + Args: + demo_id (str): id of the demo, e.g., demo_0 + index_in_demo (int): beginning index of the sequence wrt the demo + keys (tuple): list of keys to extract + num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range + seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range + prefix (str): one of "obs", "next_obs" + + Returns: + a dictionary of extracted items. + """ + obs, pad_mask = self.get_sequence_from_demo( + demo_id, + index_in_demo=index_in_demo, + keys=tuple('{}/{}'.format(prefix, k) for k in keys), + num_frames_to_stack=num_frames_to_stack, + seq_length=seq_length, + ) + obs = {k.split('/')[1]: obs[k] for k in obs} # strip the prefix + if self.get_pad_mask: + obs["pad_mask"] = pad_mask + + # prepare image observations from dataset + return ObsUtils.process_obs_dict(obs) + + def get_dataset_sequence_from_demo(self, demo_id, index_in_demo, keys, seq_length=1): + """ + Extract a (sub)sequence of dataset items from a demo given the @keys of the items (e.g., states, actions). + + Args: + demo_id (str): id of the demo, e.g., demo_0 + index_in_demo (int): beginning index of the sequence wrt the demo + keys (tuple): list of keys to extract + seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range + + Returns: + a dictionary of extracted items. + """ + data, pad_mask = self.get_sequence_from_demo( + demo_id, + index_in_demo=index_in_demo, + keys=keys, + num_frames_to_stack=0, # don't frame stack for meta keys + seq_length=seq_length, + ) + if self.get_pad_mask: + data["pad_mask"] = pad_mask + return data + + def get_trajectory_at_index(self, index): + """ + Method provided as a utility to get an entire trajectory, given + the corresponding @index. + """ + demo_id = self.demos[index] + demo_length = self._demo_id_to_demo_length[demo_id] + + meta = self.get_dataset_sequence_from_demo( + demo_id, + index_in_demo=0, + keys=self.dataset_keys, + seq_length=demo_length + ) + meta["obs"] = self.get_obs_sequence_from_demo( + demo_id, + index_in_demo=0, + keys=self.obs_keys, + seq_length=demo_length + ) + if self.load_next_obs: + meta["next_obs"] = self.get_obs_sequence_from_demo( + demo_id, + index_in_demo=0, + keys=self.obs_keys, + seq_length=demo_length, + prefix="next_obs" + ) + + meta["ep"] = demo_id + return meta + + def get_dataset_sampler(self): + """ + Return instance of torch.utils.data.Sampler or None. Allows + for dataset to define custom sampling logic, such as + re-weighting the probability of samples being drawn. + See the `train` function in scripts/train.py, and torch + `DataLoader` documentation, for more info. + """ + return None diff --git a/quest/utils/file_utils.py b/quest/utils/file_utils.py new file mode 100644 index 0000000..721062d --- /dev/null +++ b/quest/utils/file_utils.py @@ -0,0 +1,70 @@ +""" +A collection of utility functions for working with files, such as reading metadata from +demonstration datasets, loading model checkpoints, or downloading dataset files. + +This file is adopted from robomimic +https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/utils/file_utils.py +""" +import os +import h5py +from collections import OrderedDict + +import quest.utils.obs_utils as ObsUtils + + +def get_shape_metadata_from_dataset(dataset_path, all_obs_keys=None, verbose=False): + """ + Retrieves shape metadata from dataset. + + Args: + dataset_path (str): path to dataset + all_obs_keys (list): list of all modalities used by the model. If not provided, all modalities + present in the file are used. + verbose (bool): if True, include print statements + + Returns: + shape_meta (dict): shape metadata. Contains the following keys: + + :`'ac_dim'`: action space dimension + :`'all_shapes'`: dictionary that maps observation key string to shape + :`'all_obs_keys'`: list of all observation modalities used + :`'use_images'`: bool, whether or not image modalities are present + :`'use_depths'`: bool, whether or not depth modalities are present + """ + + shape_meta = {} + + # read demo file for some metadata + dataset_path = os.path.expanduser(dataset_path) + f = h5py.File(dataset_path, "r") + demo_id = list(f["data"].keys())[0] + demo = f["data/{}".format(demo_id)] + + # action dimension + shape_meta['ac_dim'] = f["data/{}/actions".format(demo_id)].shape[1] + + # observation dimensions + all_shapes = OrderedDict() + + if all_obs_keys is None: + # use all modalities present in the file + all_obs_keys = [k for k in demo["obs"]] + + for k in sorted(all_obs_keys): + initial_shape = demo["obs/{}".format(k)].shape[1:] + if verbose: + print("obs key {} with shape {}".format(k, initial_shape)) + # Store processed shape for each obs key + all_shapes[k] = ObsUtils.get_processed_shape( + obs_modality=ObsUtils.OBS_KEYS_TO_MODALITIES[k], + input_shape=initial_shape, + ) + + f.close() + + shape_meta['all_shapes'] = all_shapes + shape_meta['all_obs_keys'] = all_obs_keys + shape_meta['use_images'] = ObsUtils.has_modality("rgb", all_obs_keys) + shape_meta['use_depths'] = ObsUtils.has_modality("depth", all_obs_keys) + + return shape_meta diff --git a/quest/utils/frame_stack.py b/quest/utils/frame_stack.py new file mode 100644 index 0000000..9108d98 --- /dev/null +++ b/quest/utils/frame_stack.py @@ -0,0 +1,179 @@ +""" +The purpose of this file is to fix the silly gymnasium frame stack wrapper behavior I complain about here +https://github.com/Farama-Foundation/Gymnasium/issues/1085 +""" + +from __future__ import annotations + +from collections import deque +from copy import deepcopy +from typing import Any, Final, SupportsFloat + +import numpy as np + +import gymnasium as gym +from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType +from gymnasium.vector.utils import batch_space, concatenate, create_empty_array +from gymnasium.wrappers.utils import create_zero_array + + +class FrameStackObservationFixed( + gym.Wrapper[WrapperObsType, ActType, ObsType, ActType], + gym.utils.RecordConstructorArgs, +): + """Stacks the observations from the last ``N`` time steps in a rolling manner. + + For example, if the number of stacks is 4, then the returned observation contains + the most recent 4 observations. For environment 'Pendulum-v1', the original observation + is an array with shape [3], so if we stack 4 observations, the processed observation + has shape [4, 3]. + + Users have options for the padded observation used: + + * "reset" (default) - The reset value is repeated + * "zero" - A "zero"-like instance of the observation space + * custom - An instance of the observation space + + No vector version of the wrapper exists. + + Example: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import FrameStackObservation + >>> env = gym.make("CarRacing-v2") + >>> env = FrameStackObservation(env, stack_size=4) + >>> env.observation_space + Box(0, 255, (4, 96, 96, 3), uint8) + >>> obs, _ = env.reset() + >>> obs.shape + (4, 96, 96, 3) + + Example with different padding observations: + >>> env = gym.make("CartPole-v1") + >>> env.reset(seed=123) + (array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {}) + >>> stacked_env = FrameStackObservation(env, 3) # the default is padding_type="reset" + >>> stacked_env.reset(seed=123) + (array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], + [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], + [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]], + dtype=float32), {}) + + + >>> stacked_env = FrameStackObservation(env, 3, padding_type="zero") + >>> stacked_env.reset(seed=123) + (array([[ 0. , 0. , 0. , 0. ], + [ 0. , 0. , 0. , 0. ], + [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]], + dtype=float32), {}) + >>> stacked_env = FrameStackObservation(env, 3, padding_type=np.array([1, -1, 0, 2], dtype=np.float32)) + >>> stacked_env.reset(seed=123) + (array([[ 1. , -1. , 0. , 2. ], + [ 1. , -1. , 0. , 2. ], + [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]], + dtype=float32), {}) + + Change logs: + * v0.15.0 - Initially add as ``FrameStack`` with support for lz4 + * v1.0.0 - Rename to ``FrameStackObservation`` and remove lz4 and ``LazyFrame`` support + along with adding the ``padding_type`` parameter + + """ + + def __init__( + self, + env: gym.Env[ObsType, ActType], + stack_size: int, + *, + padding_type: str | ObsType = "reset", + ): + """Observation wrapper that stacks the observations in a rolling manner. + + Args: + env: The environment to apply the wrapper + stack_size: The number of frames to stack. + padding_type: The padding type to use when stacking the observations, options: "reset", "zero", custom obs + """ + gym.utils.RecordConstructorArgs.__init__( + self, stack_size=stack_size, padding_type=padding_type + ) + gym.Wrapper.__init__(self, env) + + if not np.issubdtype(type(stack_size), np.integer): + raise TypeError( + f"The stack_size is expected to be an integer, actual type: {type(stack_size)}" + ) + # if not 1 < stack_size: + # raise ValueError( + # f"The stack_size needs to be greater than one, actual value: {stack_size}" + # ) + if isinstance(padding_type, str) and ( + padding_type == "reset" or padding_type == "zero" + ): + self.padding_value: ObsType = create_zero_array(env.observation_space) + elif padding_type in env.observation_space: + self.padding_value = padding_type + padding_type = "_custom" + else: + if isinstance(padding_type, str): + raise ValueError( # we are guessing that the user just entered the "reset" or "zero" wrong + f"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: {padding_type!r}" + ) + else: + raise ValueError( + f"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: {padding_type!r} not an instance of env observation ({env.observation_space})" + ) + + self.observation_space = batch_space(env.observation_space, n=stack_size) + self.stack_size: Final[int] = stack_size + self.padding_type: Final[str] = padding_type + + self.obs_queue = deque( + [self.padding_value for _ in range(self.stack_size)], maxlen=self.stack_size + ) + self.stacked_obs = create_empty_array(env.observation_space, n=self.stack_size) + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Steps through the environment, appending the observation to the frame buffer. + + Args: + action: The action to step through the environment with + + Returns: + Stacked observations, reward, terminated, truncated, and info from the environment + """ + obs, reward, terminated, truncated, info = self.env.step(action) + self.obs_queue.append(obs) + + updated_obs = deepcopy( + concatenate(self.env.observation_space, self.obs_queue, self.stacked_obs) + ) + return updated_obs, reward, terminated, truncated, info + + def reset( + self, *, seed: int | None = None, + options: dict[str, Any] | None = None, + **kwargs + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Reset the environment, returning the stacked observation and info. + + Args: + seed: The environment seed + options: The reset options + + Returns: + The stacked observations and info + """ + obs, info = self.env.reset(seed=seed, options=options, **kwargs) + + if self.padding_type == "reset": + self.padding_value = obs + for _ in range(self.stack_size - 1): + self.obs_queue.append(self.padding_value) + self.obs_queue.append(obs) + + updated_obs = deepcopy( + concatenate(self.env.observation_space, self.obs_queue, self.stacked_obs) + ) + return updated_obs, info diff --git a/quest/utils/libero_utils.py b/quest/utils/libero_utils.py new file mode 100644 index 0000000..e4adb64 --- /dev/null +++ b/quest/utils/libero_utils.py @@ -0,0 +1,358 @@ +import copy + +# import gym.spaces +# import gym.wrappers +import gymnasium +from collections import OrderedDict, deque +import os +import numpy as np +import quest.utils.file_utils as FileUtils +import quest.utils.obs_utils as ObsUtils +import quest.utils.utils as utils +from PIL import Image +from quest.utils.dataset import SequenceDataset +from torch.utils.data import Dataset +from quest.utils.frame_stack import FrameStackObservationFixed +import torch +import torch.nn as nn +from torch.utils.data import ConcatDataset +# import gym +os.environ["TOKENIZERS_PARALLELISM"] = "false" +from libero.libero.benchmark import get_benchmark +from transformers import AutoModel, AutoTokenizer, logging +from hydra.utils import to_absolute_path +import time +from libero.libero import get_libero_path +from libero.libero.envs import OffScreenRenderEnv, SubprocVectorEnv, DummyVectorEnv +from libero.libero.utils.time_utils import Timer +import multiprocessing +import math +import matplotlib.pyplot as plt +import robosuite.utils.transform_utils as T +import h5py +from gymnasium.vector.utils import batch_space +from tqdm import trange +np.set_printoptions(suppress=True) + + +class LiberoVectorWrapper(gymnasium.Env): + def __init__(self, + env_factory, + env_num): + env_creation, count = False, 0 + while not env_creation and count < 5: + try: + if env_num == 1: + env = DummyVectorEnv([env_factory]) + else: + env = SubprocVectorEnv([env_factory for _ in range(env_num)]) + env_creation = True + except Exception as e: + print(e) + time.sleep(5) + count += 1 + if count >= 5: + raise Exception("Failed to create environment") + self._env = env + self.action_space = batch_space(self._env.action_space[0], env_num) + self.observation_space = batch_space(self._env.observation_space[0], env_num) + + def reset(self, init_states, *args, **kwargs): + obs, info = self._env.reset(*args, **kwargs) + obs = self.process_obs(obs) + self._env.set_init_state(init_states) + return obs, info + + def step(self, *args, **kwargs): + obs, reward, terminated, truncated, info = self._env.step(*args, **kwargs) + obs = self.process_obs(obs) + return obs, reward, terminated, truncated, info + + def render(self, *args, **kwargs): + return self._env.render(*args, **kwargs) + + def process_obs(self, obs): + """LIBERO vectorization wrapper does not handle dict obs well""" + obs_out = {key: [] for key in obs[0]} + for env_obs in obs: + for key in obs_out: + obs_out[key].append(env_obs[key]) + for key in obs_out: + obs_out[key] = np.array(obs_out[key]) + return obs_out + + +class LiberoFrameStack(FrameStackObservationFixed): + def set_init_state(self, *args, **kwargs): + return self.env.set_init_state(*args, **kwargs) + + +class LiberoWrapper(gymnasium.Env): + def __init__(self, + task_id, + benchmark, + shape_meta, + obs_key_mapping, + img_height=128, + img_width=128, + cameras=('agentview', 'robot0_eye_in_hand'), + device="cuda",): + self.img_width = img_width + self.img_height = img_height + obs_meta = shape_meta['observation'] + self.rgb_outputs = list(obs_meta['rgb']) + self.lowdim_outputs = list(obs_meta['lowdim']) + self.cameras = cameras + self.obs_key_mapping = obs_key_mapping + + self.device = device + env_args = { + "bddl_file_name": benchmark.get_task_bddl_file_path(task_id), + "camera_heights": img_height, + "camera_widths": img_width, + 'camera_names': cameras + } + + env = OffScreenRenderEnv(**env_args) + self.env = env + + obs_space_dict = {} + for key in self.rgb_outputs: + obs_space_dict[key] = gymnasium.spaces.Box( + low=0, + high=255, + shape=(img_height, img_width, 3), + dtype=np.uint8 + ) + for key in self.lowdim_outputs: + obs_space_dict[key] = gymnasium.spaces.Box( + low=-np.inf, + high=np.inf, + shape=(obs_meta['lowdim'][key],), + dtype=np.float32 + ) + self.observation_space = gymnasium.spaces.Dict(obs_space_dict) + self.action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32) + self.render_out = None + + def reset(self, init_states=None, **kwargs): + self.env.reset() + if init_states is not None: + raw_obs = self.env.set_init_state(init_states) + # dummy actions [ 7] all zeros for initial physics simulation (as in the original LIBERO code) + dummy = np.zeros((7,)) + for _ in range(5): + raw_obs, _, _, _ = self.env.step(dummy) + return self.make_obs(raw_obs), {} + + def step(self, action): + raw_obs, reward, truncated, info = self.env.step(action) + obs = self.make_obs(raw_obs) + info['success'] = self.env.check_success() + terminated = info['success'] + return obs, reward, terminated, truncated, info + + def set_init_state(self, *args, **kwargs): + self.env.set_init_state(*args, **kwargs) + + def make_obs(self, raw_obs): + obs = {} + self.render_out = raw_obs[f'{self.cameras[0]}_image'][::-1] + + for key in self.rgb_outputs: + obs[key] = raw_obs[self.obs_key_mapping[key]] + + for key in self.lowdim_outputs: + obs[key] = raw_obs[self.obs_key_mapping[key]] + + return obs + + def render(self, *args, **kwargs): + return self.render_out + +def build_dataset(data_prefix, + suite_name, + benchmark_name, + mode, + seq_len, + frame_stack, + shape_meta, + n_demos, + extra_obs_modality=None, + obs_seq_len=1, + load_obs=True, + task_embedding_format="clip", + ): + benchmark = get_benchmark(benchmark_name)() + n_tasks = benchmark.n_tasks + few_shot_demos = [1, 5, 10, 20, 45] if mode == 'fewshot' else None + few_shot_demos_list = [f"demo_{i}" for i in few_shot_demos] if few_shot_demos is not None else None + + manip_datasets = [] + descriptions = [] + # for key, value in shape_meta + obs_modality = { + 'rgb': list(shape_meta['observation']['rgb'].keys()), + 'low_dim': list(shape_meta['observation']['lowdim'].keys()), + } + if extra_obs_modality is not None: + for key in extra_obs_modality: + obs_modality[key] = obs_modality[key] + extra_obs_modality[key] + # breakpoint() + ObsUtils.initialize_obs_utils_with_obs_specs({"obs": obs_modality}) + for i in trange(n_tasks): + task_i_dataset = get_dataset( + dataset_path=os.path.join( + data_prefix, suite_name, benchmark.get_task_demonstration(i) + ), + obs_modality=obs_modality, + seq_len=seq_len, + obs_seq_len=obs_seq_len, + frame_stack=frame_stack, + load_obs=load_obs, + few_demos = few_shot_demos_list, + n_demos=n_demos, + ) + task_description = benchmark.get_task(i).language + descriptions.append(task_description) + manip_datasets.append(task_i_dataset) + task_embs = get_task_embs(task_embedding_format, descriptions) + benchmark.set_task_embs(task_embs) + datasets = [ + SequenceVLDataset(ds, emb, i) for i,(ds, emb) in enumerate(zip(manip_datasets, task_embs)) + ] + n_demos = [data.n_demos for data in datasets] + n_sequences = [data.total_num_sequences for data in datasets] + concat_dataset = ConcatDataset(datasets) + print("\n=================== Benchmark Information ===================") + print(f" Name: {benchmark.name}") + print(f" # Tasks: {n_tasks}") + print(" # demonstrations: " + " ".join(f"({x})" for x in n_demos)) + print(" # sequences: " + " ".join(f"({x})" for x in n_sequences)) + print("=======================================================================\n") + return concat_dataset + +def get_dataset( + dataset_path, + obs_modality, + seq_len=1, + obs_seq_len=1, + frame_stack=1, + filter_key=None, + hdf5_cache_mode="low_dim", + load_obs=True, + few_demos=None, + n_demos=None, + ): + all_obs_keys = [] + for modality_name, modality_list in obs_modality.items(): + all_obs_keys += modality_list + shape_meta = FileUtils.get_shape_metadata_from_dataset( + dataset_path=dataset_path, all_obs_keys=all_obs_keys, verbose=False + ) + seq_len = seq_len + filter_key = filter_key + if load_obs: + obs_keys = shape_meta["all_obs_keys"] + else: + obs_keys = [] + dataset = SequenceDataset( + hdf5_path=dataset_path, + obs_keys=obs_keys, + dataset_keys=["actions"], + load_next_obs=False, + frame_stack=frame_stack, + seq_length=seq_len, # length-10 temporal sequences + obs_seq_length=obs_seq_len, + pad_frame_stack=True, + pad_seq_length=True, # pad last obs per trajectory to ensure all sequences are sampled + get_pad_mask=False, + goal_mode=None, + hdf5_cache_mode=hdf5_cache_mode, # cache dataset in memory to avoid repeated file i/o + hdf5_use_swmr=False, + hdf5_normalize_obs=None, + filter_by_attribute=filter_key, # can optionally provide a filter key here + few_demos=few_demos, + n_demos=n_demos, + ) + return dataset + +class SequenceVLDataset(Dataset): + def __init__(self, sequence_dataset, task_emb, task_id): + self.sequence_dataset = sequence_dataset + self.task_emb = task_emb + self.task_id = task_id + self.n_demos = self.sequence_dataset.n_demos + self.total_num_sequences = self.sequence_dataset.total_num_sequences + + def __len__(self): + return len(self.sequence_dataset) + + def __getitem__(self, idx): + return_dict = self.sequence_dataset.__getitem__(idx) + return_dict["task_emb"] = self.task_emb + return_dict["task_id"] = self.task_id + return return_dict + +def get_task_embs(task_embedding_format, descriptions): + logging.set_verbosity_error() + if task_embedding_format == "bert": + tz = AutoTokenizer.from_pretrained( + "bert-base-cased", cache_dir=to_absolute_path("./bert") + ) + model = AutoModel.from_pretrained( + "bert-base-cased", cache_dir=to_absolute_path("./bert") + ) + tokens = tz( + text=descriptions, # the sentence to be encoded + add_special_tokens=True, # Add [CLS] and [SEP] + max_length=25, # maximum length of a sentence + padding="max_length", + return_attention_mask=True, # Generate the attention mask + return_tensors="pt", # ask the function to return PyTorch tensors + ) + masks = tokens["attention_mask"] + input_ids = tokens["input_ids"] + task_embs = model(tokens["input_ids"], tokens["attention_mask"])[ + "pooler_output" + ].detach() + elif task_embedding_format == "gpt2": + tz = AutoTokenizer.from_pretrained("gpt2") + tz.pad_token = tz.eos_token + model = AutoModel.from_pretrained("gpt2") + tokens = tz( + text=descriptions, # the sentence to be encoded + add_special_tokens=True, # Add [CLS] and [SEP] + max_length=25, # maximum length of a sentence + padding="max_length", + return_attention_mask=True, # Generate the attention mask + return_tensors="pt", # ask the function to return PyTorch tensors + ) + task_embs = model(**tokens)["last_hidden_state"].detach()[:, -1] + elif task_embedding_format == "clip": + tz = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", clean_up_tokenization_spaces=True) + model = AutoModel.from_pretrained("openai/clip-vit-base-patch32") + tokens = tz( + text=descriptions, # the sentence to be encoded + add_special_tokens=True, # Add [CLS] and [SEP] + max_length=25, # maximum length of a sentence + padding="max_length", + return_attention_mask=True, # Generate the attention mask + return_tensors="pt", # ask the function to return PyTorch tensors + ) + task_embs = model.get_text_features(**tokens).detach() + elif task_embedding_format == "roberta": + tz = AutoTokenizer.from_pretrained("roberta-base") + tz.pad_token = tz.eos_token + model = AutoModel.from_pretrained("roberta-base") + tokens = tz( + text=descriptions, # the sentence to be encoded + add_special_tokens=True, # Add [CLS] and [SEP] + max_length=25, # maximum length of a sentence + padding="max_length", + return_attention_mask=True, # Generate the attention mask + return_tensors="pt", # ask the function to return PyTorch tensors + ) + task_embs = model(**tokens)["pooler_output"].detach() + return task_embs + diff --git a/quest/utils/logger.py b/quest/utils/logger.py new file mode 100644 index 0000000..7a9b8f5 --- /dev/null +++ b/quest/utils/logger.py @@ -0,0 +1,43 @@ +import wandb +import numpy as np + +class Logger: + """ + The purpose of this simple logger is to log intermittently and log average values since the last log + """ + def __init__(self, log_interval): + self.log_interval = log_interval + self.data = None + + def update(self, info, step): + info = flatten_dict(info) + if self.data is None: + self.data = {key: [] for key in info} + + for key in info: + self.data[key].append(info[key]) + + if step % self.log_interval == 0: + means = {key: np.mean(value) for key, value in self.data.items()} + self.log(means, step) + self.data = None + + def log(self, info, step): + info_flat = flatten_dict(info) + wandb.log(info_flat, step=step) + + +def flatten_dict(in_dict): + """ + The purpose of this is to flatten dictionaries because as of writing wandb handling nested dicts is broken :( + https://community.wandb.ai/t/the-wandb-log-function-does-not-treat-nested-dict-as-it-describes-in-the-document/3330 + """ + + out_dict = {} + for key, value in in_dict.items(): + if type(value) is dict: + for inner_key, inner_value in value.items(): + out_dict[f'{key}/{inner_key}'] = inner_value + else: + out_dict[key] = value + return out_dict \ No newline at end of file diff --git a/quest/utils/metaworld_utils.py b/quest/utils/metaworld_utils.py new file mode 100644 index 0000000..60c6443 --- /dev/null +++ b/quest/utils/metaworld_utils.py @@ -0,0 +1,513 @@ +import copy +from collections import OrderedDict + +import numpy as np +import quest.utils.file_utils as FileUtils +import quest.utils.obs_utils as ObsUtils +from PIL import Image +from quest.utils.dataset import SequenceDataset +from torch.utils.data import Dataset +from quest.utils.frame_stack import FrameStackObservationFixed +import torch +import torch.nn as nn +import gymnasium +from gymnasium.envs.mujoco.mujoco_rendering import OffScreenViewer +import math +import mujoco +import os +from torch.utils.data import ConcatDataset +import metaworld + +from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE + +from pyinstrument import Profiler + + +from metaworld.policies import * + + +_policies = OrderedDict( + [ + ("assembly-v2", SawyerAssemblyV2Policy), + ("basketball-v2", SawyerBasketballV2Policy), + ("bin-picking-v2", SawyerBinPickingV2Policy), + ("box-close-v2", SawyerBoxCloseV2Policy), + ("button-press-topdown-v2", SawyerButtonPressTopdownV2Policy), + ("button-press-topdown-wall-v2", SawyerButtonPressTopdownWallV2Policy), + ("button-press-v2", SawyerButtonPressV2Policy), + ("button-press-wall-v2", SawyerButtonPressWallV2Policy), + ("coffee-button-v2", SawyerCoffeeButtonV2Policy), + ("coffee-pull-v2", SawyerCoffeePullV2Policy), + ("coffee-push-v2", SawyerCoffeePushV2Policy), + ("dial-turn-v2", SawyerDialTurnV2Policy), + ("disassemble-v2", SawyerDisassembleV2Policy), + ("door-close-v2", SawyerDoorCloseV2Policy), + ("door-lock-v2", SawyerDoorLockV2Policy), + ("door-open-v2", SawyerDoorOpenV2Policy), + ("door-unlock-v2", SawyerDoorUnlockV2Policy), + ("drawer-close-v2", SawyerDrawerCloseV2Policy), + ("drawer-open-v2", SawyerDrawerOpenV2Policy), + ("faucet-close-v2", SawyerFaucetCloseV2Policy), + ("faucet-open-v2", SawyerFaucetOpenV2Policy), + ("hammer-v2", SawyerHammerV2Policy), + ("hand-insert-v2", SawyerHandInsertV2Policy), + ("handle-press-side-v2", SawyerHandlePressSideV2Policy), + ("handle-press-v2", SawyerHandlePressV2Policy), + ("handle-pull-v2", SawyerHandlePullV2Policy), + ("handle-pull-side-v2", SawyerHandlePullSideV2Policy), + ("peg-insert-side-v2", SawyerPegInsertionSideV2Policy), + ("lever-pull-v2", SawyerLeverPullV2Policy), + ("peg-unplug-side-v2", SawyerPegUnplugSideV2Policy), + ("pick-out-of-hole-v2", SawyerPickOutOfHoleV2Policy), + ("pick-place-v2", SawyerPickPlaceV2Policy), + ("pick-place-wall-v2", SawyerPickPlaceWallV2Policy), + ("plate-slide-back-side-v2", SawyerPlateSlideBackSideV2Policy), + ("plate-slide-back-v2", SawyerPlateSlideBackV2Policy), + ("plate-slide-side-v2", SawyerPlateSlideSideV2Policy), + ("plate-slide-v2", SawyerPlateSlideV2Policy), + ("reach-v2", SawyerReachV2Policy), + ("reach-wall-v2", SawyerReachWallV2Policy), + ("push-back-v2", SawyerPushBackV2Policy), + ("push-v2", SawyerPushV2Policy), + ("push-wall-v2", SawyerPushWallV2Policy), + ("shelf-place-v2", SawyerShelfPlaceV2Policy), + ("soccer-v2", SawyerSoccerV2Policy), + ("stick-pull-v2", SawyerStickPullV2Policy), + ("stick-push-v2", SawyerStickPushV2Policy), + ("sweep-into-v2", SawyerSweepIntoV2Policy), + ("sweep-v2", SawyerSweepV2Policy), + ("window-close-v2", SawyerWindowCloseV2Policy), + ("window-open-v2", SawyerWindowOpenV2Policy), + ] +) +_env_names = list(_policies) + +classes = { + 'ML45': { + 'train': ['assembly-v2', + 'basketball-v2', + 'button-press-topdown-v2', + 'button-press-topdown-wall-v2', + 'button-press-v2', + 'button-press-wall-v2', + 'coffee-button-v2', + 'coffee-pull-v2', + 'coffee-push-v2', + 'dial-turn-v2', + 'disassemble-v2', + 'door-close-v2', + 'door-open-v2', + 'drawer-close-v2', + 'drawer-open-v2', + 'faucet-close-v2', + 'faucet-open-v2', + 'hammer-v2', + 'handle-press-side-v2', + 'handle-press-v2', + 'handle-pull-side-v2', + 'handle-pull-v2', + 'lever-pull-v2', + 'peg-insert-side-v2', + 'peg-unplug-side-v2', + 'pick-out-of-hole-v2', + 'pick-place-v2', + 'pick-place-wall-v2', + 'plate-slide-back-side-v2', + 'plate-slide-back-v2', + 'plate-slide-side-v2', + 'plate-slide-v2', + 'push-back-v2', + 'push-v2', + 'push-wall-v2', + 'reach-v2', + 'reach-wall-v2', + 'shelf-place-v2', + 'soccer-v2', + 'stick-pull-v2', + 'stick-push-v2', + 'sweep-into-v2', + 'sweep-v2', + 'window-close-v2', + 'window-open-v2'], + 'test': ['bin-picking-v2', + 'box-close-v2', + 'door-lock-v2', + 'door-unlock-v2', + 'hand-insert-v2'] + }, + 'ML45_PRISE': { + 'train': [ + 'assembly-v2', + 'basketball-v2', + 'bin-picking-v2', + 'button-press-topdown-v2', + 'button-press-topdown-wall-v2', + 'button-press-v2', + 'button-press-wall-v2', + 'coffee-button-v2', + 'coffee-pull-v2', + 'coffee-push-v2', + 'dial-turn-v2', + 'door-close-v2', + 'door-lock-v2', + 'door-open-v2', + 'door-unlock-v2', + 'drawer-close-v2', + 'drawer-open-v2', + 'faucet-close-v2', + 'faucet-open-v2', + 'hammer-v2', + 'handle-press-side-v2', + 'handle-press-v2', + 'handle-pull-side-v2', + 'handle-pull-v2', + 'lever-pull-v2', + 'peg-insert-side-v2', + 'peg-unplug-side-v2', + 'pick-out-of-hole-v2', + 'pick-place-v2', + 'plate-slide-back-side-v2', + 'plate-slide-back-v2', + 'plate-slide-side-v2', + 'plate-slide-v2', + 'push-back-v2', + 'push-v2', + 'push-wall-v2', + 'reach-v2', + 'reach-wall-v2', + 'shelf-place-v2', + 'soccer-v2', + 'stick-push-v2', + 'sweep-into-v2', + 'sweep-v2', + 'window-close-v2', + 'window-open-v2'], + 'test': [ + 'box-close-v2', + 'disassemble-v2', + 'hand-insert-v2', + 'pick-place-wall-v2', + 'stick-pull-v2', + ] + + }, + 'MT50': { + 'train': list(_env_names), + 'test': [] + } +} + +def get_index(env_name): + return _env_names.index(env_name) + +def get_expert(): + env_experts = { + env_name: _policies[env_name]() for env_name in _policies + } + + def expert(obs, task_id): + obs_gt = obs['obs_gt'].squeeze() + return env_experts[_env_names[task_id]].get_action(obs_gt) + + return expert + +def get_env_expert(env_name): + return _policies[env_name]() + +def get_benchmark(benchmark_name): + benchmarks = { + 'ML1': metaworld.ML1, + 'ML10': metaworld.ML10, + 'ML45': metaworld.ML45, + 'MT50': metaworld.MT50, + 'ML45_PRISE': ML45PRISEBenchmark, + } + return benchmarks[benchmark_name]() + +def get_env_names(benchmark=None, mode=None): + if benchmark is None: + return list(_env_names) + + if type(benchmark) is str: + return classes[benchmark][mode] + else: + env_names = list(benchmark.train_classes \ + if mode == 'train' else benchmark.test_classes) + env_names.sort() + return env_names + +def get_tasks(benchmark, mode): + if benchmark is None: + return [] + return benchmark.train_tasks if mode == 'train' else benchmark.test_tasks + + +class MetaWorldFrameStack(FrameStackObservationFixed): + def __init__(self, + env_name, + env_factory, + num_stack, + ): + self.num_stack = num_stack + + env = env_factory(env_name) + super().__init__(env, num_stack) + + def set_task(self, task): + self.env.set_task(task) + + +class MetaWorldWrapper(gymnasium.Wrapper): + def __init__(self, + env_name: str, + shape_meta, + img_height: int = 128, + img_width: int = 128, + cameras=('corner2',), + env_kwargs=None,): + if env_kwargs is None: + env_kwargs = {} + env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[f'{env_name}-goal-observable'](**env_kwargs) + env._freeze_rand_vec = False + super().__init__(env) + # I don't know why + self.env.model.cam_pos[2] = [0.75, 0.075, 0.7] + + self.img_width = img_width + self.img_height = img_height + obs_meta = shape_meta['observation'] + self.rgb_outputs = list(obs_meta['rgb']) + self.lowdim_outputs = list(obs_meta['lowdim']) + + self.cameras = cameras + self.viewer = OffScreenViewer( + env.model, + env.data, + img_width, + img_height, + env.mujoco_renderer.max_geom, + env.mujoco_renderer._vopt, + ) + + obs_space_dict = {} + for key in self.rgb_outputs: + obs_space_dict[key] = gymnasium.spaces.Box( + low=0, + high=255, + shape=(img_height, img_width, 3), + dtype=np.uint8 + ) + for key in self.lowdim_outputs: + obs_space_dict[key] = gymnasium.spaces.Box( + low=-np.inf, + high=np.inf, + shape=(obs_meta['lowdim'][key],), + dtype=np.float32 + ) + self.observation_space = gymnasium.spaces.Dict(obs_space_dict) + + def step(self, action): + obs_gt, reward, terminated, truncated, info = super().step(action) + obs_gt = obs_gt.astype(np.float32) + info['obs_gt'] = obs_gt + + next_obs = self.make_obs(obs_gt) + + terminated = info['success'] == 1 + return next_obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + obs_gt, info = super().reset() + obs_gt = obs_gt.astype(np.float32) + info['obs_gt'] = obs_gt + + obs = self.make_obs(obs_gt) + + return obs, info + + def make_obs(self, obs_gt): + obs = {} + obs['robot_states'] = np.concatenate((obs_gt[:4],obs_gt[18:22])) + obs['obs_gt'] = obs_gt + + image_dict = {} + for camera_name in self.cameras: + image_obs = self.render(camera_name=camera_name, mode='all') + image_dict[camera_name] = image_obs + for key in self.rgb_outputs: + obs[key] = image_dict[f'{key[:-4]}2'][::-1] # since generated dataset at the time had corner key instead of corner2 + + return obs + + def render(self, camera_name=None, mode='rgb_array'): + if camera_name is None: + camera_name = self.cameras[0] + cam_id = mujoco.mj_name2id(self.env.model, + mujoco.mjtObj.mjOBJ_CAMERA, + camera_name) + + return self.viewer.render( + render_mode=mode, + camera_id=cam_id + ) + + def set_task(self, task): + self.env.set_task(task) + self.env._partially_observable = False + + def seed(self, seed): + self.env.seed(seed) + + def close(self): + self.viewer.close() + + +def build_dataset(data_prefix, + suite_name, + benchmark_name, + mode, + seq_len, + frame_stack, + shape_meta, + extra_obs_modality=None, + obs_seq_len=1, + lowdim_obs_seq_len=None, + load_obs=True, + n_demos=None, + load_next_obs=False, + dataset_keys=('actions',) + ): + task_names = get_env_names(benchmark_name, mode) + n_tasks = len(task_names) + datasets = [] + + obs_modality = { + 'rgb': list(shape_meta['observation']['rgb'].keys()), + 'low_dim': list(shape_meta['observation']['lowdim'].keys()) + } + if extra_obs_modality is not None: + for key in extra_obs_modality: + obs_modality[key] = obs_modality[key] + extra_obs_modality[key] + + ObsUtils.initialize_obs_utils_with_obs_specs({"obs": obs_modality}) + for task_name in task_names: + # currently we assume tasks from same benchmark have the same shape_meta + task_i_dataset = get_task_dataset( + dataset_path=os.path.join( + data_prefix, + suite_name, + benchmark_name, + mode, + f"{task_name}.hdf5" + ), + obs_modality=obs_modality, + seq_len=seq_len, + obs_seq_len=obs_seq_len, + lowdim_obs_seq_len=lowdim_obs_seq_len, + load_obs=load_obs, + frame_stack=frame_stack, + n_demos=n_demos, + load_next_obs=load_next_obs, + dataset_keys=dataset_keys + ) + # loaded_datasets.append(task_i_dataset) + task_id = get_index(task_name) + datasets.append(SequenceVLDataset(task_i_dataset, task_id)) + n_demos = [dataset.n_demos for dataset in datasets] + n_sequences = [dataset.total_num_sequences for dataset in datasets] + concat_dataset = ConcatDataset(datasets) + print("\n=================== Benchmark Information ===================") + print(f" Name: MetaWorld") + print(f" # Tasks: {n_tasks}") + print(" # demonstrations: " + " ".join(f"({x})" for x in n_demos)) + print(" # sequences: " + " ".join(f"({x})" for x in n_sequences)) + print("=======================================================================\n") + + return concat_dataset + + +def get_task_dataset( + dataset_path, + obs_modality, + seq_len=1, + obs_seq_len=1, + lowdim_obs_seq_len=None, + frame_stack=1, + filter_key=None, + hdf5_cache_mode="low_dim", + few_demos=None, + load_obs=True, + n_demos=None, + load_next_obs=False, + dataset_keys=None, +): + all_obs_keys = [] + for modality_name, modality_list in obs_modality.items(): + all_obs_keys += modality_list + shape_meta = FileUtils.get_shape_metadata_from_dataset( + dataset_path=dataset_path, all_obs_keys=all_obs_keys, verbose=False + ) + seq_len = seq_len + filter_key = filter_key + if load_obs: + obs_keys = shape_meta["all_obs_keys"] + else: + obs_keys = [] + + if dataset_keys is None: + dataset_keys = ['actions',] + dataset = SequenceDataset( + hdf5_path=dataset_path, + obs_keys=obs_keys, + dataset_keys=dataset_keys, + load_next_obs=load_next_obs, + frame_stack=frame_stack, + seq_length=seq_len, # length-10 temporal sequences + obs_seq_length=obs_seq_len, + lowdim_obs_seq_length=lowdim_obs_seq_len, + pad_frame_stack=True, + pad_seq_length=True, # pad last obs per trajectory to ensure all sequences are sampled + get_pad_mask=False, + goal_mode=None, + hdf5_cache_mode=hdf5_cache_mode, # cache dataset in memory to avoid repeated file i/o + hdf5_use_swmr=False, + hdf5_normalize_obs=None, + filter_by_attribute=filter_key, # can optionally provide a filter key here + few_demos=few_demos, + n_demos=n_demos, + ) + return dataset + + +class SequenceVLDataset(Dataset): + # Note: task_id should be a string + def __init__(self, sequence_dataset, task_id): + self.sequence_dataset = sequence_dataset + self.task_id = task_id + self.n_demos = self.sequence_dataset.n_demos + self.total_num_sequences = self.sequence_dataset.total_num_sequences + + def __len__(self): + return len(self.sequence_dataset) + + def __getitem__(self, idx): + return_dict = self.sequence_dataset.__getitem__(idx) + return_dict["task_id"] = self.task_id + return return_dict + + +class ML45PRISEBenchmark(object): + def __init__(self): + benchmark = metaworld.ML45() + all_classes = dict(benchmark.train_classes) + all_classes.update(benchmark.test_classes) + self.train_classes = {name: all_classes[name] for name in classes['ML45_PRISE']['train']} + self.test_classes = {name: all_classes[name] for name in classes['ML45_PRISE']['test']} + + self.train_tasks = [] + self.test_tasks = [] + for task in benchmark.train_tasks + benchmark.test_tasks: + if task.env_name in classes['ML45_PRISE']['train']: + self.train_tasks.append(task) + else: + self.test_tasks.append(task) diff --git a/quest/utils/obs_utils.py b/quest/utils/obs_utils.py new file mode 100644 index 0000000..b8dfadc --- /dev/null +++ b/quest/utils/obs_utils.py @@ -0,0 +1,993 @@ +""" +A collection of utilities for working with observation dictionaries and +different kinds of modalities such as images. + +This file is adopted from Robomimic +https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/utils/obs_utils.py +""" +import numpy as np +from copy import deepcopy +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +import quest.utils.tensor_utils as TU + +# MACRO FOR VALID IMAGE CHANNEL SIZES +VALID_IMAGE_CHANNEL_DIMS = {1, 3} # depth, rgb + +# DO NOT MODIFY THIS! +# This keeps track of observation types (modalities) - and is populated on call to @initialize_obs_utils_with_obs_specs. +# This will be a dictionary that maps observation modality (e.g. low_dim, rgb) to a list of observation +# keys under that observation modality. +OBS_MODALITIES_TO_KEYS = None + +# DO NOT MODIFY THIS! +# This keeps track of observation types (modalities) - and is populated on call to @initialize_obs_utils_with_obs_specs. +# This will be a dictionary that maps observation keys to their corresponding observation modality +# (e.g. low_dim, rgb) +OBS_KEYS_TO_MODALITIES = None + +# DO NOT MODIFY THIS +# This holds the default encoder kwargs that will be used if none are passed at runtime for any given network +DEFAULT_ENCODER_KWARGS = None + +# DO NOT MODIFY THIS +# This holds the registered observation modality classes +OBS_MODALITY_CLASSES = {} + +# DO NOT MODIFY THIS +# This global dict stores mapping from observation encoder / randomizer network name to class. +# We keep track of these registries to enable automated class inference at runtime, allowing +# users to simply extend our base encoder / randomizer class and refer to that class in string form +# in their config, without having to manually register their class internally. +# This also future-proofs us for any additional encoder / randomizer classes we would +# like to add ourselves. +OBS_ENCODER_CORES = {"None": None} # Include default None +OBS_RANDOMIZERS = {"None": None} # Include default None + + +def register_obs_key(target_class): + assert target_class not in OBS_MODALITY_CLASSES, f"Already registered modality {target_class}!" + OBS_MODALITY_CLASSES[target_class.name] = target_class + + +def register_encoder_core(target_class): + assert target_class not in OBS_ENCODER_CORES, f"Already registered obs encoder core {target_class}!" + OBS_ENCODER_CORES[target_class.__name__] = target_class + + +def register_randomizer(target_class): + assert target_class not in OBS_RANDOMIZERS, f"Already registered obs randomizer {target_class}!" + OBS_RANDOMIZERS[target_class.__name__] = target_class + + +class ObservationKeyToModalityDict(dict): + """ + Custom dictionary class with the sole additional purpose of automatically registering new "keys" at runtime + without breaking. This is mainly for backwards compatibility, where certain keys such as "latent", "actions", etc. + are used automatically by certain models (e.g.: VAEs) but were never specified by the user externally in their + config. Thus, this dictionary will automatically handle those keys by implicitly associating them with the low_dim + modality. + """ + def __getitem__(self, item): + # If a key doesn't already exist, warn the user and add default mapping + if item not in self.keys(): + print(f"ObservationKeyToModalityDict: {item} not found," + f" adding {item} to mapping with assumed low_dim modality!") + self.__setitem__(item, "low_dim") + return super(ObservationKeyToModalityDict, self).__getitem__(item) + + +def obs_encoder_kwargs_from_config(obs_encoder_config): + """ + Generate a set of args used to create visual backbones for networks + from the observation encoder config. + + Args: + obs_encoder_config (Config): Config object containing relevant encoder information. Should be equivalent to + config.observation.encoder + + Returns: + dict: Processed encoder kwargs + """ + # Loop over each obs modality + # Unlock encoder config + obs_encoder_config.unlock() + for obs_modality, encoder_kwargs in obs_encoder_config.items(): + # First run some sanity checks and store the classes + for cls_name, cores in zip(("core", "obs_randomizer"), (OBS_ENCODER_CORES, OBS_RANDOMIZERS)): + # Make sure the requested encoder for each obs_modality exists + cfg_cls = encoder_kwargs[f"{cls_name}_class"] + if cfg_cls is not None: + assert cfg_cls in cores, f"No {cls_name} class with name {cfg_cls} found, must register this class before" \ + f"creating model!" + # encoder_kwargs[f"{cls_name}_class"] = cores[cfg_cls] + + # Process core and randomizer kwargs + encoder_kwargs.core_kwargs = dict() if encoder_kwargs.core_kwargs is None else \ + deepcopy(encoder_kwargs.core_kwargs) + encoder_kwargs.obs_randomizer_kwargs = dict() if encoder_kwargs.obs_randomizer_kwargs is None else \ + deepcopy(encoder_kwargs.obs_randomizer_kwargs) + + # Re-lock keys + obs_encoder_config.lock() + + return dict(obs_encoder_config) + + +def initialize_obs_modality_mapping_from_dict(modality_mapping): + """ + This function is an alternative to @initialize_obs_utils_with_obs_specs, that allows manually setting of modalities. + NOTE: Only one of these should be called at runtime -- not both! (Note that all training scripts that use a config) + automatically handle obs modality mapping, so using this function is usually unnecessary) + + Args: + modality_mapping (dict): Maps modality string names (e.g.: rgb, low_dim, etc.) to a list of observation + keys that should belong to that modality + """ + global OBS_KEYS_TO_MODALITIES, OBS_MODALITIES_TO_KEYS + + OBS_KEYS_TO_MODALITIES = ObservationKeyToModalityDict() + OBS_MODALITIES_TO_KEYS = dict() + + for mod, keys in modality_mapping.items(): + OBS_MODALITIES_TO_KEYS[mod] = deepcopy(keys) + OBS_KEYS_TO_MODALITIES.update({k: mod for k in keys}) + + +def initialize_obs_utils_with_obs_specs(obs_modality_specs): + """ + This function should be called before using any observation key-specific + functions in this file, in order to make sure that all utility + functions are aware of the observation modalities (e.g. which ones + are low-dimensional, which ones are rgb, etc.). + + It constructs two dictionaries: (1) that map observation modality (e.g. low_dim, rgb) to + a list of observation keys under that modality, and (2) that maps the inverse, specific + observation keys to their corresponding observation modality. + + Input should be a nested dictionary (or list of such dicts) with the following structure: + + obs_variant (str): + obs_modality (str): observation keys (list) + ... + ... + + Example: + { + "obs": { + "low_dim": ["robot0_eef_pos", "robot0_eef_quat"], + "rgb": ["agentview_image", "robot0_eye_in_hand"], + } + "goal": { + "low_dim": ["robot0_eef_pos"], + "rgb": ["agentview_image"] + } + } + + In the example, raw observations consist of low-dim and rgb modalities, with + the robot end effector pose under low-dim, and the agentview and wrist camera + images under rgb, while goal observations also consist of low-dim and rgb modalities, + with a subset of the raw observation keys per modality. + + Args: + obs_modality_specs (dict or list): A nested dictionary (see docstring above for an example) + or a list of nested dictionaries. Accepting a list as input makes it convenient for + situations where multiple modules may each have their own modality spec. + """ + global OBS_KEYS_TO_MODALITIES, OBS_MODALITIES_TO_KEYS + + OBS_KEYS_TO_MODALITIES = ObservationKeyToModalityDict() + + # accept one or more spec dictionaries - if it's just one, account for this + if isinstance(obs_modality_specs, dict): + obs_modality_spec_list = [obs_modality_specs] + else: + obs_modality_spec_list = obs_modality_specs + + # iterates over observation specs + obs_modality_mapping = {} + for obs_modality_spec in obs_modality_spec_list: + # iterates over observation variants (e.g. observations, goals, subgoals) + for obs_modalities in obs_modality_spec.values(): + for obs_modality, obs_keys in obs_modalities.items(): + # add all keys for each obs modality to the corresponding list in obs_modality_mapping + if obs_modality not in obs_modality_mapping: + obs_modality_mapping[obs_modality] = [] + obs_modality_mapping[obs_modality] += obs_keys + # loop over each modality, and add to global dict if it doesn't exist yet + for obs_key in obs_keys: + if obs_key not in OBS_KEYS_TO_MODALITIES: + OBS_KEYS_TO_MODALITIES[obs_key] = obs_modality + # otherwise, run sanity check to make sure we don't have conflicting, duplicate entries + else: + assert OBS_KEYS_TO_MODALITIES[obs_key] == obs_modality, \ + f"Cannot register obs key {obs_key} with modality {obs_modality}; " \ + f"already exists with corresponding modality {OBS_KEYS_TO_MODALITIES[obs_key]}" + + # remove duplicate entries and store in global mapping + OBS_MODALITIES_TO_KEYS = { obs_modality : list(set(obs_modality_mapping[obs_modality])) for obs_modality in obs_modality_mapping } + + print("\n============= Initialized Observation Utils with Obs Spec =============\n") + for obs_modality, obs_keys in OBS_MODALITIES_TO_KEYS.items(): + print("using obs modality: {} with keys: {}".format(obs_modality, obs_keys)) + + +def initialize_default_obs_encoder(obs_encoder_config): + """ + Initializes the default observation encoder kwarg information to be used by all networks if no values are manually + specified at runtime. + + Args: + obs_encoder_config (Config): Observation encoder config to use. + Should be equivalent to config.observation.encoder + """ + global DEFAULT_ENCODER_KWARGS + DEFAULT_ENCODER_KWARGS = obs_encoder_kwargs_from_config(obs_encoder_config) + + +def initialize_obs_utils_with_config(config): + """ + Utility function to parse config and call @initialize_obs_utils_with_obs_specs and + @initialize_default_obs_encoder_kwargs with the correct arguments. + + Args: + config (BaseConfig instance): config object + """ + if config.algo_name == "hbc": + obs_modality_specs = [ + config.observation.planner.modalities, + config.observation.actor.modalities, + ] + obs_encoder_config = config.observation.actor.encoder + elif config.algo_name == "iris": + obs_modality_specs = [ + config.observation.value_planner.planner.modalities, + config.observation.value_planner.value.modalities, + config.observation.actor.modalities, + ] + obs_encoder_config = config.observation.actor.encoder + else: + obs_modality_specs = [config.observation.modalities] + obs_encoder_config = config.observation.encoder + initialize_obs_utils_with_obs_specs(obs_modality_specs=obs_modality_specs) + initialize_default_obs_encoder(obs_encoder_config=obs_encoder_config) + + +def key_is_obs_modality(key, obs_modality): + """ + Check if observation key corresponds to modality @obs_modality. + + Args: + key (str): obs key name to check + obs_modality (str): observation modality - e.g.: "low_dim", "rgb" + """ + assert OBS_KEYS_TO_MODALITIES is not None, "error: must call ObsUtils.initialize_obs_utils_with_obs_config first" + return OBS_KEYS_TO_MODALITIES[key] == obs_modality + + +def center_crop(im, t_h, t_w): + """ + Takes a center crop of an image. + + Args: + im (np.array or torch.Tensor): image of shape (..., height, width, channel) + t_h (int): height of crop + t_w (int): width of crop + + Returns: + im (np.array or torch.Tensor): center cropped image + """ + assert(im.shape[-3] >= t_h and im.shape[-2] >= t_w) + assert(im.shape[-1] in [1, 3]) + crop_h = int((im.shape[-3] - t_h) / 2) + crop_w = int((im.shape[-2] - t_w) / 2) + return im[..., crop_h:crop_h + t_h, crop_w:crop_w + t_w, :] + + +def batch_image_hwc_to_chw(im): + """ + Channel swap for images - useful for preparing images for + torch training. + + Args: + im (np.array or torch.Tensor): image of shape (batch, height, width, channel) + or (height, width, channel) + + Returns: + im (np.array or torch.Tensor): image of shape (batch, channel, height, width) + or (channel, height, width) + """ + start_dims = np.arange(len(im.shape) - 3).tolist() + s = start_dims[-1] if len(start_dims) > 0 else -1 + if isinstance(im, np.ndarray): + return im.transpose(start_dims + [s + 3, s + 1, s + 2]) + else: + return im.permute(start_dims + [s + 3, s + 1, s + 2]) + + +def batch_image_chw_to_hwc(im): + """ + Inverse of channel swap in @batch_image_hwc_to_chw. + + Args: + im (np.array or torch.Tensor): image of shape (batch, channel, height, width) + or (channel, height, width) + + Returns: + im (np.array or torch.Tensor): image of shape (batch, height, width, channel) + or (height, width, channel) + """ + start_dims = np.arange(len(im.shape) - 3).tolist() + s = start_dims[-1] if len(start_dims) > 0 else -1 + if isinstance(im, np.ndarray): + return im.transpose(start_dims + [s + 2, s + 3, s + 1]) + else: + return im.permute(start_dims + [s + 2, s + 3, s + 1]) + + +def process_obs(obs, obs_modality=None, obs_key=None): + """ + Process observation @obs corresponding to @obs_modality modality (or implicitly inferred from @obs_key) + to prepare for network input. + + Note that either obs_modality OR obs_key must be specified! + + If both are specified, obs_key will override obs_modality + + Args: + obs (np.array or torch.Tensor): Observation to process. Leading batch dimension is optional + obs_modality (str): Observation modality (e.g.: depth, image, low_dim, etc.) + obs_key (str): Name of observation from which to infer @obs_modality + + Returns: + processed_obs (np.array or torch.Tensor): processed observation + """ + assert obs_modality is not None or obs_key is not None, "Either obs_modality or obs_key must be specified!" + if obs_key is not None: + obs_modality = OBS_KEYS_TO_MODALITIES[obs_key] + return OBS_MODALITY_CLASSES[obs_modality].process_obs(obs) + + +def process_obs_dict(obs_dict): + """ + Process observations in observation dictionary to prepare for network input. + + Args: + obs_dict (dict): dictionary mapping observation keys to np.array or + torch.Tensor. Leading batch dimensions are optional. + + Returns: + new_dict (dict): dictionary where observation keys have been processed by their corresponding processors + """ + return { k : process_obs(obs=obs, obs_key=k) for k, obs in obs_dict.items() } # shallow copy + + +def process_frame(frame, channel_dim, scale=None): + """ + Given frame fetched from dataset, process for network input. Converts array + to float (from uint8), normalizes pixels from range [0, @scale] to [0, 1], and channel swaps + from (H, W, C) to (C, H, W). + + Args: + frame (np.array or torch.Tensor): frame array + channel_dim (int): Number of channels to sanity check for + scale (float or None): Value to normalize inputs by + + Returns: + processed_frame (np.array or torch.Tensor): processed frame + """ + # Channel size should either be 3 (RGB) or 1 (depth) + assert (frame.shape[-1] == channel_dim) + + # frame = TU.to_float(frame) + # if scale is not None: + # frame = frame / scale + # frame = frame.clip(0.0, 1.0) + frame = batch_image_hwc_to_chw(frame) + + return frame + + +def unprocess_obs(obs, obs_modality=None, obs_key=None): + """ + Prepare observation @obs corresponding to @obs_modality modality (or implicitly inferred from @obs_key) + to prepare for deployment. + + Note that either obs_modality OR obs_key must be specified! + + If both are specified, obs_key will override obs_modality + + Args: + obs (np.array or torch.Tensor): Observation to unprocess. Leading batch dimension is optional + obs_modality (str): Observation modality (e.g.: depth, image, low_dim, etc.) + obs_key (str): Name of observation from which to infer @obs_modality + + Returns: + unprocessed_obs (np.array or torch.Tensor): unprocessed observation + """ + assert obs_modality is not None or obs_key is not None, "Either obs_modality or obs_key must be specified!" + if obs_key is not None: + obs_modality = OBS_KEYS_TO_MODALITIES[obs_key] + return OBS_MODALITY_CLASSES[obs_modality].unprocess_obs(obs) + + +def unprocess_obs_dict(obs_dict): + """ + Prepare processed observation dictionary for saving to dataset. Inverse of + @process_obs. + + Args: + obs_dict (dict): dictionary mapping observation keys to np.array or + torch.Tensor. Leading batch dimensions are optional. + + Returns: + new_dict (dict): dictionary where observation keys have been unprocessed by + their respective unprocessor methods + """ + return { k : unprocess_obs(obs=obs, obs_key=k) for k, obs in obs_dict.items() } # shallow copy + + +def unprocess_frame(frame, channel_dim, scale): + """ + Given frame prepared for network input, prepare for saving to dataset. + Inverse of @process_frame. + + Args: + frame (np.array or torch.Tensor): frame array + channel_dim (int): What channel dimension should be (used for sanity check) + scale (float or None): Scaling factor to apply during denormalization + + Returns: + unprocessed_frame (np.array or torch.Tensor): frame passed through + inverse operation of @process_frame + """ + assert frame.shape[-3] == channel_dim # check for channel dimension + frame = batch_image_chw_to_hwc(frame) + if scale is not None: + frame = scale * frame + return frame + + +def get_processed_shape(obs_modality, input_shape): + """ + Given observation modality @obs_modality and expected inputs of shape @input_shape (excluding batch dimension), return the + expected processed observation shape resulting from process_{obs_modality}. + + Args: + obs_modality (str): Observation modality to use (e.g.: low_dim, rgb, depth, etc...) + input_shape (list of int): Expected input dimensions, excluding the batch dimension + + Returns: + list of int: expected processed input shape + """ + return list(process_obs(obs=np.zeros(input_shape), obs_modality=obs_modality).shape) + + +def normalize_obs(obs_dict, obs_normalization_stats): + """ + Normalize observations using the provided "mean" and "std" entries + for each observation key. The observation dictionary will be + modified in-place. + + Args: + obs_dict (dict): dictionary mapping observation key to np.array or + torch.Tensor. Can have any number of leading batch dimensions. + + obs_normalization_stats (dict): this should map observation keys to dicts + with a "mean" and "std" of shape (1, ...) where ... is the default + shape for the observation. + + Returns: + obs_dict (dict): obs dict with normalized observation arrays + """ + + # ensure we have statistics for each modality key in the observation + assert set(obs_dict.keys()).issubset(obs_normalization_stats) + + for m in obs_dict: + # get rid of extra dimension - we will pad for broadcasting later + mean = obs_normalization_stats[m]["mean"][0] + std = obs_normalization_stats[m]["std"][0] + + # shape consistency checks + m_num_dims = len(mean.shape) + shape_len_diff = len(obs_dict[m].shape) - m_num_dims + assert shape_len_diff >= 0, "shape length mismatch in @normalize_obs" + assert obs_dict[m].shape[-m_num_dims:] == mean.shape, "shape mismatch in @normalize_obs" + + # Obs can have one or more leading batch dims - prepare for broadcasting. + # + # As an example, if the obs has shape [B, T, D] and our mean / std stats are shape [D] + # then we should pad the stats to shape [1, 1, D]. + reshape_padding = tuple([1] * shape_len_diff) + mean = mean.reshape(reshape_padding + tuple(mean.shape)) + std = std.reshape(reshape_padding + tuple(std.shape)) + + obs_dict[m] = (obs_dict[m] - mean) / std + + return obs_dict + + +def has_modality(modality, obs_keys): + """ + Returns True if @modality is present in the list of observation keys @obs_keys. + + Args: + modality (str): modality to check for, e.g.: rgb, depth, etc. + obs_keys (list): list of observation keys + """ + for k in obs_keys: + if key_is_obs_modality(k, obs_modality=modality): + return True + return False + + +def repeat_and_stack_observation(obs_dict, n): + """ + Given an observation dictionary and a desired repeat value @n, + this function will return a new observation dictionary where + each modality is repeated @n times and the copies are + stacked in the first dimension. + + For example, if a batch of 3 observations comes in, and n is 2, + the output will look like [ob1; ob1; ob2; ob2; ob3; ob3] in + each modality. + + Args: + obs_dict (dict): dictionary mapping observation key to np.array or + torch.Tensor. Leading batch dimensions are optional. + + n (int): number to repeat by + + Returns: + repeat_obs_dict (dict): repeated obs dict + """ + return TU.repeat_by_expand_at(obs_dict, repeats=n, dim=0) + + +def crop_image_from_indices(images, crop_indices, crop_height, crop_width): + """ + Crops images at the locations specified by @crop_indices. Crops will be + taken across all channels. + + Args: + images (torch.Tensor): batch of images of shape [..., C, H, W] + + crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where + N is the number of crops to take per image and each entry corresponds + to the pixel height and width of where to take the crop. Note that + the indices can also be of shape [..., 2] if only 1 crop should + be taken per image. Leading dimensions must be consistent with + @images argument. Each index specifies the top left of the crop. + Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where + H and W are the height and width of @images and CH and CW are + @crop_height and @crop_width. + + crop_height (int): height of crop to take + + crop_width (int): width of crop to take + + Returns: + crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width] + """ + + # make sure length of input shapes is consistent + assert crop_indices.shape[-1] == 2 + ndim_im_shape = len(images.shape) + ndim_indices_shape = len(crop_indices.shape) + assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2) + + # maybe pad so that @crop_indices is shape [..., N, 2] + is_padded = False + if ndim_im_shape == ndim_indices_shape + 2: + crop_indices = crop_indices.unsqueeze(-2) + is_padded = True + + # make sure leading dimensions between images and indices are consistent + assert images.shape[:-3] == crop_indices.shape[:-2] + + device = images.device + image_c, image_h, image_w = images.shape[-3:] + num_crops = crop_indices.shape[-2] + + # make sure @crop_indices are in valid range + assert (crop_indices[..., 0] >= 0).all().item() + assert (crop_indices[..., 0] < (image_h - crop_height)).all().item() + assert (crop_indices[..., 1] >= 0).all().item() + assert (crop_indices[..., 1] < (image_w - crop_width)).all().item() + + # convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window. + + # 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW] + crop_ind_grid_h = torch.arange(crop_height, device=device) + crop_ind_grid_h = TU.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1) + # 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW] + crop_ind_grid_w = torch.arange(crop_width, device=device) + crop_ind_grid_w = TU.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0) + # combine into shape [CH, CW, 2] + crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1) + + # Add above grid with the offset index of each sampled crop to get 2d indices for each crop. + # After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2] + # shape array that tells us which pixels from the corresponding source image to grab. + grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2] + all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape) + + # For using @torch.gather, convert to flat indices from 2D indices, and also + # repeat across the channel dimension. To get flat index of each pixel to grab for + # each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind + all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW] + all_crop_inds = TU.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW] + all_crop_inds = TU.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW] + + # Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds + images_to_crop = TU.unsqueeze_expand_at(images, size=num_crops, dim=-4) + images_to_crop = TU.flatten(images_to_crop, begin_axis=-2) + crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds) + # [..., N, C, CH * CW] -> [..., N, C, CH, CW] + reshape_axis = len(crops.shape) - 1 + crops = TU.reshape_dimensions(crops, begin_axis=reshape_axis, end_axis=reshape_axis, + target_dims=(crop_height, crop_width)) + + if is_padded: + # undo padding -> [..., C, CH, CW] + crops = crops.squeeze(-4) + return crops + + +def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False): + """ + For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from + @images. + + Args: + images (torch.Tensor): batch of images of shape [..., C, H, W] + + crop_height (int): height of crop to take + + crop_width (int): width of crop to take + + num_crops (n): number of crops to sample + + pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial + encoding of the original source pixel locations. This means that the + output crops will contain information about where in the source image + it was sampled from. + + Returns: + crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width) + if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width) + + crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2) + """ + device = images.device + # maybe add 2 channels of spatial encoding to the source image + source_im = images + if pos_enc: + # spatial encoding [y, x] in [0, 1] + h, w = source_im.shape[-2:] + pos_y, pos_x = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing='ij') + pos_y = pos_y.float() / float(h) + pos_x = pos_x.float() / float(w) + position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W] + + # unsqueeze and expand to match leading dimensions -> shape [..., C, H, W] + leading_shape = source_im.shape[:-3] + position_enc = position_enc[(None,) * len(leading_shape)] + position_enc = position_enc.expand(*leading_shape, -1, -1, -1) + + # concat across channel dimension with input + source_im = torch.cat((source_im, position_enc), dim=-3) + + # make sure sample boundaries ensure crops are fully within the images + image_c, image_h, image_w = source_im.shape[-3:] + max_sample_h = image_h - crop_height + max_sample_w = image_w - crop_width + + # Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W]. + # Each gets @num_crops samples - typically this will just be the batch dimension (B), so + # we will sample [B, N] indices, but this supports having more than one leading dimension, + # or possibly no leading dimension. + # + # Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints + crop_inds_h = (max_sample_h * torch.rand((*source_im.shape[:-3], num_crops), device=device)).long() + crop_inds_w = (max_sample_w * torch.rand((*source_im.shape[:-3], num_crops), device=device)).long() + crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2] + + crops = crop_image_from_indices( + images=source_im, + crop_indices=crop_inds, + crop_height=crop_height, + crop_width=crop_width, + ) + + return crops, crop_inds + + +class Modality: + """ + Observation Modality class to encapsulate necessary functions needed to + process observations of this modality + """ + # observation keys to associate with this modality + keys = set() + + # Custom processing function that should prepare raw observations of this modality for training + _custom_obs_processor = None + + # Custom unprocessing function that should prepare observations of this modality used during training for deployment + _custom_obs_unprocessor = None + + # Name of this modality -- must be set by subclass! + name = None + + def __init_subclass__(cls, **kwargs): + """ + Hook method to automatically register all valid subclasses so we can keep track of valid modalities + """ + assert cls.name is not None, f"Name of modality {cls.__name__} must be specified!" + register_obs_key(cls) + + @classmethod + def set_keys(cls, keys): + """ + Sets the observation keys associated with this modality. + + Args: + keys (list or set): observation keys to associate with this modality + """ + cls.keys = {k for k in keys} + + @classmethod + def add_keys(cls, keys): + """ + Adds the observation @keys associated with this modality to the current set of keys. + + Args: + keys (list or set): observation keys to add to associate with this modality + """ + for key in keys: + cls.keys.add(key) + + @classmethod + def set_obs_processor(cls, processor=None): + """ + Sets the processor for this observation modality. If @processor is set to None, then + the obs processor will use the default one (self.process_obs(...)). Otherwise, @processor + should be a function to process this corresponding observation modality. + + Args: + processor (function or None): If not None, should be function that takes in either a + np.array or torch.Tensor and output the processed array / tensor. If None, will reset + to the default processor (self.process_obs(...)) + """ + cls._custom_obs_processor = processor + + @classmethod + def set_obs_unprocessor(cls, unprocessor=None): + """ + Sets the unprocessor for this observation modality. If @unprocessor is set to None, then + the obs unprocessor will use the default one (self.unprocess_obs(...)). Otherwise, @unprocessor + should be a function to process this corresponding observation modality. + + Args: + unprocessor (function or None): If not None, should be function that takes in either a + np.array or torch.Tensor and output the unprocessed array / tensor. If None, will reset + to the default unprocessor (self.unprocess_obs(...)) + """ + cls._custom_obs_unprocessor = unprocessor + + @classmethod + def _default_obs_processor(cls, obs): + """ + Default processing function for this obs modality. + + Note that this function is overridden by self.custom_obs_processor (a function with identical inputs / outputs) + if it is not None. + + Args: + obs (np.array or torch.Tensor): raw observation, which may include a leading batch dimension + + Returns: + np.array or torch.Tensor: processed observation + """ + raise NotImplementedError + + @classmethod + def _default_obs_unprocessor(cls, obs): + """ + Default unprocessing function for this obs modality. + + Note that this function is overridden by self.custom_obs_unprocessor + (a function with identical inputs / outputs) if it is not None. + + Args: + obs (np.array or torch.Tensor): processed observation, which may include a leading batch dimension + + Returns: + np.array or torch.Tensor: unprocessed observation + """ + raise NotImplementedError + + @classmethod + def process_obs(cls, obs): + """ + Prepares an observation @obs of this modality for network input. + + Args: + obs (np.array or torch.Tensor): raw observation, which may include a leading batch dimension + + Returns: + np.array or torch.Tensor: processed observation + """ + processor = cls._custom_obs_processor if \ + cls._custom_obs_processor is not None else cls._default_obs_processor + return processor(obs) + + @classmethod + def unprocess_obs(cls, obs): + """ + Prepares an observation @obs of this modality for deployment. + + Args: + obs (np.array or torch.Tensor): processed observation, which may include a leading batch dimension + + Returns: + np.array or torch.Tensor: unprocessed observation + """ + unprocessor = cls._custom_obs_unprocessor if \ + cls._custom_obs_unprocessor is not None else cls._default_obs_unprocessor + return unprocessor(obs) + + @classmethod + def process_obs_from_dict(cls, obs_dict, inplace=True): + """ + Receives a dictionary of keyword mapped observations @obs_dict, and processes the observations with keys + corresponding to this modality. A copy will be made of the received dictionary unless @inplace is True + + Args: + obs_dict (dict): Dictionary mapping observation keys to observations + inplace (bool): If True, will modify @obs_dict in place, otherwise, will create a copy + + Returns: + dict: observation dictionary with processed observations corresponding to this modality + """ + if inplace: + obs_dict = deepcopy(obs_dict) + # Loop over all keys and process the ones corresponding to this modality + for key, obs in obs_dict.values(): + if key in cls.keys: + obs_dict[key] = cls.process_obs(obs) + + return obs_dict + + +class ImageModality(Modality): + """ + Modality for RGB image observations + """ + name = "rgb" + + @classmethod + def _default_obs_processor(cls, obs): + """ + Given image fetched from dataset, process for network input. Converts array + to float (from uint8), normalizes pixels from range [0, 255] to [0, 1], and channel swaps + from (H, W, C) to (C, H, W). + + Args: + obs (np.array or torch.Tensor): image array + + Returns: + processed_obs (np.array or torch.Tensor): processed image + """ + return process_frame(frame=obs, channel_dim=3, scale=255.) + + @classmethod + def _default_obs_unprocessor(cls, obs): + """ + Given image prepared for network input, prepare for saving to dataset. + Inverse of @process_frame. + + Args: + obs (np.array or torch.Tensor): image array + + Returns: + unprocessed_obs (np.array or torch.Tensor): image passed through + inverse operation of @process_frame + """ + return TU.to_uint8(unprocess_frame(frame=obs, channel_dim=3, scale=255.)) + + +class DepthModality(Modality): + """ + Modality for depth observations + """ + name = "depth" + + @classmethod + def _default_obs_processor(cls, obs): + """ + Given depth fetched from dataset, process for network input. Converts array + to float (from uint8), normalizes pixels from range [0, 1] to [0, 1], and channel swaps + from (H, W, C) to (C, H, W). + + Args: + obs (np.array or torch.Tensor): depth array + + Returns: + processed_obs (np.array or torch.Tensor): processed depth + """ + return process_frame(frame=obs, channel_dim=1, scale=1.) + + @classmethod + def _default_obs_unprocessor(cls, obs): + """ + Given depth prepared for network input, prepare for saving to dataset. + Inverse of @process_depth. + + Args: + obs (np.array or torch.Tensor): depth array + + Returns: + unprocessed_obs (np.array or torch.Tensor): depth passed through + inverse operation of @process_depth + """ + return unprocess_frame(frame=obs, channel_dim=1, scale=1.) + + +class ScanModality(Modality): + """ + Modality for scan observations + """ + name = "scan" + + @classmethod + def _default_obs_processor(cls, obs): + # Channel swaps ([...,] L, C) --> ([...,] C, L) + + # First, add extra dimension at 2nd to last index to treat this as a frame + shape = obs.shape + new_shape = [*shape[:-2], 1, *shape[-2:]] + obs = obs.reshape(new_shape) + + # Convert shape + obs = batch_image_hwc_to_chw(obs) + + # Remove extra dimension (it's the second from last dimension) + obs = obs.squeeze(-2) + return obs + + @classmethod + def _default_obs_unprocessor(cls, obs): + # Channel swaps ([B,] C, L) --> ([B,] L, C) + + # First, add extra dimension at 1st index to treat this as a frame + shape = obs.shape + new_shape = [*shape[:-2], 1, *shape[-2:]] + obs = obs.reshape(new_shape) + + # Convert shape + obs = batch_image_chw_to_hwc(obs) + + # Remove extra dimension (it's the second from last dimension) + obs = obs.squeeze(-2) + return obs + + +class LowDimModality(Modality): + """ + Modality for low dimensional observations + """ + name = "low_dim" + + @classmethod + def _default_obs_processor(cls, obs): + return obs + + @classmethod + def _default_obs_unprocessor(cls, obs): + return obs diff --git a/quest/utils/tensor_utils.py b/quest/utils/tensor_utils.py new file mode 100644 index 0000000..3a35a15 --- /dev/null +++ b/quest/utils/tensor_utils.py @@ -0,0 +1,1054 @@ +""" +A collection of utilities for working with nested tensor structures consisting +of numpy arrays and torch tensors. + +This file is adopted from Robomimic +https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/utils/tensor_utils.py +""" +import collections +import numpy as np +import torch +import torch.nn as nn + + +def separate_no_decay(module, + name_blacklist=None, + blacklist_weight_modules = ( + nn.LayerNorm, + nn.Embedding, + nn.BatchNorm2d, + nn.GroupNorm)): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + + Modified from VQBeT codebase so that you don't need to provide a whitelist + https://github.com/jayLEE0301/vq_bet_official/blob/09d4851288ca5deaaa1ab367a208e520f8ee9a84/vq_behavior_transformer/gpt.py#L230 + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + if name_blacklist is None: + name_blacklist = [] + + whitelist_classes = set() + for mn, m in module.named_modules(): + # This skips modules whose names include words from the name blacklist + bl = False + for name in name_blacklist: + if name in mn: + bl = True + # print(name) + break + if bl: + continue + + for pn, p in m.named_parameters(): + fpn = f"{mn}.{pn}" if mn else pn # full param name + if '.' in pn: + break + + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + elif pn.endswith("weight"): + whitelist_classes.add(type(m)) + decay.add(fpn) + else: + no_decay.add(fpn) + + # validate that we considered every parameter + if len(name_blacklist) > 0: + old_param_dict = {pn: p for pn, p in module.named_parameters()} + param_dict = {} + for pn, p in old_param_dict.items(): + bl = False + for name in name_blacklist: + if name in pn: + bl = True + if not bl: + param_dict[pn] = p + else: + param_dict = {pn: p for pn, p in module.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + + # fuck = print + # pain = list(param_dict.keys()) + # pain.sort() + # for me in pain: + # fuck(me) + + assert len(inter_params) == 0, ( + "parameters %s made it into both decay/no_decay sets!" + % (str(inter_params),) + ) + assert len(param_dict.keys() - union_params) == 0, ( + "parameters %s were not separated into either decay/no_decay set!" + % (str(param_dict.keys() - union_params),) + ) + + # breakpoint() + decay = [param_dict[pn] for pn in sorted(list(decay))] + no_decay = [param_dict[pn] for pn in sorted(list(no_decay))] + + return decay, no_decay + + +def recursive_dict_list_tuple_apply(x, type_func_dict): + """ + Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of + {data_type: function_to_apply}. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + type_func_dict (dict): a mapping from data types to the functions to be + applied for each data type. + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + assert(list not in type_func_dict) + assert(tuple not in type_func_dict) + assert(dict not in type_func_dict) + + if isinstance(x, (dict, collections.OrderedDict)): + new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else dict() + for k, v in x.items(): + new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict) + return new_x + elif isinstance(x, (list, tuple)): + ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x] + if isinstance(x, tuple): + ret = tuple(ret) + return ret + else: + for t, f in type_func_dict.items(): + if isinstance(x, t): + return f(x) + else: + raise NotImplementedError( + 'Cannot handle data type %s' % str(type(x))) + + +def map_tensor(x, func): + """ + Apply function @func to torch.Tensor objects in a nested dictionary or + list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + func (function): function to apply to each tensor + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: func, + type(None): lambda x: x, + } + ) + + +def map_ndarray(x, func): + """ + Apply function @func to np.ndarray objects in a nested dictionary or + list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + func (function): function to apply to each array + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + np.ndarray: func, + type(None): lambda x: x, + } + ) + + +def map_tensor_ndarray(x, tensor_func, ndarray_func): + """ + Apply function @tensor_func to torch.Tensor objects and @ndarray_func to + np.ndarray objects in a nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + tensor_func (function): function to apply to each tensor + ndarray_Func (function): function to apply to each array + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: tensor_func, + np.ndarray: ndarray_func, + type(None): lambda x: x, + } + ) + + +def clone(x): + """ + Clones all torch tensors and numpy arrays in nested dictionary or list + or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.clone(), + np.ndarray: lambda x: x.copy(), + type(None): lambda x: x, + } + ) + + +def detach(x): + """ + Detaches all torch tensors in nested dictionary or list + or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.detach(), + } + ) + + +def to_batch(x): + """ + Introduces a leading batch dimension of 1 for all torch tensors and numpy + arrays in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[None, ...], + np.ndarray: lambda x: x[None, ...], + type(None): lambda x: x, + } + ) + + +def to_sequence(x): + """ + Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy + arrays in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[:, None, ...], + np.ndarray: lambda x: x[:, None, ...], + type(None): lambda x: x, + } + ) + + +def index_at_time(x, ind): + """ + Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in + nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + ind (int): index + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[:, ind, ...], + np.ndarray: lambda x: x[:, ind, ...], + type(None): lambda x: x, + } + ) + + +def unsqueeze(x, dim): + """ + Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays + in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + dim (int): dimension + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.unsqueeze(dim=dim), + np.ndarray: lambda x: np.expand_dims(x, axis=dim), + type(None): lambda x: x, + } + ) + + +def contiguous(x): + """ + Makes all torch tensors and numpy arrays contiguous in nested dictionary or + list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.contiguous(), + np.ndarray: lambda x: np.ascontiguousarray(x), + type(None): lambda x: x, + } + ) + + +def to_device(x, device): + """ + Sends all torch tensors in nested dictionary or list or tuple to device + @device, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + device (torch.Device): device to send tensors to + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, d=device: x.to(d), + type(None): lambda x: x, + } + ) + + +def to_tensor(x): + """ + Converts all numpy arrays in nested dictionary or list or tuple to + torch tensors (and leaves existing torch Tensors as-is), and returns + a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x, + np.ndarray: lambda x: torch.from_numpy(x), + type(None): lambda x: x, + } + ) + + +def to_numpy(x): + """ + Converts all torch tensors in nested dictionary or list or tuple to + numpy (and leaves existing numpy arrays as-is), and returns + a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + def f(tensor): + if tensor.is_cuda: + return tensor.detach().cpu().numpy() + else: + return tensor.detach().numpy() + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: f, + np.ndarray: lambda x: x, + type(None): lambda x: x, + } + ) + + +def to_list(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to a list, and returns a new nested structure. Useful for + json encoding. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + def f(tensor): + if tensor.is_cuda: + return tensor.detach().cpu().numpy().tolist() + else: + return tensor.detach().numpy().tolist() + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: f, + np.ndarray: lambda x: x.tolist(), + type(None): lambda x: x, + } + ) + + +def to_float(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to float type entries, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.float(), + np.ndarray: lambda x: x.astype(np.float32), + type(None): lambda x: x, + } + ) + + +def to_uint8(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to uint8 type entries, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.byte(), + np.ndarray: lambda x: x.astype(np.uint8), + type(None): lambda x: x, + } + ) + + +def to_torch(x, device): + """ + Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to + torch tensors on device @device and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + device (torch.Device): device to send tensors to + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return to_device(to_float(to_tensor(x)), device) + + +def to_one_hot_single(tensor, num_class): + """ + Convert tensor to one-hot representation, assuming a certain number of total class labels. + + Args: + tensor (torch.Tensor): tensor containing integer labels + num_class (int): number of classes + + Returns: + x (torch.Tensor): tensor containing one-hot representation of labels + """ + x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device) + x.scatter_(-1, tensor.unsqueeze(-1), 1) + return x + + +def to_one_hot(tensor, num_class): + """ + Convert all tensors in nested dictionary or list or tuple to one-hot representation, + assuming a certain number of total class labels. + + Args: + tensor (dict or list or tuple): a possibly nested dictionary or list or tuple + num_class (int): number of classes + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc)) + + +def flatten_single(x, begin_axis=1): + """ + Flatten a tensor in all dimensions from @begin_axis onwards. + + Args: + x (torch.Tensor): tensor to flatten + begin_axis (int): which axis to flatten from + + Returns: + y (torch.Tensor): flattened tensor + """ + fixed_size = x.size()[:begin_axis] + _s = list(fixed_size) + [-1] + return x.reshape(*_s) + + +def flatten(x, begin_axis=1): + """ + Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): which axis to flatten from + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b), + } + ) + + +def reshape_dimensions_single(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions in a tensor to a target dimension. + + Args: + x (torch.Tensor): tensor to reshape + begin_axis (int): begin dimension + end_axis (int): end dimension (inclusive) + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (torch.Tensor): reshaped tensor + """ + assert(begin_axis <= end_axis) + assert(begin_axis >= 0) + assert(end_axis < len(x.shape)) + assert(isinstance(target_dims, (tuple, list))) + s = x.shape + final_s = [] + for i in range(len(s)): + if i == begin_axis: + final_s.extend(target_dims) + elif i < begin_axis or i > end_axis: + final_s.append(s[i]) + return x.reshape(*final_s) + + +def reshape_dimensions(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions for all tensors in nested dictionary or list or tuple + to a target dimension. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): begin dimension + end_axis (int): end dimension (inclusive) + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=t), + np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=t), + type(None): lambda x: x, + } + ) + + +def join_dimensions(x, begin_axis, end_axis): + """ + Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for + all tensors in nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): begin dimension + end_axis (int): end dimension + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=[-1]), + np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=[-1]), + type(None): lambda x: x, + } + ) + + +def expand_at_single(x, size, dim): + """ + Expand a tensor at a single dimension @dim by @size + + Args: + x (torch.Tensor): input tensor + size (int): size to expand + dim (int): dimension to expand + + Returns: + y (torch.Tensor): expanded tensor + """ + assert dim < x.ndimension() + assert x.shape[dim] == 1 + expand_dims = [-1] * x.ndimension() + expand_dims[dim] = size + return x.expand(*expand_dims) + + +def expand_at(x, size, dim): + """ + Expand all tensors in nested dictionary or list or tuple at a single + dimension @dim by @size. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size to expand + dim (int): dimension to expand + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d)) + + +def unsqueeze_expand_at(x, size, dim): + """ + Unsqueeze and expand a tensor at a dimension @dim by @size. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size to expand + dim (int): dimension to unsqueeze and expand + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + x = unsqueeze(x, dim) + return expand_at(x, size, dim) + + +def repeat_by_expand_at(x, repeats, dim): + """ + Repeat a dimension by combining expand and reshape operations. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + repeats (int): number of times to repeat the target dimension + dim (int): dimension to repeat on + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + x = unsqueeze_expand_at(x, repeats, dim + 1) + return join_dimensions(x, dim, dim + 1) + + +def named_reduce_single(x, reduction, dim): + """ + Reduce tensor at a dimension by named reduction functions. + + Args: + x (torch.Tensor): tensor to be reduced + reduction (str): one of ["sum", "max", "mean", "flatten"] + dim (int): dimension to be reduced (or begin axis for flatten) + + Returns: + y (torch.Tensor): reduced tensor + """ + assert x.ndimension() > dim + assert reduction in ["sum", "max", "mean", "flatten"] + if reduction == "flatten": + x = flatten(x, begin_axis=dim) + elif reduction == "max": + x = torch.max(x, dim=dim)[0] # [B, D] + elif reduction == "sum": + x = torch.sum(x, dim=dim) + else: + x = torch.mean(x, dim=dim) + return x + + +def named_reduce(x, reduction, dim): + """ + Reduces all tensors in nested dictionary or list or tuple at a dimension + using a named reduction function. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + reduction (str): one of ["sum", "max", "mean", "flatten"] + dim (int): dimension to be reduced (or begin axis for flatten) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d)) + + +def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices): + """ + This function indexes out a target dimension of a tensor in a structured way, + by allowing a different value to be selected for each member of a flat index + tensor (@indices) corresponding to a source dimension. This can be interpreted + as moving along the source dimension, using the corresponding index value + in @indices to select values for all other dimensions outside of the + source and target dimensions. A common use case is to gather values + in target dimension 1 for each batch member (target dimension 0). + + Args: + x (torch.Tensor): tensor to gather values for + target_dim (int): dimension to gather values along + source_dim (int): dimension to hold constant and use for gathering values + from the other dimensions + indices (torch.Tensor): flat index tensor with same shape as tensor @x along + @source_dim + + Returns: + y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out + """ + assert len(indices.shape) == 1 + assert x.shape[source_dim] == indices.shape[0] + + # unsqueeze in all dimensions except the source dimension + new_shape = [1] * x.ndimension() + new_shape[source_dim] = -1 + indices = indices.reshape(*new_shape) + + # repeat in all dimensions - but preserve shape of source dimension, + # and make sure target_dimension has singleton dimension + expand_shape = list(x.shape) + expand_shape[source_dim] = -1 + expand_shape[target_dim] = 1 + indices = indices.expand(*expand_shape) + + out = x.gather(dim=target_dim, index=indices) + return out.squeeze(target_dim) + + +def gather_along_dim_with_dim(x, target_dim, source_dim, indices): + """ + Apply @gather_along_dim_with_dim_single to all tensors in a nested + dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + target_dim (int): dimension to gather values along + source_dim (int): dimension to hold constant and use for gathering values + from the other dimensions + indices (torch.Tensor): flat index tensor with same shape as tensor @x along + @source_dim + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(x, + lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i)) + + +def gather_sequence_single(seq, indices): + """ + Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in + the batch given an index for each sequence. + + Args: + seq (torch.Tensor): tensor with leading dimensions [B, T, ...] + indices (torch.Tensor): tensor indices of shape [B] + + Return: + y (torch.Tensor): indexed tensor of shape [B, ....] + """ + return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices) + + +def gather_sequence(seq, indices): + """ + Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch + for tensors with leading dimensions [B, T, ...]. + + Args: + seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + indices (torch.Tensor): tensor indices of shape [B] + + Returns: + y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...] + """ + return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices) + + +def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None): + """ + Pad input tensor or array @seq in the time dimension (dimension 1). + + Args: + seq (np.ndarray or torch.Tensor): sequence to be padded + padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 + batched (bool): if sequence has the batch dimension + pad_same (bool): if pad by duplicating + pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same + + Returns: + padded sequence (np.ndarray or torch.Tensor) + """ + assert isinstance(seq, (np.ndarray, torch.Tensor)) + assert pad_same or pad_values is not None + if pad_values is not None: + assert isinstance(pad_values, float) + repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave + concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat + ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like + seq_dim = 1 if batched else 0 + + begin_pad = [] + end_pad = [] + + if padding[0] > 0: + pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values + begin_pad.append(repeat_func(pad, padding[0], seq_dim)) + if padding[1] > 0: + pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values + end_pad.append(repeat_func(pad, padding[1], seq_dim)) + + return concat_func(begin_pad + [seq] + end_pad, seq_dim) + + +def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None): + """ + Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1). + + Args: + seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 + batched (bool): if sequence has the batch dimension + pad_same (bool): if pad by duplicating + pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same + + Returns: + padded sequence (dict or list or tuple) + """ + return recursive_dict_list_tuple_apply( + seq, + { + torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: + pad_sequence_single(x, p, b, ps, pv), + np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: + pad_sequence_single(x, p, b, ps, pv), + type(None): lambda x: x, + } + ) + + +def assert_size_at_dim_single(x, size, dim, msg): + """ + Ensure that array or tensor @x has size @size in dim @dim. + + Args: + x (np.ndarray or torch.Tensor): input array or tensor + size (int): size that tensors should have at @dim + dim (int): dimension to check + msg (str): text to display if assertion fails + """ + assert x.shape[dim] == size, msg + + +def assert_size_at_dim(x, size, dim, msg): + """ + Ensure that arrays and tensors in nested dictionary or list or tuple have + size @size in dim @dim. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size that tensors should have at @dim + dim (int): dimension to check + """ + map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m)) + + +def get_shape(x): + """ + Get all shapes of arrays and tensors in nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple that contains each array or + tensor's shape + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.shape, + np.ndarray: lambda x: x.shape, + type(None): lambda x: x, + } + ) + + +def list_of_flat_dict_to_dict_of_list(list_of_dict): + """ + Helper function to go from a list of flat dictionaries to a dictionary of lists. + By "flat" we mean that none of the values are dictionaries, but are numpy arrays, + floats, etc. + + Args: + list_of_dict (list): list of flat dictionaries + + Returns: + dict_of_list (dict): dictionary of lists + """ + assert isinstance(list_of_dict, list) + dic = collections.OrderedDict() + for i in range(len(list_of_dict)): + for k in list_of_dict[i]: + if k not in dic: + dic[k] = [] + dic[k].append(list_of_dict[i][k]) + return dic + + +def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''): + """ + Flatten a nested dict or list to a list. + + For example, given a dict + { + a: 1 + b: { + c: 2 + } + c: 3 + } + + the function would return [(a, 1), (b_c, 2), (c, 3)] + + Args: + d (dict, list): a nested dict or list to be flattened + parent_key (str): recursion helper + sep (str): separator for nesting keys + item_key (str): recursion helper + Returns: + list: a list of (key, value) tuples + """ + items = [] + if isinstance(d, (tuple, list)): + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + for i, v in enumerate(d): + items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i))) + return items + elif isinstance(d, dict): + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + for k, v in d.items(): + assert isinstance(k, str) + items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k)) + return items + else: + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + return [(new_key, d)] + + +def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs): + """ + Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the + batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...]. + Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping + outputs to [B, T, ...]. + + Args: + inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + op: a layer op that accepts inputs + activation: activation to apply at the output + inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op + inputs_as_args (bool) whether to feed input as a args list to the op + kwargs (dict): other kwargs to supply to the op + + Returns: + outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T]. + """ + batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2] + inputs = join_dimensions(inputs, 0, 1) + if inputs_as_kwargs: + outputs = op(**inputs, **kwargs) + elif inputs_as_args: + outputs = op(*inputs, **kwargs) + else: + outputs = op(inputs, **kwargs) + + if activation is not None: + outputs = map_tensor(outputs, activation) + outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len)) + return outputs diff --git a/quest/utils/utils.py b/quest/utils/utils.py new file mode 100644 index 0000000..e4acf00 --- /dev/null +++ b/quest/utils/utils.py @@ -0,0 +1,147 @@ +import copy +import json +import os +import random +from pathlib import Path +import quest.utils.tensor_utils as TensorUtils +import numpy as np +import torch +import torch.nn as nn +import warnings +from natsort import natsorted + +def get_experiment_dir(cfg, evaluate=False, allow_overlap=False): + # if eval_flag: + # prefix = "evaluations" + # else: + # prefix = "experiments" + # if cfg.pretrain_model_path != "": + # prefix += "_finetune" + + prefix = cfg.output_prefix + if evaluate: + prefix = os.path.join(prefix, 'evaluate') + + experiment_dir = ( + f"{prefix}/{cfg.task.suite_name}/{cfg.task.benchmark_name}/" + + f"{cfg.algo.name}/{cfg.exp_name}" + ) + if cfg.variant_name is not None: + experiment_dir += f'/{cfg.variant_name}' + + if cfg.seed != 10000: + experiment_dir += f'/{cfg.seed}' + + if cfg.make_unique_experiment_dir: + # look for the most recent run + experiment_id = 0 + if os.path.exists(experiment_dir): + for path in Path(experiment_dir).glob("run_*"): + if not path.is_dir(): + continue + try: + folder_id = int(str(path).split("run_")[-1]) + if folder_id > experiment_id: + experiment_id = folder_id + except BaseException: + pass + experiment_id += 1 + + experiment_dir += f"/run_{experiment_id:03d}" + else: + experiment_dir += f'/stage_{cfg.stage}' + + if not allow_overlap and not cfg.training.resume: + assert not os.path.exists(experiment_dir), \ + f'cfg.make_unique_experiment_dir=false but {experiment_dir} is already occupied' + + experiment_name = "_".join(experiment_dir.split("/")[len(cfg.output_prefix.split('/')):]) + return experiment_dir, experiment_name + +def get_latest_checkpoint(checkpoint_dir): + if os.path.isfile(checkpoint_dir): + return checkpoint_dir + + onlyfiles = [f for f in os.listdir(checkpoint_dir) if os.path.isfile(os.path.join(checkpoint_dir, f))] + onlyfiles = natsorted(onlyfiles) + best_file = onlyfiles[-1] + return os.path.join(checkpoint_dir, best_file) + +def soft_load_state_dict(model, loaded_state_dict): + # loaded_state_dict['task_encoder.weight'] = loaded_state_dict['task_encodings.weight'] + + current_model_dict = model.state_dict() + new_state_dict = {} + + for k in current_model_dict.keys(): + if k in loaded_state_dict: + v = loaded_state_dict[k] + if not hasattr(v, 'size') or v.size() == current_model_dict[k].size(): + new_state_dict[k] = v + else: + warnings.warn(f'Cannot load checkpoint parameter {k} with shape {loaded_state_dict[k].shape}' + f'into model with corresponding parameter shape {current_model_dict[k].shape}. Skipping') + new_state_dict[k] = current_model_dict[k] + else: + new_state_dict[k] = current_model_dict[k] + warnings.warn(f'Model parameter {k} does not exist in checkpoint. Skipping') + for k in loaded_state_dict.keys(): + if k not in current_model_dict: + warnings.warn(f'Loaded checkpoint parameter {k} does not exist in model. Skipping') + + model.load_state_dict(new_state_dict) + +def map_tensor_to_device(data, device): + """Move data to the device specified by device.""" + return TensorUtils.map_tensor( + data, lambda x: safe_device(x, device=device) + ) + +def safe_device(x, device="cpu"): + if device == "cpu": + return x.cpu() + elif "cuda" in device: + if torch.cuda.is_available(): + return x.to(device) + else: + return x.cpu() + +def extract_state_dicts(inp): + + if not (isinstance(inp, dict) or isinstance(inp, list)): + if hasattr(inp, 'state_dict'): + return inp.state_dict() + else: + return inp + elif isinstance(inp, list): + out_list = [] + for value in inp: + out_list.append(extract_state_dicts(value)) + return out_list + else: + out_dict = {} + for key, value in inp.items(): + out_dict[key] = extract_state_dicts(value) + return out_dict + +def save_state(state_dict, path): + save_dict = extract_state_dicts(state_dict) + torch.save(save_dict, path) + +def load_state(path): + return torch.load(path) + +def torch_save_model(model, optimizer, scheduler, model_path, cfg=None): + torch.save( + { + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "cfg": cfg, + }, + model_path, + ) + +def torch_load_model(model_path): + checkpoint = torch.load(model_path) + return checkpoint["model_state_dict"], checkpoint["optimizer_state_dict"], checkpoint["scheduler_state_dict"], checkpoint["cfg"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..dbebdff --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +easydict==1.13 +einops==0.8.0 +h5py==3.11.0 +hydra-core==1.3.2 +imageio==2.34.1 +matplotlib==3.9.0 +numpy==1.26.4 +positional-encodings==6.0.3 +tqdm==4.66.4 +vector-quantize-pytorch==1.14.24 +wandb==0.17.0 +metaworld @ git+https://github.com/Farama-Foundation/Metaworld.git@master#egg=metaworld +cmake==3.29.3 +diffusers==0.28.0 +pyinstrument==4.6.2 +moviepy==1.0.3 +tokenizers==0.19.1 +torchtyping==0.1.4 +ema-pytorch==0.5.2 +natsort==8.4.0 +opencv-python==4.6.0.66 +thop==0.1.1-2209072238 +bddl==1.0.1 +future==0.18.2 +gym==0.25.2 +cloudpickle==2.1.0 +robosuite +transformers diff --git a/scripts/act/finetune.sh b/scripts/act/finetune.sh new file mode 100755 index 0000000..337bee3 --- /dev/null +++ b/scripts/act/finetune.sh @@ -0,0 +1,22 @@ + +# This script is used to finetune ACT on downstream tasks + +python train.py --config-name=train_fewshot.yaml \ + task=libero_long \ + algo=act \ + exp_name=final \ + variant_name=block_16 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + make_unique_experiment_dir=false \ + algo.skill_block_size=16 \ + training.auto_continue=true \ + rollout.num_parallel_envs=5 \ + rollout.rollouts_per_env=5 \ + seed=0 + +# Note1: training.auto_continue will automatically load the latest checkpoint from the previous training stage. +# Else you can specify the checkpoint_path to load a specific checkpoint. diff --git a/scripts/act/main.sh b/scripts/act/main.sh new file mode 100755 index 0000000..51e650c --- /dev/null +++ b/scripts/act/main.sh @@ -0,0 +1,20 @@ + +# This script is used to train the ACT model + +python train.py --config-name=train_prior.yaml \ + task=libero_90 \ + algo=act \ + exp_name=final \ + variant_name=block_16 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + make_unique_experiment_dir=false \ + algo.skill_block_size=16 \ + rollout.num_parallel_envs=5 \ + rollout.rollouts_per_env=5 \ + seed=0 + +# Note2: change rollout.num_parallel_envs to 1 if libero vectorized env is not working as expected. diff --git a/scripts/bc_trans/finetune.sh b/scripts/bc_trans/finetune.sh new file mode 100755 index 0000000..bc05390 --- /dev/null +++ b/scripts/bc_trans/finetune.sh @@ -0,0 +1,22 @@ + +# This script is used to finetune ResNet-T on downstream tasks + +python train.py --config-name=train_fewshot.yaml \ + task=libero_long \ + algo=bc_transformer \ + exp_name=final \ + variant_name=block_10 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + train_dataloader.batch_size=128 \ + make_unique_experiment_dir=false \ + training.auto_continue=true \ + rollout.num_parallel_envs=5 \ + rollout.rollouts_per_env=5 \ + seed=0 + +# Note1: training.auto_continue will automatically load the latest checkpoint from the previous training stage. +# Else you can specify the checkpoint_path to load a specific checkpoint. diff --git a/scripts/bc_trans/main.sh b/scripts/bc_trans/main.sh new file mode 100755 index 0000000..c263dae --- /dev/null +++ b/scripts/bc_trans/main.sh @@ -0,0 +1,20 @@ + +# This script is used to train the ResNet-T model + +python train.py --config-name=train_prior.yaml \ + task=libero_90 \ + algo=bc_transformer \ + exp_name=final \ + variant_name=block_10 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + train_dataloader.batch_size=128 \ + make_unique_experiment_dir=false \ + rollout.num_parallel_envs=5 \ + rollout.rollouts_per_env=5 \ + seed=0 + +# Note2: change rollout.num_parallel_envs to 1 if libero vectorized env is not working as expected. diff --git a/scripts/bet/autoencoder.sh b/scripts/bet/autoencoder.sh new file mode 100755 index 0000000..30059bc --- /dev/null +++ b/scripts/bet/autoencoder.sh @@ -0,0 +1,17 @@ + +# This script is used to train the autoencoder of VQ-BeT + +python train.py --config-name=train_autoencoder.yaml \ + task=libero_90 \ + algo=bet \ + exp_name=final \ + variant_name=block_5 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + train_dataloader.batch_size=128 \ + make_unique_experiment_dir=false \ + algo.skill_block_size=5 \ + seed=0 \ No newline at end of file diff --git a/scripts/bet/finetune.sh b/scripts/bet/finetune.sh new file mode 100755 index 0000000..b30095a --- /dev/null +++ b/scripts/bet/finetune.sh @@ -0,0 +1,23 @@ + +# This script is used to finetune VQ-BeT on downstream tasks + +python train.py --config-name=train_fewshot.yaml \ + task=libero_long \ + algo=bet \ + exp_name=final \ + variant_name=block_5 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + train_dataloader.batch_size=128 \ + make_unique_experiment_dir=false \ + algo.skill_block_size=5 \ + training.auto_continue=true \ + rollout.num_parallel_envs=5 \ + rollout.rollouts_per_env=5 \ + seed=0 + +# Note1: training.auto_continue will automatically load the latest checkpoint from the previous training stage. +# Else you can specify the checkpoint_path to load a specific checkpoint. diff --git a/scripts/bet/main.sh b/scripts/bet/main.sh new file mode 100755 index 0000000..2361fbe --- /dev/null +++ b/scripts/bet/main.sh @@ -0,0 +1,24 @@ + +# This script is used to train the VQ-BeT model + +python train.py --config-name=train_prior.yaml \ + task=libero_90 \ + algo=bet \ + exp_name=final \ + variant_name=block_5 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + train_dataloader.batch_size=128 \ + make_unique_experiment_dir=false \ + algo.skill_block_size=5 \ + training.auto_continue=true \ + rollout.num_parallel_envs=5 \ + rollout.rollouts_per_env=5 \ + seed=0 + +# Note1: training.auto_continue will automatically load the latest checkpoint from the previous training stage. +# Else you can specify the checkpoint_path to load a specific checkpoint. +# Note2: change rollout.num_parallel_envs to 1 if libero vectorized env is not working as expected. diff --git a/scripts/dp/finetune.sh b/scripts/dp/finetune.sh new file mode 100755 index 0000000..ecefc81 --- /dev/null +++ b/scripts/dp/finetune.sh @@ -0,0 +1,24 @@ + +# This script is used to finetune diffusion policy on downstream tasks + +python train.py --config-name=train_fewshot.yaml \ + task=libero_long \ + algo=diffusion_policy \ + exp_name=final \ + variant_name=block_32 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + training.n_epochs=200 \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + make_unique_experiment_dir=false \ + algo.skill_block_size=32 \ + training.auto_continue=true \ + rollout.num_parallel_envs=5 \ + rollout.rollouts_per_env=5 \ + seed=0 + +# Note1: training.auto_continue will automatically load the latest checkpoint from the previous training stage. +# Else you can specify the checkpoint_path to load a specific checkpoint. +# Note2: algo.l1_loss_scale is used to finetune the decoder of the autoencoder. \ No newline at end of file diff --git a/scripts/dp/main.sh b/scripts/dp/main.sh new file mode 100755 index 0000000..fa634ff --- /dev/null +++ b/scripts/dp/main.sh @@ -0,0 +1,22 @@ + +# This script is used to train diffusion policy + +python train.py --config-name=train_prior.yaml \ + task=libero_90 \ + algo=diffusion_policy \ + exp_name=final \ + variant_name=block_32 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + training.n_epochs=200 \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + make_unique_experiment_dir=false \ + algo.skill_block_size=32 \ + training.auto_continue=true \ + rollout.num_parallel_envs=5 \ + rollout.rollouts_per_env=5 \ + seed=0 + +# Note1: change rollout.num_parallel_envs to 1 if libero vectorized env is not working as expected. diff --git a/scripts/eval.sh b/scripts/eval.sh new file mode 100755 index 0000000..75bbd89 --- /dev/null +++ b/scripts/eval.sh @@ -0,0 +1,12 @@ + +python evaluate.py \ + task=libero_90 \ + algo=quest \ + exp_name=final \ + variant_name=block_32_ds_4 \ + stage=1 \ + training.use_tqdm=false \ + seed=0 + +# Note1: this will automatically load the latest checkpoint as per your exp_name, variant_name, algo, and stage. +# Else you can specify the checkpoint_path to load a specific checkpoint. \ No newline at end of file diff --git a/scripts/generate_metaworld_dataset.py b/scripts/generate_metaworld_dataset.py new file mode 100755 index 0000000..6901e61 --- /dev/null +++ b/scripts/generate_metaworld_dataset.py @@ -0,0 +1,97 @@ +import numpy as np +import h5py +from tqdm import tqdm +import json + +import quest.utils.metaworld_utils as mu +import os +import hydra +from hydra.utils import instantiate +import quest.utils.utils as utils +from moviepy.editor import ImageSequenceClip + +@hydra.main(config_path="../config", + config_name='collect_data', + version_base=None) +def main(cfg): + env_runner = instantiate(cfg.task.env_runner) + + data_dir = os.path.join( + cfg.data_prefix, + cfg.task.suite_name, + cfg.task.benchmark_name, + cfg.task.mode + # f"{task_names[i]}.hdf5" + ) + os.makedirs(data_dir, exist_ok=True) + experiment_dir, _ = utils.get_experiment_dir(cfg) + + success_rates, returns = {}, {} + expert = mu.get_expert() + + def noisy_expert(obs, task_id): + expert_action = expert(obs, task_id) + action = np.random.normal(expert_action, cfg.task.demo_noise) + action = np.clip(action, -1, 1) + return action + + for env_name in mu.get_env_names(cfg.task.benchmark_name, cfg.task.mode): + file_path = os.path.join(data_dir, f"{env_name}.hdf5") + if os.path.exists(file_path): + print(f'{file_path} already exists. Skipping') + continue + video_dir = os.path.join(experiment_dir, env_name) + os.makedirs(video_dir) + init_hdf5(file_path, env_name) + + completed = total_return = 0 + rollouts = env_runner.run_policy_in_env(env_name, noisy_expert) + for i, (success, ep_return, episode) in tqdm(enumerate(rollouts), total=cfg.rollout.rollouts_per_env): + + completed += success + total_return += ep_return + + save_path = os.path.join(video_dir, f'trial_{i}.mp4') + clip = ImageSequenceClip(list(episode['corner_rgb']), fps=24) + clip.write_videofile(save_path, fps=24, verbose=False, logger=None) + dump_demo(episode, file_path, i) + success_rate = completed / (i + 1) + success_rates[env_name] = success_rate + returns[env_name] = total_return / (i + 1) + print(env_name, success_rate) + + with open(os.path.join(data_dir, 'success_rates.json'), 'w') as f: + json.dump(success_rates, f) + with open(os.path.join(data_dir, 'returns.json'), 'w') as f: + json.dump(returns, f) + + +def init_hdf5(file_path, env_name): + with h5py.File(file_path, 'a') as f: + group_data = f.create_group('data') + group_data.attrs['total'] = 0 + group_data.attrs['env_args'] = json.dumps({ + 'env_name': env_name, 'env_type': 2, + 'env_kwargs':{'render_mode':'rgb_array', 'camera_name':'corner2'} + }) + +def dump_demo(demo, file_path, demo_i): + with h5py.File(file_path, 'a') as f: + group_data = f['data'] + group = group_data.create_group(f'demo_{demo_i}') + + demo_length = demo['actions'].shape[0] + group_data.attrs['total'] = group_data.attrs['total'] + demo_length + group.attrs['num_samples'] = demo_length + non_obs_keys = ('actions', 'terminated', 'truncated', 'reward', 'success') + group.create_dataset('states', data=()) + for key in demo: + if key in non_obs_keys: + continue + group.create_dataset(f'obs/{key}', data=demo[key]) + for key in non_obs_keys: + group.create_dataset(key, data=demo[key]) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/quest/autoencoder.sh b/scripts/quest/autoencoder.sh new file mode 100755 index 0000000..066180c --- /dev/null +++ b/scripts/quest/autoencoder.sh @@ -0,0 +1,17 @@ + +# This script is used to train stage 0 i.e. the autoencoder of Quest + +python train.py --config-name=train_autoencoder.yaml \ + task=libero_90 \ + algo=quest \ + exp_name=final \ + variant_name=block_32_ds_4 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + make_unique_experiment_dir=false \ + algo.skill_block_size=32 \ + algo.downsample_factor=4 \ + seed=0 \ No newline at end of file diff --git a/scripts/quest/finetune.sh b/scripts/quest/finetune.sh new file mode 100755 index 0000000..881e6c9 --- /dev/null +++ b/scripts/quest/finetune.sh @@ -0,0 +1,25 @@ + +# This script is used to train stage 2 i.e. finetuning Quest for downstream tasks + +python train.py --config-name=train_fewshot.yaml \ + task=libero_long \ + algo=quest \ + exp_name=final \ + variant_name=block_32_ds_4 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + make_unique_experiment_dir=false \ + algo.skill_block_size=32 \ + algo.downsample_factor=4 \ + algo.l1_loss_scale=10 \ + training.auto_continue=true \ + rollout.num_parallel_envs=5 \ + rollout.rollouts_per_env=5 \ + seed=0 + +# Note1: training.auto_continue will automatically load the latest checkpoint from the previous training stage. +# Else you can specify the checkpoint_path to load a specific checkpoint. +# Note2: algo.l1_loss_scale is used to finetune the decoder of the autoencoder. \ No newline at end of file diff --git a/scripts/quest/main.sh b/scripts/quest/main.sh new file mode 100755 index 0000000..6451da8 --- /dev/null +++ b/scripts/quest/main.sh @@ -0,0 +1,24 @@ + +# This script is used to train stage 1 i.e. the gpt-like transformer prior of Quest + +python train.py --config-name=train_prior.yaml \ + task=libero_90 \ + algo=quest \ + exp_name=final \ + variant_name=block_32_ds_4 \ + training.use_tqdm=false \ + training.save_all_checkpoints=true \ + training.use_amp=false \ + train_dataloader.persistent_workers=true \ + train_dataloader.num_workers=6 \ + make_unique_experiment_dir=false \ + algo.skill_block_size=32 \ + algo.downsample_factor=4 \ + training.auto_continue=true \ + rollout.num_parallel_envs=5 \ + rollout.rollouts_per_env=5 \ + seed=0 + +# Note1: training.auto_continue will automatically load the latest checkpoint from the previous training stage. +# Else you can specify the checkpoint_path to load a specific checkpoint. +# Note2: change rollout.num_parallel_envs to 1 if libero vectorized env is not working as expected. diff --git a/train.py b/train.py new file mode 100644 index 0000000..604f1d7 --- /dev/null +++ b/train.py @@ -0,0 +1,190 @@ +import os +import time +import hydra +import wandb +from hydra.utils import instantiate +from omegaconf import OmegaConf +from tqdm import tqdm +from pathlib import Path +import warnings + +import torch +import torch.nn as nn +import quest.utils.utils as utils +from pyinstrument import Profiler +from quest.utils.logger import Logger +import gc + +OmegaConf.register_new_resolver("eval", eval, replace=True) + + +@hydra.main(config_path="config", version_base=None) +def main(cfg): + device = cfg.device + seed = cfg.seed + torch.manual_seed(seed) + train_cfg = cfg.training + + # create model + model = instantiate(cfg.algo.policy, + shape_meta=cfg.task.shape_meta) + model.to(device) + model.train() + + # start training + optimizers = model.get_optimizers() + schedulers = model.get_schedulers(optimizers) + + scaler = torch.cuda.amp.GradScaler(enabled=train_cfg.use_amp) + + experiment_dir, experiment_name = utils.get_experiment_dir(cfg) + os.makedirs(experiment_dir, exist_ok=True) + + start_epoch, steps, wandb_id = 0, 0, None + if train_cfg.auto_continue: + checkpoint_path = os.path.join(experiment_dir, os.path.pardir, f'stage_{cfg.stage - 1}') + elif train_cfg.resume and len(os.listdir(experiment_dir)) > 0: + checkpoint_path = experiment_dir + else: + checkpoint_path = cfg.checkpoint_path + + if checkpoint_path is not None: + checkpoint_path = utils.get_latest_checkpoint(checkpoint_path) + print(f'loading from checkpoint {checkpoint_path}') + state_dict = utils.load_state(checkpoint_path) + loaded_state_dict = state_dict['model'] + + # Below line allows loading state dicts with some mismatched parameters + utils.soft_load_state_dict(model, loaded_state_dict) + + # resuming training since we are loading a checkpoint training the same stage + if cfg.stage == state_dict['stage']: + print('loading from checkpoint') + for optimizer, opt_state_dict in zip(optimizers, state_dict['optimizers']): + optimizer.load_state_dict(opt_state_dict) + for scheduler, sch_state_dict in zip(schedulers, state_dict['schedulers']): + scheduler.load_state_dict(sch_state_dict) + scaler.load_state_dict(state_dict['scaler']) + start_epoch = state_dict['epoch'] + steps = state_dict['steps'] + wandb_id = state_dict['wandb_id'] + else: + print('starting from scratch') + + dataset = instantiate(cfg.task.dataset) + model.preprocess_dataset(dataset, use_tqdm=train_cfg.use_tqdm) + train_dataloader = instantiate( + cfg.train_dataloader, + dataset=dataset) + + + if cfg.rollout.enabled: + env_runner = instantiate(cfg.task.env_runner) + # rollout_results = env_runner.run(model, n_video=cfg.rollout.n_video, do_tqdm=train_cfg.use_tqdm) # for debugging env runner before starting training + + print('Saving to:', experiment_dir) + print('Experiment name:', experiment_name) + + wandb.init( + dir=experiment_dir, + name=experiment_name, + config=OmegaConf.to_container(cfg, resolve=True), + id=wandb_id, + **cfg.logging + ) + + logger = Logger(train_cfg.log_interval) + + print('Training...') + + for epoch in range(start_epoch, train_cfg.n_epochs + 1): + t0 = time.time() + model.train() + training_loss = 0.0 + if train_cfg.do_profile: + profiler = Profiler() + profiler.start() + for idx, data in enumerate(tqdm(train_dataloader, disable=not train_cfg.use_tqdm)): + data = utils.map_tensor_to_device(data, device) + + for optimizer in optimizers: + optimizer.zero_grad() + + with torch.autograd.set_detect_anomaly(False): + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=train_cfg.use_amp): + loss, info = model.compute_loss(data) + + scaler.scale(loss).backward() + + for optimizer in optimizers: + scaler.unscale_(optimizer) + if train_cfg.grad_clip is not None: + grad_norm = nn.utils.clip_grad_norm_( + model.parameters(), train_cfg.grad_clip + ) + + for optimizer in optimizers: + scaler.step(optimizer) + + scaler.update() + + info.update({ + 'epoch': epoch + }) + if train_cfg.grad_clip is not None: + info.update({ + "grad_norm": grad_norm.item(), + }) + info = {cfg.logging_folder: info} + training_loss += loss.item() + steps += 1 + logger.update(info, steps) + + if train_cfg.cut and idx > train_cfg.cut: + break + + if train_cfg.do_profile: + profiler.stop() + profiler.print() + + training_loss /= len(train_dataloader) + t1 = time.time() + print( + f"[info] Epoch: {epoch:3d} | train loss: {training_loss:5.5f} | time: {(t1-t0)/60:4.2f}" + ) + + if epoch % train_cfg.save_interval == 0 and epoch > 0: + if cfg.training.save_all_checkpoints: + model_checkpoint_name_ep = os.path.join( + experiment_dir, f"multitask_model_epoch_{epoch:04d}.pth" + ) + else: + model_checkpoint_name_ep = os.path.join( + experiment_dir, f"multitask_model.pth" + ) + utils.save_state({ + 'model': model, + 'optimizers': optimizers, + 'schedulers': schedulers, + 'scaler': scaler, + 'epoch': epoch, + 'stage': cfg.stage, + 'steps': steps, + 'wandb_id': wandb.run.id, + 'experiment_dir': experiment_dir, + 'experiment_name': experiment_name, + 'config': OmegaConf.to_container(cfg, resolve=True) + }, model_checkpoint_name_ep) + + if cfg.rollout.enabled and epoch > 0 and epoch % cfg.rollout.interval == 0: + rollout_results = env_runner.run(model, n_video=cfg.rollout.n_video, do_tqdm=train_cfg.use_tqdm) + print( + f"[info] success rate: {rollout_results['rollout']['overall_success_rate']:1.3f} \ + | environments solved: {rollout_results['rollout']['environments_solved']}") + logger.log(rollout_results, step=steps) + [scheduler.step() for scheduler in schedulers] + print("[info] finished learning\n") + wandb.finish() + +if __name__ == "__main__": + main() \ No newline at end of file