From 9aa6be9aa474914b84dad36d3f961a544d48d1e3 Mon Sep 17 00:00:00 2001 From: Denis Tarasov <39963896+DT6A@users.noreply.github.com> Date: Wed, 6 Dec 2023 23:46:31 +0100 Subject: [PATCH] ReBRAC finetune (#9) * ReBRAC finetune * Fix * linter * fix linter * Fix * Fix configs * Fix umaze config * Add scoring * Change arrows in README * Update get_finetune_scores.py * Update rebrac.py * Update get_finetune_urls.py * Update get_finetune_scores.py * Update rebrac.py * Update README.md --------- Co-authored-by: Denis Tarasov --- README.md | 58 +- algorithms/finetune/rebrac.py | 1099 +++++++++++++++++ .../rebrac/antmaze/large_diverse_v2.yaml | 35 + .../rebrac/antmaze/large_play_v2.yaml | 35 + .../rebrac/antmaze/medium_diverse_v2.yaml | 35 + .../rebrac/antmaze/medium_play_v2.yaml | 35 + .../rebrac/antmaze/umaze_diverse_v2.yaml | 35 + configs/finetune/rebrac/antmaze/umaze_v2.yaml | 36 + configs/finetune/rebrac/door/cloned_v1.yaml | 35 + configs/finetune/rebrac/hammer/cloned_v1.yaml | 35 + configs/finetune/rebrac/pen/cloned_v1.yaml | 35 + .../finetune/rebrac/relocate/cloned_v1.yaml | 35 + results/bin/finetune_scores.pickle | Bin 705857 -> 741701 bytes results/get_finetune_scores.py | 5 + results/get_finetune_tables_and_plots.py | 6 +- results/get_finetune_urls.py | 4 + results/runs_tables/finetune_urls.csv | 54 +- 17 files changed, 1539 insertions(+), 38 deletions(-) create mode 100644 algorithms/finetune/rebrac.py create mode 100644 configs/finetune/rebrac/antmaze/large_diverse_v2.yaml create mode 100644 configs/finetune/rebrac/antmaze/large_play_v2.yaml create mode 100644 configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml create mode 100644 configs/finetune/rebrac/antmaze/medium_play_v2.yaml create mode 100644 configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml create mode 100644 configs/finetune/rebrac/antmaze/umaze_v2.yaml create mode 100644 configs/finetune/rebrac/door/cloned_v1.yaml create mode 100644 configs/finetune/rebrac/hammer/cloned_v1.yaml create mode 100644 configs/finetune/rebrac/pen/cloned_v1.yaml create mode 100644 configs/finetune/rebrac/relocate/cloned_v1.yaml diff --git a/README.md b/README.md index 9242fb4f..88b8a19a 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ docker run --gpus=all -it --rm --name | ✅ [Conservative Q-Learning for Offline Reinforcement Learning
(CQL)](https://arxiv.org/abs/2006.04779) | [`offline/cql.py`](algorithms/offline/cql.py)
[`finetune/cql.py`](algorithms/finetune/cql.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-CQL--Vmlldzo1MzM4MjY3)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-CQL--Vmlldzo0NTQ3NTMz) | ✅ [Accelerating Online Reinforcement Learning with Offline Datasets
(AWAC)](https://arxiv.org/abs/2006.09359) | [`offline/awac.py`](algorithms/offline/awac.py)
[`finetune/awac.py`](algorithms/finetune/awac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-AWAC--Vmlldzo1MzM4MTEy)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-AWAC--VmlldzozODAyNzQz) | ✅ [Offline Reinforcement Learning with Implicit Q-Learning
(IQL)](https://arxiv.org/abs/2110.06169) | [`offline/iql.py`](algorithms/offline/iql.py)
[`finetune/iql.py`](algorithms/finetune/iql.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-IQL--Vmlldzo1MzM4MzQz)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-IQL--VmlldzozNzE1MTEy) +| ✅ [Revisiting the Minimalist Approach to Offline Reinforcement Learning
(ReBRAC)](https://arxiv.org/abs/2305.09836) | [`offline/rebrac.py`](algorithms/offline/rebrac.py)
[`finetune/rebrac.py`](algorithms/finetune/rebrac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-ReBRAC--Vmlldzo0ODkzOTQ2)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-ReBRAC--Vmlldzo1NDAyNjE5) | **Offline-to-Online only** | | | ✅ [Supported Policy Optimization for Offline Reinforcement Learning
(SPOT)](https://arxiv.org/abs/2202.06239) | [`finetune/spot.py`](algorithms/finetune/spot.py) | [`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-SPOT--VmlldzozODk5MTgx) | ✅ [Cal-QL: Calibrated Offline RL Pre-Training for Efficient Online Fine-Tuning
(Cal-QL)](https://arxiv.org/abs/2303.05479) | [`finetune/cal_ql.py`](algorithms/finetune/cal_ql.py) | [`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-Cal-QL--Vmlldzo0NTQ3NDk5) @@ -57,7 +58,6 @@ docker run --gpus=all -it --rm --name | ✅ [Decision Transformer: Reinforcement Learning via Sequence Modeling
(DT)](https://arxiv.org/abs/2106.01345) | [`offline/dt.py`](algorithms/offline/dt.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-Decision-Transformer--Vmlldzo1MzM3OTkx) | ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(SAC-N)](https://arxiv.org/abs/2110.01548) | [`offline/sac_n.py`](algorithms/offline/sac_n.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-SAC-N--VmlldzoyNzA1NTY1) | ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(EDAC)](https://arxiv.org/abs/2110.01548) | [`offline/edac.py`](algorithms/offline/edac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-EDAC--VmlldzoyNzA5ODUw) -| ✅ [Revisiting the Minimalist Approach to Offline Reinforcement Learning
(ReBRAC)](https://arxiv.org/abs/2305.09836) | [`offline/rebrac.py`](algorithms/offline/rebrac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-ReBRAC--Vmlldzo0ODkzOTQ2) | ✅ [Q-Ensemble for Offline RL: Don't Scale the Ensemble, Scale the Batch Size
(LB-SAC)](https://arxiv.org/abs/2211.11092) | [`offline/lb_sac.py`](algorithms/offline/lb_sac.py) | [`Offline Gym-MuJoCo`](https://wandb.ai/tlab/CORL/reports/LB-SAC-D4RL-Results--VmlldzozNjIxMDY1) @@ -179,42 +179,42 @@ You can check the links above for learning curves and details. Here, we report r ### Offline-to-Online #### Scores -| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL| -|---------------------------|------------|--------|--------|-----|-----| -|antmaze-umaze-v2|52.75 ± 8.67 → 98.75 ± 1.09|94.00 ± 1.58 → 99.50 ± 0.87|77.00 ± 0.71 → 96.50 ± 1.12|91.00 ± 2.55 → 99.50 ± 0.50|76.75 ± 7.53 → 99.75 ± 0.43| -|antmaze-umaze-diverse-v2|56.00 ± 2.74 → 0.00 ± 0.00|9.50 ± 9.91 → 99.00 ± 1.22|59.50 ± 9.55 → 63.75 ± 25.02|36.25 ± 2.17 → 95.00 ± 3.67|32.00 ± 27.79 → 98.50 ± 1.12| -|antmaze-medium-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|59.00 ± 11.18 → 97.75 ± 1.30|71.75 ± 2.95 → 89.75 ± 1.09|67.25 ± 10.47 → 97.25 ± 1.30|71.75 ± 3.27 → 98.75 ± 1.64| -|antmaze-medium-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|63.50 ± 6.84 → 97.25 ± 1.92|64.25 ± 1.92 → 92.25 ± 2.86|73.75 ± 7.29 → 94.50 ± 1.66|62.00 ± 4.30 → 98.25 ± 1.48| -|antmaze-large-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|28.75 ± 7.76 → 88.25 ± 2.28|38.50 ± 8.73 → 64.50 ± 17.04|31.50 ± 12.58 → 87.00 ± 3.24|31.75 ± 8.87 → 97.25 ± 1.79| -|antmaze-large-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|35.50 ± 3.64 → 91.75 ± 3.96|26.75 ± 3.77 → 64.25 ± 4.15|17.50 ± 7.26 → 81.00 ± 14.14|44.00 ± 8.69 → 91.50 ± 3.91| +| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL|ReBRAC| +|---------------------------|------------|--------|--------|-----|-----|-----| +|antmaze-umaze-v2|52.75 ± 8.67 → 98.75 ± 1.09|94.00 ± 1.58 → 99.50 ± 0.87|77.00 ± 0.71 → 96.50 ± 1.12|91.00 ± 2.55 → 99.50 ± 0.50|76.75 ± 7.53 → 99.75 ± 0.43|98.00 ± 1.58 → 74.75 ± 42.59| +|antmaze-umaze-diverse-v2|56.00 ± 2.74 → 0.00 ± 0.00|9.50 ± 9.91 → 99.00 ± 1.22|59.50 ± 9.55 → 63.75 ± 25.02|36.25 ± 2.17 → 95.00 ± 3.67|32.00 ± 27.79 → 98.50 ± 1.12|73.75 ± 13.27 → 98.00 ± 2.92| +|antmaze-medium-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|59.00 ± 11.18 → 97.75 ± 1.30|71.75 ± 2.95 → 89.75 ± 1.09|67.25 ± 10.47 → 97.25 ± 1.30|71.75 ± 3.27 → 98.75 ± 1.64|87.50 ± 3.77 → 98.00 ± 1.58| +|antmaze-medium-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|63.50 ± 6.84 → 97.25 ± 1.92|64.25 ± 1.92 → 92.25 ± 2.86|73.75 ± 7.29 → 94.50 ± 1.66|62.00 ± 4.30 → 98.25 ± 1.48|85.25 ± 2.17 → 98.75 ± 0.43| +|antmaze-large-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|28.75 ± 7.76 → 88.25 ± 2.28|38.50 ± 8.73 → 64.50 ± 17.04|31.50 ± 12.58 → 87.00 ± 3.24|31.75 ± 8.87 → 97.25 ± 1.79|68.50 ± 6.18 → 31.50 ± 33.56| +|antmaze-large-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|35.50 ± 3.64 → 91.75 ± 3.96|26.75 ± 3.77 → 64.25 ± 4.15|17.50 ± 7.26 → 81.00 ± 14.14|44.00 ± 8.69 → 91.50 ± 3.91|67.00 ± 10.61 → 72.25 ± 41.73| | | | | | | | | | | | -| **antmaze average** |18.12 → 16.46|48.38 → 95.58|56.29 → 78.50|52.88 → 92.38|53.04 → 97.33| +| **antmaze average** |18.12 → 16.46|48.38 → 95.58|56.29 → 78.50|52.88 → 92.38|53.04 → 97.33|80.00 → 78.88| | | | | | | | | | | | -|pen-cloned-v1|88.66 ± 15.10 → 86.82 ± 11.12|-2.76 ± 0.08 → -1.28 ± 2.16|84.19 ± 3.96 → 102.02 ± 20.75|6.19 ± 5.21 → 43.63 ± 20.09|-2.66 ± 0.04 → -2.68 ± 0.12| -|door-cloned-v1|0.93 ± 1.66 → 0.01 ± 0.00|-0.33 ± 0.01 → -0.33 ± 0.01|1.19 ± 0.93 → 20.34 ± 9.32|-0.21 ± 0.14 → 0.02 ± 0.31|-0.33 ± 0.01 → -0.33 ± 0.01| -|hammer-cloned-v1|1.80 ± 3.01 → 0.24 ± 0.04|0.56 ± 0.55 → 2.85 ± 4.81|1.35 ± 0.32 → 57.27 ± 28.49|3.97 ± 6.39 → 3.73 ± 4.99|0.25 ± 0.04 → 0.17 ± 0.17| -|relocate-cloned-v1|-0.04 ± 0.04 → -0.04 ± 0.01|-0.33 ± 0.01 → -0.33 ± 0.01|0.04 ± 0.04 → 0.32 ± 0.38|-0.24 ± 0.01 → -0.15 ± 0.05|-0.31 ± 0.05 → -0.31 ± 0.04| +|pen-cloned-v1|88.66 ± 15.10 → 86.82 ± 11.12|-2.76 ± 0.08 → -1.28 ± 2.16|84.19 ± 3.96 → 102.02 ± 20.75|6.19 ± 5.21 → 43.63 ± 20.09|-2.66 ± 0.04 → -2.68 ± 0.12|74.04 ± 11.97 → 138.15 ± 3.22| +|door-cloned-v1|0.93 ± 1.66 → 0.01 ± 0.00|-0.33 ± 0.01 → -0.33 ± 0.01|1.19 ± 0.93 → 20.34 ± 9.32|-0.21 ± 0.14 → 0.02 ± 0.31|-0.33 ± 0.01 → -0.33 ± 0.01|0.07 ± 0.04 → 102.39 ± 8.27| +|hammer-cloned-v1|1.80 ± 3.01 → 0.24 ± 0.04|0.56 ± 0.55 → 2.85 ± 4.81|1.35 ± 0.32 → 57.27 ± 28.49|3.97 ± 6.39 → 3.73 ± 4.99|0.25 ± 0.04 → 0.17 ± 0.17|6.54 ± 3.35 → 124.65 ± 7.37| +|relocate-cloned-v1|-0.04 ± 0.04 → -0.04 ± 0.01|-0.33 ± 0.01 → -0.33 ± 0.01|0.04 ± 0.04 → 0.32 ± 0.38|-0.24 ± 0.01 → -0.15 ± 0.05|-0.31 ± 0.05 → -0.31 ± 0.04|0.70 ± 0.62 → 6.96 ± 4.59| | | | | | | | | | | | -| **adroit average** |22.84 → 21.76|-0.72 → 0.22|21.69 → 44.99|2.43 → 11.81|-0.76 → -0.79| +| **adroit average** |22.84 → 21.76|-0.72 → 0.22|21.69 → 44.99|2.43 → 11.81|-0.76 → -0.79|20.33 → 93.04| #### Regrets -| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL| -|---------------------------|------------|--------|--------|-----|-----| -|antmaze-umaze-v2|0.04 ± 0.01|0.02 ± 0.00|0.07 ± 0.00|0.02 ± 0.00|0.01 ± 0.00| -|antmaze-umaze-diverse-v2|0.88 ± 0.01|0.09 ± 0.01|0.43 ± 0.11|0.22 ± 0.07|0.05 ± 0.01| -|antmaze-medium-play-v2|1.00 ± 0.00|0.08 ± 0.01|0.09 ± 0.01|0.06 ± 0.00|0.04 ± 0.01| -|antmaze-medium-diverse-v2|1.00 ± 0.00|0.08 ± 0.00|0.10 ± 0.01|0.05 ± 0.01|0.04 ± 0.01| -|antmaze-large-play-v2|1.00 ± 0.00|0.21 ± 0.02|0.34 ± 0.05|0.29 ± 0.07|0.13 ± 0.02| -|antmaze-large-diverse-v2|1.00 ± 0.00|0.21 ± 0.03|0.41 ± 0.03|0.23 ± 0.08|0.13 ± 0.02| +| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL|ReBRAC| +|---------------------------|------------|--------|--------|-----|-----|-----| +|antmaze-umaze-v2|0.04 ± 0.01|0.02 ± 0.00|0.07 ± 0.00|0.02 ± 0.00|0.01 ± 0.00|0.11 ± 0.18| +|antmaze-umaze-diverse-v2|0.88 ± 0.01|0.09 ± 0.01|0.43 ± 0.11|0.22 ± 0.07|0.05 ± 0.01|0.04 ± 0.02| +|antmaze-medium-play-v2|1.00 ± 0.00|0.08 ± 0.01|0.09 ± 0.01|0.06 ± 0.00|0.04 ± 0.01|0.03 ± 0.01| +|antmaze-medium-diverse-v2|1.00 ± 0.00|0.08 ± 0.00|0.10 ± 0.01|0.05 ± 0.01|0.04 ± 0.01|0.03 ± 0.00| +|antmaze-large-play-v2|1.00 ± 0.00|0.21 ± 0.02|0.34 ± 0.05|0.29 ± 0.07|0.13 ± 0.02|0.14 ± 0.05| +|antmaze-large-diverse-v2|1.00 ± 0.00|0.21 ± 0.03|0.41 ± 0.03|0.23 ± 0.08|0.13 ± 0.02|0.29 ± 0.39| | | | | | | | | | | | -| **antmaze average** |0.82|0.11|0.24|0.15|0.07| +| **antmaze average** |0.82|0.11|0.24|0.15|0.07|0.11| | | | | | | | | | | | -|pen-cloned-v1|0.46 ± 0.02|0.97 ± 0.00|0.37 ± 0.01|0.58 ± 0.02|0.98 ± 0.01| -|door-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.83 ± 0.03|0.99 ± 0.01|1.00 ± 0.00| -|hammer-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.65 ± 0.10|0.98 ± 0.01|1.00 ± 0.00| -|relocate-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00| +|pen-cloned-v1|0.46 ± 0.02|0.97 ± 0.00|0.37 ± 0.01|0.58 ± 0.02|0.98 ± 0.01|0.08 ± 0.01| +|door-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.83 ± 0.03|0.99 ± 0.01|1.00 ± 0.00|0.19 ± 0.05| +|hammer-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.65 ± 0.10|0.98 ± 0.01|1.00 ± 0.00|0.13 ± 0.03| +|relocate-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|0.90 ± 0.06| | | | | | | | | | | | -| **adroit average** |0.86|0.99|0.71|0.89|0.99| +| **adroit average** |0.86|0.99|0.71|0.89|0.99|0.33| ## Citing CORL diff --git a/algorithms/finetune/rebrac.py b/algorithms/finetune/rebrac.py new file mode 100644 index 00000000..26b0d661 --- /dev/null +++ b/algorithms/finetune/rebrac.py @@ -0,0 +1,1099 @@ +import os + +os.environ["TF_CUDNN_DETERMINISTIC"] = "1" + +import math +import random +import uuid +from copy import deepcopy +from dataclasses import asdict, dataclass +from functools import partial +from typing import Any, Callable, Dict, Sequence, Tuple, Union + +import chex +import d4rl # noqa +import flax.linen as nn +import gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +import pyrallis +import tqdm +import wandb +from flax.core import FrozenDict +from flax.training.train_state import TrainState +from tqdm.auto import trange + +ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate") + +default_kernel_init = nn.initializers.lecun_normal() +default_bias_init = nn.initializers.zeros + + +@dataclass +class Config: + # wandb params + project: str = "ReBRAC" + group: str = "rebrac-finetune" + name: str = "rebrac-finetune" + # model params + actor_learning_rate: float = 3e-4 + critic_learning_rate: float = 3e-4 + hidden_dim: int = 256 + actor_n_hiddens: int = 3 + critic_n_hiddens: int = 3 + replay_buffer_size: int = 2_000_000 + mixing_ratio: float = 0.5 + gamma: float = 0.99 + tau: float = 5e-3 + actor_bc_coef: float = 1.0 + critic_bc_coef: float = 1.0 + actor_ln: bool = False + critic_ln: bool = True + policy_noise: float = 0.2 + noise_clip: float = 0.5 + expl_noise: float = 0.0 + policy_freq: int = 2 + normalize_q: bool = True + min_decay_coef: float = 0.5 + use_calibration: bool = False + reset_opts: bool = False + # training params + dataset_name: str = "halfcheetah-medium-v2" + batch_size: int = 256 + num_offline_updates: int = 1_000_000 + num_online_updates: int = 1_000_000 + num_warmup_steps: int = 0 + normalize_reward: bool = False + normalize_states: bool = False + # evaluation params + eval_episodes: int = 10 + eval_every: int = 5000 + # general params + train_seed: int = 10 + eval_seed: int = 42 + + def __post_init__(self): + self.name = f"{self.name}-{self.dataset_name}-{str(uuid.uuid4())[:8]}" + + +def pytorch_init(fan_in: float) -> Callable: + """ + Default init for PyTorch Linear layer weights and biases: + https://pytorch.org/docs/stable/generated/torch.nn.Linear.html + """ + bound = math.sqrt(1 / fan_in) + + def _init(key: jax.random.PRNGKey, shape: Tuple, dtype: type) -> jax.Array: + return jax.random.uniform( + key, shape=shape, minval=-bound, maxval=bound, dtype=dtype + ) + + return _init + + +def uniform_init(bound: float) -> Callable: + def _init(key: jax.random.PRNGKey, shape: Tuple, dtype: type) -> jax.Array: + return jax.random.uniform( + key, shape=shape, minval=-bound, maxval=bound, dtype=dtype + ) + + return _init + + +def identity(x: Any) -> Any: + return x + + +class DetActor(nn.Module): + action_dim: int + hidden_dim: int = 256 + layernorm: bool = True + n_hiddens: int = 3 + + @nn.compact + def __call__(self, state: jax.Array) -> jax.Array: + s_d, h_d = state.shape[-1], self.hidden_dim + # Initialization as in the EDAC paper + layers = [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(s_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + for _ in range(self.n_hiddens - 1): + layers += [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(h_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + layers += [ + nn.Dense( + self.action_dim, + kernel_init=uniform_init(1e-3), + bias_init=uniform_init(1e-3), + ), + nn.tanh, + ] + net = nn.Sequential(layers) + actions = net(state) + return actions + + +class Critic(nn.Module): + hidden_dim: int = 256 + layernorm: bool = True + n_hiddens: int = 3 + + @nn.compact + def __call__(self, state: jax.Array, action: jax.Array) -> jax.Array: + s_d, a_d, h_d = state.shape[-1], action.shape[-1], self.hidden_dim + # Initialization as in the EDAC paper + layers = [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(s_d + a_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + for _ in range(self.n_hiddens - 1): + layers += [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(h_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + layers += [ + nn.Dense(1, kernel_init=uniform_init(3e-3), bias_init=uniform_init(3e-3)) + ] + network = nn.Sequential(layers) + state_action = jnp.hstack([state, action]) + out = network(state_action).squeeze(-1) + return out + + +class EnsembleCritic(nn.Module): + hidden_dim: int = 256 + num_critics: int = 10 + layernorm: bool = True + n_hiddens: int = 3 + + @nn.compact + def __call__(self, state: jax.Array, action: jax.Array) -> jax.Array: + ensemble = nn.vmap( + target=Critic, + in_axes=None, + out_axes=0, + variable_axes={"params": 0}, + split_rngs={"params": True}, + axis_size=self.num_critics, + ) + q_values = ensemble(self.hidden_dim, self.layernorm, self.n_hiddens)( + state, action + ) + return q_values + + +def calc_return_to_go(is_sparse_reward, rewards, terminals, gamma): + """ + A config dict for getting the default high/low rewrd values for each envs + This is used in calc_return_to_go func in sampler.py and replay_buffer.py + """ + if len(rewards) == 0: + return [] + reward_neg = 0 + if is_sparse_reward and np.all(np.array(rewards) == reward_neg): + """ + If the env has sparse reward and the trajectory is all negative rewards, + we use r / (1-gamma) as return to go. + For exapmle, if gamma = 0.99 and the rewards = [-1, -1, -1], + then return_to_go = [-100, -100, -100] + """ + # assuming failure reward is negative + # use r / (1-gamma) for negative trajctory + return_to_go = [float(reward_neg / (1 - gamma))] * len(rewards) + else: + return_to_go = [0] * len(rewards) + prev_return = 0 + for i in range(len(rewards)): + return_to_go[-i - 1] = \ + rewards[-i - 1] + gamma * prev_return * (1 - terminals[-i - 1]) + prev_return = return_to_go[-i - 1] + + return return_to_go + + +def qlearning_dataset( + env, dataset_name, + normalize_reward=False, dataset=None, + terminate_on_end=False, discount=0.99, **kwargs +): + """ + Returns datasets formatted for use by standard Q-learning algorithms, + with observations, actions, next_observations, next_actins, rewards, + and a terminal flag. + Args: + env: An OfflineEnv object. + dataset: An optional dataset to pass in for processing. If None, + the dataset will default to env.get_dataset() + terminate_on_end (bool): Set done=True on the last timestep + in a trajectory. Default is False, and will discard the + last timestep in each trajectory. + **kwargs: Arguments to pass to env.get_dataset(). + Returns: + A dictionary containing keys: + observations: An N x dim_obs array of observations. + actions: An N x dim_action array of actions. + next_observations: An N x dim_obs array of next observations. + next_actions: An N x dim_action array of next actions. + rewards: An N-dim float array of rewards. + terminals: An N-dim boolean array of "done" or episode termination flags. + """ + if dataset is None: + dataset = env.get_dataset(**kwargs) + if normalize_reward: + dataset['rewards'] = ReplayBuffer.normalize_reward( + dataset_name, dataset['rewards'] + ) + N = dataset['rewards'].shape[0] + is_sparse = "antmaze" in dataset_name + obs_ = [] + next_obs_ = [] + action_ = [] + next_action_ = [] + reward_ = [] + done_ = [] + mc_returns_ = [] + print("SIZE", N) + # The newer version of the dataset adds an explicit + # timeouts field. Keep old method for backwards compatability. + use_timeouts = 'timeouts' in dataset + + episode_step = 0 + episode_rewards = [] + episode_terminals = [] + for i in range(N - 1): + if episode_step == 0: + episode_rewards = [] + episode_terminals = [] + + obs = dataset['observations'][i].astype(np.float32) + new_obs = dataset['observations'][i + 1].astype(np.float32) + action = dataset['actions'][i].astype(np.float32) + new_action = dataset['actions'][i + 1].astype(np.float32) + reward = dataset['rewards'][i].astype(np.float32) + done_bool = bool(dataset['terminals'][i]) + + if use_timeouts: + final_timestep = dataset['timeouts'][i] + else: + final_timestep = (episode_step == env._max_episode_steps - 1) + if (not terminate_on_end) and final_timestep: + # Skip this transition and don't apply terminals on the last step of episode + episode_step = 0 + mc_returns_ += calc_return_to_go( + is_sparse, episode_rewards, episode_terminals, discount + ) + continue + if done_bool or final_timestep: + episode_step = 0 + + episode_rewards.append(reward) + episode_terminals.append(done_bool) + + obs_.append(obs) + next_obs_.append(new_obs) + action_.append(action) + next_action_.append(new_action) + reward_.append(reward) + done_.append(done_bool) + episode_step += 1 + if episode_step != 0: + mc_returns_ += calc_return_to_go( + is_sparse, episode_rewards, episode_terminals, discount + ) + assert np.array(mc_returns_).shape == np.array(reward_).shape + return { + 'observations': np.array(obs_), + 'actions': np.array(action_), + 'next_observations': np.array(next_obs_), + 'next_actions': np.array(next_action_), + 'rewards': np.array(reward_), + 'terminals': np.array(done_), + 'mc_returns': np.array(mc_returns_), + } + + +def compute_mean_std(states: jax.Array, eps: float) -> Tuple[jax.Array, jax.Array]: + mean = states.mean(0) + std = states.std(0) + eps + return mean, std + + +def normalize_states(states: jax.Array, mean: jax.Array, std: jax.Array) -> jax.Array: + return (states - mean) / std + + +@chex.dataclass +class ReplayBuffer: + data: Dict[str, jax.Array] = None + mean: float = 0 + std: float = 1 + + def create_from_d4rl( + self, + dataset_name: str, + normalize_reward: bool = False, + is_normalize: bool = False, + ): + d4rl_data = qlearning_dataset(gym.make(dataset_name), dataset_name) + buffer = { + "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), + "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), + "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32), + "next_states": jnp.asarray( + d4rl_data["next_observations"], dtype=jnp.float32 + ), + "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32), + "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32), + } + if is_normalize: + self.mean, self.std = compute_mean_std(buffer["states"], eps=1e-3) + buffer["states"] = normalize_states(buffer["states"], self.mean, self.std) + buffer["next_states"] = normalize_states( + buffer["next_states"], self.mean, self.std + ) + if normalize_reward: + buffer["rewards"] = ReplayBuffer.normalize_reward( + dataset_name, buffer["rewards"] + ) + self.data = buffer + + @property + def size(self) -> int: + # WARN: It will use len of the dataclass, i.e. number of fields. + return self.data["states"].shape[0] + + def sample_batch( + self, key: jax.random.PRNGKey, batch_size: int + ) -> Dict[str, jax.Array]: + indices = jax.random.randint( + key, shape=(batch_size,), minval=0, maxval=self.size + ) + batch = jax.tree_map(lambda arr: arr[indices], self.data) + return batch + + def get_moments(self, modality: str) -> Tuple[jax.Array, jax.Array]: + mean = self.data[modality].mean(0) + std = self.data[modality].std(0) + return mean, std + + @staticmethod + def normalize_reward(dataset_name: str, rewards: jax.Array) -> jax.Array: + if "antmaze" in dataset_name: + return rewards * 100.0 # like in LAPO + else: + raise NotImplementedError( + "Reward normalization is implemented only for AntMaze yet!" + ) + + +class Dataset(object): + def __init__(self, observations: np.ndarray, actions: np.ndarray, + rewards: np.ndarray, masks: np.ndarray, + dones_float: np.ndarray, next_observations: np.ndarray, + next_actions: np.ndarray, + mc_returns: np.ndarray, + size: int): + self.observations = observations + self.actions = actions + self.rewards = rewards + self.masks = masks + self.dones_float = dones_float + self.next_observations = next_observations + self.next_actions = next_actions + self.mc_returns = mc_returns + self.size = size + + def sample(self, batch_size: int) -> Dict[str, np.ndarray]: + indx = np.random.randint(self.size, size=batch_size) + return { + "states": self.observations[indx], + "actions": self.actions[indx], + "rewards": self.rewards[indx], + "dones": self.dones_float[indx], + "next_states": self.next_observations[indx], + "next_actions": self.next_actions[indx], + "mc_returns": self.mc_returns[indx], + } + + +class OnlineReplayBuffer(Dataset): + def __init__(self, observation_space: gym.spaces.Box, action_dim: int, + capacity: int): + + observations = np.empty((capacity, *observation_space.shape), + dtype=observation_space.dtype) + actions = np.empty((capacity, action_dim), dtype=np.float32) + rewards = np.empty((capacity,), dtype=np.float32) + mc_returns = np.empty((capacity,), dtype=np.float32) + masks = np.empty((capacity,), dtype=np.float32) + dones_float = np.empty((capacity,), dtype=np.float32) + next_observations = np.empty((capacity, *observation_space.shape), + dtype=observation_space.dtype) + next_actions = np.empty((capacity, action_dim), dtype=np.float32) + super().__init__(observations=observations, + actions=actions, + rewards=rewards, + masks=masks, + dones_float=dones_float, + next_observations=next_observations, + next_actions=next_actions, + mc_returns=mc_returns, + size=0) + + self.size = 0 + + self.insert_index = 0 + self.capacity = capacity + + def initialize_with_dataset(self, dataset: Dataset, + num_samples=None): + assert self.insert_index == 0, \ + 'Can insert a batch online in an empty replay buffer.' + + dataset_size = len(dataset.observations) + + if num_samples is None: + num_samples = dataset_size + else: + num_samples = min(dataset_size, num_samples) + assert self.capacity >= num_samples, \ + 'Dataset cannot be larger than the replay buffer capacity.' + + if num_samples < dataset_size: + perm = np.random.permutation(dataset_size) + indices = perm[:num_samples] + else: + indices = np.arange(num_samples) + + self.observations[:num_samples] = dataset.observations[indices] + self.actions[:num_samples] = dataset.actions[indices] + self.rewards[:num_samples] = dataset.rewards[indices] + self.masks[:num_samples] = dataset.masks[indices] + self.dones_float[:num_samples] = dataset.dones_float[indices] + self.next_observations[:num_samples] = dataset.next_observations[ + indices] + self.next_actions[:num_samples] = dataset.next_actions[ + indices] + self.mc_returns[:num_samples] = dataset.mc_returns[indices] + + self.insert_index = num_samples + self.size = num_samples + + def insert(self, observation: np.ndarray, action: np.ndarray, + reward: float, mask: float, done_float: float, + next_observation: np.ndarray, + next_action: np.ndarray, mc_return: np.ndarray): + self.observations[self.insert_index] = observation + self.actions[self.insert_index] = action + self.rewards[self.insert_index] = reward + self.masks[self.insert_index] = mask + self.dones_float[self.insert_index] = done_float + self.next_observations[self.insert_index] = next_observation + self.next_actions[self.insert_index] = next_action + self.mc_returns[self.insert_index] = mc_return + + self.insert_index = (self.insert_index + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) + + +class D4RLDataset(Dataset): + def __init__( + self, + env: gym.Env, + env_name: str, + normalize_reward: bool, + discount: float, + ): + d4rl_data = qlearning_dataset( + env, env_name, normalize_reward=normalize_reward, discount=discount + ) + dataset = { + "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), + "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), + "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32), + "next_states": jnp.asarray( + d4rl_data["next_observations"], dtype=jnp.float32 + ), + "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32), + "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32), + "mc_returns": jnp.asarray(d4rl_data["mc_returns"], dtype=jnp.float32) + } + + super().__init__(dataset['states'].astype(np.float32), + actions=dataset['actions'].astype(np.float32), + rewards=dataset['rewards'].astype(np.float32), + masks=1.0 - dataset['dones'].astype(np.float32), + dones_float=dataset['dones'].astype(np.float32), + next_observations=dataset['next_states'].astype( + np.float32), + next_actions=dataset["next_actions"], + mc_returns=dataset["mc_returns"], + size=len(dataset['states'])) + + +def concat_batches(b1, b2): + new_batch = {} + for k in b1: + new_batch[k] = np.concatenate((b1[k], b2[k]), axis=0) + return new_batch + + +@chex.dataclass(frozen=True) +class Metrics: + accumulators: Dict[str, Tuple[jax.Array, jax.Array]] + + @staticmethod + def create(metrics: Sequence[str]) -> "Metrics": + init_metrics = {key: (jnp.array([0.0]), jnp.array([0.0])) for key in metrics} + return Metrics(accumulators=init_metrics) + + def update(self, updates: Dict[str, jax.Array]) -> "Metrics": + new_accumulators = deepcopy(self.accumulators) + for key, value in updates.items(): + acc, steps = new_accumulators[key] + new_accumulators[key] = (acc + value, steps + 1) + + return self.replace(accumulators=new_accumulators) + + def compute(self) -> Dict[str, np.ndarray]: + # cumulative_value / total_steps + return {k: np.array(v[0] / v[1]) for k, v in self.accumulators.items()} + + +def normalize( + arr: jax.Array, mean: jax.Array, std: jax.Array, eps: float = 1e-8 +) -> jax.Array: + return (arr - mean) / (std + eps) + + +def make_env(env_name: str, seed: int) -> gym.Env: + env = gym.make(env_name) + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + +def wrap_env( + env: gym.Env, + state_mean: Union[np.ndarray, float] = 0.0, + state_std: Union[np.ndarray, float] = 1.0, + reward_scale: float = 1.0, +) -> gym.Env: + # PEP 8: E731 do not assign a lambda expression, use a def + def normalize_state(state: np.ndarray) -> np.ndarray: + return ( + state - state_mean + ) / state_std # epsilon should be already added in std. + + def scale_reward(reward: float) -> float: + # Please be careful, here reward is multiplied by scale! + return reward_scale * reward + + env = gym.wrappers.TransformObservation(env, normalize_state) + if reward_scale != 1.0: + env = gym.wrappers.TransformReward(env, scale_reward) + return env + + +def make_env_and_dataset(env_name: str, + seed: int, + normalize_reward: bool, + discount: float) -> Tuple[gym.Env, D4RLDataset]: + env = gym.make(env_name) + + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + + dataset = D4RLDataset(env, env_name, normalize_reward, discount=discount) + + return env, dataset + + +def is_goal_reached(reward: float, info: Dict) -> bool: + if "goal_achieved" in info: + return info["goal_achieved"] + return reward > 0 # Assuming that reaching target is a positive reward + + +def evaluate( + env: gym.Env, params, action_fn: Callable, num_episodes: int, seed: int +) -> Tuple[np.ndarray, np.ndarray]: + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + + returns = [] + successes = [] + for _ in trange(num_episodes, desc="Eval", leave=False): + obs, done = env.reset(), False + goal_achieved = False + total_reward = 0.0 + while not done: + action = np.asarray(jax.device_get(action_fn(params, obs))) + obs, reward, done, info = env.step(action) + total_reward += reward + if not goal_achieved: + goal_achieved = is_goal_reached(reward, info) + successes.append(float(goal_achieved)) + returns.append(total_reward) + + return np.array(returns), np.mean(successes) + + +class CriticTrainState(TrainState): + target_params: FrozenDict + + +class ActorTrainState(TrainState): + target_params: FrozenDict + + +@jax.jit +def update_actor( + key: jax.random.PRNGKey, + actor: TrainState, + critic: TrainState, + batch: Dict[str, jax.Array], + beta: float, + tau: float, + normalize_q: bool, + metrics: Metrics, +) -> Tuple[jax.random.PRNGKey, TrainState, jax.Array, Metrics]: + key, random_action_key = jax.random.split(key, 2) + + def actor_loss_fn(params): + actions = actor.apply_fn(params, batch["states"]) + + bc_penalty = ((actions - batch["actions"]) ** 2).sum(-1) + q_values = critic.apply_fn(critic.params, batch["states"], actions).min(0) + # lmbda = 1 + # # if normalize_q: + lmbda = jax.lax.stop_gradient(1 / jax.numpy.abs(q_values).mean()) + + loss = (beta * bc_penalty - lmbda * q_values).mean() + + # logging stuff + random_actions = jax.random.uniform( + random_action_key, shape=batch["actions"].shape, minval=-1.0, maxval=1.0 + ) + new_metrics = metrics.update({ + "actor_loss": loss, + "bc_mse_policy": bc_penalty.mean(), + "bc_mse_random": ((random_actions - batch["actions"]) ** 2).sum(-1).mean(), + "action_mse": ((actions - batch["actions"]) ** 2).mean() + }) + return loss, new_metrics + + grads, new_metrics = jax.grad(actor_loss_fn, has_aux=True)(actor.params) + new_actor = actor.apply_gradients(grads=grads) + + new_actor = new_actor.replace( + target_params=optax.incremental_update(actor.params, actor.target_params, tau) + ) + new_critic = critic.replace( + target_params=optax.incremental_update(critic.params, critic.target_params, tau) + ) + + return key, new_actor, new_critic, new_metrics + + +def update_critic( + key: jax.random.PRNGKey, + actor: TrainState, + critic: CriticTrainState, + batch: Dict[str, jax.Array], + gamma: float, + beta: float, + tau: float, + policy_noise: float, + noise_clip: float, + use_calibration: bool, + metrics: Metrics, +) -> Tuple[jax.random.PRNGKey, TrainState, Metrics]: + key, actions_key = jax.random.split(key) + + next_actions = actor.apply_fn(actor.target_params, batch["next_states"]) + noise = jax.numpy.clip( + (jax.random.normal(actions_key, next_actions.shape) * policy_noise), + -noise_clip, + noise_clip, + ) + next_actions = jax.numpy.clip(next_actions + noise, -1, 1) + + bc_penalty = ((next_actions - batch["next_actions"]) ** 2).sum(-1) + next_q = critic.apply_fn( + critic.target_params, batch["next_states"], next_actions + ).min(0) + next_q = next_q - beta * bc_penalty + target_q = jax.lax.cond( + use_calibration, + lambda: jax.numpy.maximum( + batch["rewards"] + (1 - batch["dones"]) * gamma * next_q, + batch['mc_returns'] + ), + lambda: batch["rewards"] + (1 - batch["dones"]) * gamma * next_q + ) + + def critic_loss_fn(critic_params): + # [N, batch_size] - [1, batch_size] + q = critic.apply_fn(critic_params, batch["states"], batch["actions"]) + q_min = q.min(0).mean() + loss = ((q - target_q[None, ...]) ** 2).mean(1).sum(0) + return loss, q_min + + (loss, q_min), grads = jax.value_and_grad( + critic_loss_fn, has_aux=True + )(critic.params) + new_critic = critic.apply_gradients(grads=grads) + new_metrics = metrics.update({ + "critic_loss": loss, + "q_min": q_min, + }) + return key, new_critic, new_metrics + + +@jax.jit +def update_td3( + key: jax.random.PRNGKey, + actor: TrainState, + critic: CriticTrainState, + batch: Dict[str, Any], + metrics: Metrics, + gamma: float, + actor_bc_coef: float, + critic_bc_coef: float, + tau: float, + policy_noise: float, + noise_clip: float, + normalize_q: bool, + use_calibration: bool, +): + key, new_critic, new_metrics = update_critic( + key, actor, critic, batch, gamma, critic_bc_coef, tau, + policy_noise, noise_clip, use_calibration, metrics + ) + key, new_actor, new_critic, new_metrics = update_actor( + key, actor, new_critic, batch, actor_bc_coef, tau, + normalize_q, new_metrics + ) + return key, new_actor, new_critic, new_metrics + + +@jax.jit +def update_td3_no_targets( + key: jax.random.PRNGKey, + actor: TrainState, + critic: CriticTrainState, + batch: Dict[str, Any], + gamma: float, + metrics: Metrics, + actor_bc_coef: float, + critic_bc_coef: float, + tau: float, + policy_noise: float, + noise_clip: float, + use_calibration: bool, +): + key, new_critic, new_metrics = update_critic( + key, actor, critic, batch, gamma, critic_bc_coef, tau, + policy_noise, noise_clip, use_calibration, metrics + ) + return key, actor, new_critic, new_metrics + + +def action_fn(actor: TrainState) -> Callable: + @jax.jit + def _action_fn(obs: jax.Array) -> jax.Array: + action = actor.apply_fn(actor.params, obs) + return action + + return _action_fn + + +@pyrallis.wrap() +def train(config: Config): + dict_config = asdict(config) + dict_config["mlc_job_name"] = os.environ.get("PLATFORM_JOB_NAME") + is_env_with_goal = config.dataset_name.startswith(ENVS_WITH_GOAL) + np.random.seed(config.train_seed) + random.seed(config.train_seed) + + wandb.init( + config=dict_config, + project=config.project, + group=config.group, + name=config.name, + id=str(uuid.uuid4()), + ) + buffer = ReplayBuffer() + buffer.create_from_d4rl( + config.dataset_name, config.normalize_reward, config.normalize_states + ) + + key = jax.random.PRNGKey(seed=config.train_seed) + key, actor_key, critic_key = jax.random.split(key, 3) + + init_state = buffer.data["states"][0][None, ...] + init_action = buffer.data["actions"][0][None, ...] + + actor_module = DetActor( + action_dim=init_action.shape[-1], hidden_dim=config.hidden_dim, + layernorm=config.actor_ln, n_hiddens=config.actor_n_hiddens + ) + actor = ActorTrainState.create( + apply_fn=actor_module.apply, + params=actor_module.init(actor_key, init_state), + target_params=actor_module.init(actor_key, init_state), + tx=optax.adam(learning_rate=config.actor_learning_rate), + ) + + critic_module = EnsembleCritic( + hidden_dim=config.hidden_dim, num_critics=2, + layernorm=config.critic_ln, n_hiddens=config.critic_n_hiddens + ) + critic = CriticTrainState.create( + apply_fn=critic_module.apply, + params=critic_module.init(critic_key, init_state, init_action), + target_params=critic_module.init(critic_key, init_state, init_action), + tx=optax.adam(learning_rate=config.critic_learning_rate), + ) + + # metrics + bc_metrics_to_log = [ + "critic_loss", "q_min", "actor_loss", "batch_entropy", + "bc_mse_policy", "bc_mse_random", "action_mse" + ] + # shared carry for update loops + carry = { + "key": key, + "actor": actor, + "critic": critic, + "buffer": buffer, + "delayed_updates": jax.numpy.equal( + jax.numpy.arange( + config.num_offline_updates + config.num_online_updates + ) % config.policy_freq, 0 + ).astype(int) + } + + # Online + offline tuning + env, dataset = make_env_and_dataset( + config.dataset_name, config.train_seed, False, discount=config.gamma + ) + eval_env, _ = make_env_and_dataset( + config.dataset_name, config.eval_seed, False, discount=config.gamma + ) + + max_steps = env._max_episode_steps + + action_dim = env.action_space.shape[0] + replay_buffer = OnlineReplayBuffer(env.observation_space, action_dim, + config.replay_buffer_size) + replay_buffer.initialize_with_dataset(dataset, None) + online_buffer = OnlineReplayBuffer( + env.observation_space, action_dim, config.replay_buffer_size + ) + + online_batch_size = 0 + offline_batch_size = config.batch_size + + observation, done = env.reset(), False + episode_step = 0 + goal_achieved = False + + @jax.jit + def actor_action_fn(params, obs): + return actor.apply_fn(params, obs) + + eval_successes = [] + train_successes = [] + print("Offline training") + for i in tqdm.tqdm( + range(config.num_online_updates + config.num_offline_updates), + smoothing=0.1 + ): + carry["metrics"] = Metrics.create(bc_metrics_to_log) + if i == config.num_offline_updates: + print("Online training") + + online_batch_size = int(config.mixing_ratio * config.batch_size) + offline_batch_size = config.batch_size - online_batch_size + # Reset optimizers similar to SPOT + if config.reset_opts: + actor = actor.replace( + opt_state=optax.adam(learning_rate=config.actor_learning_rate).init(actor.params) + ) + critic = critic.replace( + opt_state=optax.adam(learning_rate=config.critic_learning_rate).init(critic.params) + ) + + update_td3_partial = partial( + update_td3, gamma=config.gamma, + tau=config.tau, + policy_noise=config.policy_noise, + noise_clip=config.noise_clip, + normalize_q=config.normalize_q, + use_calibration=config.use_calibration, + ) + + update_td3_no_targets_partial = partial( + update_td3_no_targets, gamma=config.gamma, + tau=config.tau, + policy_noise=config.policy_noise, + noise_clip=config.noise_clip, + use_calibration=config.use_calibration, + ) + online_log = {} + + if i >= config.num_offline_updates: + episode_step += 1 + action = np.asarray(actor_action_fn(carry["actor"].params, observation)) + action = np.array( + [ + ( + action + + np.random.normal(0, 1 * config.expl_noise, size=action_dim) + ).clip(-1, 1) + ] + )[0] + + next_observation, reward, done, info = env.step(action) + if not goal_achieved: + goal_achieved = is_goal_reached(reward, info) + next_action = np.asarray( + actor_action_fn(carry["actor"].params, next_observation) + )[0] + next_action = np.array( + [ + ( + next_action + + np.random.normal(0, 1 * config.expl_noise, size=action_dim) + ).clip(-1, 1) + ] + )[0] + + if not done or 'TimeLimit.truncated' in info: + mask = 1.0 + else: + mask = 0.0 + real_done = False + if done and episode_step < max_steps: + real_done = True + + online_buffer.insert(observation, action, reward, mask, + float(real_done), next_observation, next_action, 0) + observation = next_observation + if done: + train_successes.append(goal_achieved) + observation, done = env.reset(), False + episode_step = 0 + goal_achieved = False + + if config.num_offline_updates <= \ + i < \ + config.num_offline_updates + config.num_warmup_steps: + continue + + offline_batch = replay_buffer.sample(offline_batch_size) + online_batch = online_buffer.sample(online_batch_size) + batch = concat_batches(offline_batch, online_batch) + + if 'antmaze' in config.dataset_name and config.normalize_reward: + batch["rewards"] *= 100 + + ### Update step + actor_bc_coef = config.actor_bc_coef + critic_bc_coef = config.critic_bc_coef + if i >= config.num_offline_updates: + lin_coef = ( + config.num_online_updates + + config.num_offline_updates - + i + config.num_warmup_steps + ) / config.num_online_updates + decay_coef = max(config.min_decay_coef, lin_coef) + actor_bc_coef *= decay_coef + critic_bc_coef *= 0 + if i % config.policy_freq == 0: + update_fn = partial(update_td3_partial, + actor_bc_coef=actor_bc_coef, + critic_bc_coef=critic_bc_coef, + key=key, + actor=carry["actor"], + critic=carry["critic"], + batch=batch, + metrics=carry["metrics"]) + else: + update_fn = partial(update_td3_no_targets_partial, + actor_bc_coef=actor_bc_coef, + critic_bc_coef=critic_bc_coef, + key=key, + actor=carry["actor"], + critic=carry["critic"], + batch=batch, + metrics=carry["metrics"]) + key, new_actor, new_critic, new_metrics = update_fn() + carry.update( + key=key, actor=new_actor, critic=new_critic, metrics=new_metrics + ) + + if i % 1000 == 0: + mean_metrics = carry["metrics"].compute() + common = {f"ReBRAC/{k}": v for k, v in mean_metrics.items()} + common["actor_bc_coef"] = actor_bc_coef + common["critic_bc_coef"] = critic_bc_coef + if i < config.num_offline_updates: + wandb.log({"offline_iter": i, **common}) + else: + wandb.log({"online_iter": i - config.num_offline_updates, **common}) + if i % config.eval_every == 0 or\ + i == config.num_offline_updates + config.num_online_updates - 1 or\ + i == config.num_offline_updates - 1: + eval_returns, success_rate = evaluate( + eval_env, carry["actor"].params, actor_action_fn, + config.eval_episodes, + seed=config.eval_seed + ) + normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0 + eval_successes.append(success_rate) + if is_env_with_goal: + online_log["train/regret"] = np.mean(1 - np.array(train_successes)) + offline_log = { + "eval/return_mean": np.mean(eval_returns), + "eval/return_std": np.std(eval_returns), + "eval/normalized_score_mean": np.mean(normalized_score), + "eval/normalized_score_std": np.std(normalized_score), + "eval/success_rate": success_rate + } + offline_log.update(online_log) + wandb.log(offline_log) + + +if __name__ == "__main__": + train() diff --git a/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml new file mode 100644 index 00000000..cd208878 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.002 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.002 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-large-diverse-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-large-diverse-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/large_play_v2.yaml b/configs/finetune/rebrac/antmaze/large_play_v2.yaml new file mode 100644 index 00000000..77e56655 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/large_play_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.002 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.001 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-large-play-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-large-play-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml new file mode 100644 index 00000000..d5d56435 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.001 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.0 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-medium-diverse-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-medium-diverse-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/medium_play_v2.yaml b/configs/finetune/rebrac/antmaze/medium_play_v2.yaml new file mode 100644 index 00000000..e1c4966a --- /dev/null +++ b/configs/finetune/rebrac/antmaze/medium_play_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.001 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.0005 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-medium-play-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-medium-play-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml new file mode 100644 index 00000000..41fdf292 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.003 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.001 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-umaze-diverse-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-umaze-diverse-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/umaze_v2.yaml b/configs/finetune/rebrac/antmaze/umaze_v2.yaml new file mode 100644 index 00000000..042b5232 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/umaze_v2.yaml @@ -0,0 +1,36 @@ +actor_bc_coef: 0.003 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.002 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-umaze-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-umaze-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false + diff --git a/configs/finetune/rebrac/door/cloned_v1.yaml b/configs/finetune/rebrac/door/cloned_v1.yaml new file mode 100644 index 00000000..06db15cd --- /dev/null +++ b/configs/finetune/rebrac/door/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.01 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.1 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: door-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-door-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/hammer/cloned_v1.yaml b/configs/finetune/rebrac/hammer/cloned_v1.yaml new file mode 100644 index 00000000..c55be6f2 --- /dev/null +++ b/configs/finetune/rebrac/hammer/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.1 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.5 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: hammer-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-hammer-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/pen/cloned_v1.yaml b/configs/finetune/rebrac/pen/cloned_v1.yaml new file mode 100644 index 00000000..9144ab9c --- /dev/null +++ b/configs/finetune/rebrac/pen/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.05 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.5 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: pen-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-pen-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/relocate/cloned_v1.yaml b/configs/finetune/rebrac/relocate/cloned_v1.yaml new file mode 100644 index 00000000..fefa5048 --- /dev/null +++ b/configs/finetune/rebrac/relocate/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.1 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.01 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: relocate-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-relocate-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/results/bin/finetune_scores.pickle b/results/bin/finetune_scores.pickle index d60377ec763001d8b678a2bb4c22d82ed4dd82d5..a0a682d367051865f229727df37b46f1e4df7e9d 100644 GIT binary patch delta 36898 zcmc&-30RF=_pj5DBMnq48Hz}CMdqYvG^9x-Lq#EDrBWpFmJm@wIxZnZ>58aiu4F2g zq{uA8bq!H&G${JN=bZPPv)5@Ey5ILb|MT2D`nC4jYp=b2YpuQCH@6_>x6AttBi2Y~ zvpC(kPAY7bTq7+OONOh|jjh5BaF$i!)~w?wb2opLR^=vtkXGSVx^h&wUfL`LuGvRv z4Q{fHY_Ht>?ySWuZoVo@mD?$b{;^Se!H-q8d$2UPPIF~dx$)y^RkS{3ZvHAc)##eu zEc#0(x28ASRhu`jF1M2{PeO)eC-|GTG)tN5Wz9C>f1B$xj@^^RSS7<6nHz7#u4ZwQ z?b$BeconuU*K3Zfa&G=mwh}9`8(WF{&5brgTZ27Qpe(s7yOhym%-$^>y}^-9f3?fq z;3!+C$!&JwFssRQq)n}vMO&%8nL{t;3BBL|OZr|d?$@XDIUsxhb zf3>H-IJp;=bhx)kMoLOL=Xbd?`?+IQ(`%3DOKVfepp~1~(<6=4=x-}|*5m~XHb~`U zs$4kZQS-$##{c8#W{xg5e3=}R!MB)f3@doG)OCPk2KKoHEeqRWGi|#nUTB&_olzioJ%`4LW3SjhZF5$ z4ojCCe^Bl|H@-VtIXC{WTu&J;KMGoM=pi>{XhCh>g7Q1i(()N{s$4o$=~ZRY7%%K{ z{WImJu=qjDgb+`f{;>l0oII;5e=)hVOAaurrp1|`lql1Bl=jim^aRAFayR?X7Mmx? zspJM^$(3}bHpx2sSUOu+C4?vF&6n{^(4+sG&zY}i#VLC2cDN|({FsDw3C;-uXs@@& z<^!kx$Ri+|y-S_zYg0%m(+1m39X*eb(CwvevPpW6NJ!(G+MQmBl$trAZ0P!nX;)Qc zNvNqjHZ+lSmKy6{oij->PQkbvhW1ZLC?U48Fx4xSg!azr7Hu}_BMIF(yL9R98I>fYajDm# zxqA5mh!Qx~@P)O<@#f>C*n^`zZQBoLtnd--O8=YhuY9#mBPBbE3XXBB66Vk5guIG+asSo~PNZpiNUr$ab^V<~KQ0NN8cLQu?fc(@1D@ z$&b=@#CFM?ztR1uA%adgQ@?SsH+xN?K zQflRj5fdx@fz+jMQxkKx0jP)l)EKD<05zY#_w?m54_X%`if!uEUuj8FK!0vPK)(Y+ zcw!={?B>JmSE@3~e8!Er`0;Sl-@Y586GR)7v&Jp9Mw!`O!R}V|aPGhFQF)6c~B9B{p|T$thB7_bl0l0)1e`qh{5r+Btco)Yz!3t`?yn zUxdZ29RJS+K(6Vfbw^$HiwHsSGR=#iPIr2rK1V96a~fcJ+5$jh2d4#>_5_Z-w=Xa& zw;`7tC_Crivw;JcDaEvjIXdi%?7a)5WZVsxHJ+5EJn5s^$M+VM%>}$U)pM2pn2C0z zl=g$(Vb=o2lh6UX;hM5O<48#Vbwj;VgE-Xbh3->)Xn*A@2~nIS=gU@peh%1GYkAID zstEArP?YAPx$Qx8#+e6hI$#P0R$kRma_a=7h8)~pwnPCyN1QbjvzQpBn9TkqX6u^t zW-w*v_~EnD27{49zpjkyege=wR3Ur&pS1u|eYALHw<}vM%WOcBN%Li%^_g*)lnQ(~SUY!n770E2+_CD05hyoD ziw8&_%K=g)Hz&VZTT&z*NLg{@Q^NcgudWeeR-E7$0bPBL19kPSNGS;jD{#MyUT*Yt?bdP-Q@74nul8}uF_7%AqBu2i34vvZb{ z3O(wg|FTi(ONM%IZBlPDrBNiLzHsp8yxC(($nR0j&i&S)PCcU14;`IQZc2*zI^8|G zLK}>%I#s@TbbE7BO3ril^ip3l5~@3XE=R+fWEaP#d!+Ho2LO`cM(kFc;KB@~_^po1 zC9@5tkV;z4C)0c`fgm25=6IYs1u`?0-@6IKFcZ{}yO?XAq%PnVlgezmbC%FV}(2HP*Hf>3d^d1aaAoSjFg zUg|ljGO_;|66(Klc4gL6>NL5vcKXNM!q#PzP<`x@NP`Dbsd)%l>@ObM^gyyNIZ zik-Wo&HFLQIE`$)GTu0|xZ(fiE<3pLPY zE$}u;I(WMyDYdnGZqV3^!vzo%aZ3u?dvs*Ukz&3^7v6Rb??ggWMP8-wH!655cg(d$Kw1_>2B>#n=b4jkSbJO|osf7p(cI!NuB#tN!qMv578Z0$$?n?)>A z)FGmu(jK2NB*e`dcc4l^n}qz9u27Du0q2_XY1JX=TZJi}pW!PTN6AFpnL!T}cl?Eo zv)Bz*UZk?%jEXWJdjM75_!@C=3V>YZ9h&A~Hj^Cqda7dMSPDobR-IjcW;TGbPu=vR zO2+|AZFlo!7e%1#vQDh2o7RjWS zqmm{accMU>VQ%yQfht3 zkU@RA1CAJ;nzrchDWFSV$!Jel6oB+vvhJS^q#;Utm^5tSYI~3ar7y-%BhvZLy`)|85R)}e{Rj%Pk^K~F;ZxMb!=Q1cN<%7cdmlWFaEfIG7pFWLollv1K zZXE|`%zLlRBZ(+0H-Gr#@t`*dm43fVWpW$<^z_lS@_A|iitS*%CwmfrmK@bkkbgvZ zlKP+LDyTFC0f3#02%ppDD^x5AS!O3 z$E!t+fH!kqs(kI41;kF(CO77u0+6kaUC$CI_2vPnndU-s<=N~2x*xh^Ls zbt9K#vHng|WS9yGx%5dhsdxnl$I(nx*tnmO;?qXZ@#lB8d^przi5z`bE_kNLI4csG z)EI7Junp9fE!LwZt?~y@xmMAs;eUa0^S;#Uo4Xkx*YImb2Y2uD1RXzjUiJIPZ-5Xz zzLs?GNC6Gb3ih23fro%pdGB34Jc~gL4}37Jr>#GLOwMs*_1M6Q`_HH*!;c+cTpO!D z+L;YTF1^)hhetDj=2p-7XT1s-Sis8JVlD;rJ9O{BeNzNd$XL&JeMaet2J-Vi_kGN} zotqf_G;c1tdZo{~*g)_mqM}K6;ADa~R-FE_R=!)jz|B{+UwDFhc?WXOQjX;LKT2sJ zY18age75#?rb}b`FPUwe_j?M!~X3^ z=_v^H;Sa1j+nxXswdJ=pxAjv%bZR*+D&7jidO*OFL;7%_+^nn#)b$>sSmY(3kAa6; zGFSkrf0gt>c{G^P%gr}w=0Y&#!Sz*}mt6%=-sBFeKbZhsDnk_$$GH$v&Q_ehi`Ol> z(H)4{*Mx@;d`gH}adIZ_=%Ehl?3usZ*bjXe!qK!jYpXLd|J?^}=h5v{sW-3sY*ad< zN^V5(UZ;e*7;q-+lU5isW0w;t<(V!&>TE5@U+UUv1s`+3CB`82`Ys`X0IGY_C-bx> zGmzm`OV{K^&S0SbCd=;Y$-MyjbFuNFG0MO&-^-EqN1lPi8aYyzlT!$$44{VjU3%?I zn*T9%lW}9xWQxQN2fZBoxdG@;=`g|WVgZ28JoqEzkL3U|^9oRY{tiqT`D~TP$$kKu zIPFksQ7Z6QAmzDSS7#%mOw5=bvuyKRN4Sy3?Fu1F^I#FFN=|B#7u!uEV;$-36o^9%uT_m-QxX zt>#|8w&DhmN_v&lz795HKHZ*&QJ(#QSe5&V%7kVRW?zT&n9CRZv{H=L`0|^aqx7@p?41`K;pzqxbX;CQh>ssiP5sBevL9Cee3*Fk+ z9zZ*{|K9R!BM5biJt3@afdIPbdagXJ!}F~T2=0f`N#0jnc;UI!j_exe2p5wm&Kwtn}6S zia%3EfC^MF$S(ZRHxTMFseg?7wgo^5uh)*Q2n8{=Q{j&Rq3T}XTywYk^hRZAAU3G{ z%;%BiAQUI<;ZQNr02>$bYoz`W4EOsV6Vwjb5;(E1(^PqF71x% z0_eq{=qF3z5&m`dp|VrYfc|ZET{^El0tQ~2S9ImLJ%BbJP+)sOte?1lyl2JERKD{{#2;sivxxMBfASLhMw5~E2 zNKG!)JL#}M6lA<2{bpKzh#xEmr?k!mP|xpxk#A$o-{nPsk#4ECE9T0BAol$-wToL6 zkg~|z|B00XpsdG+nMtsY=-u98Z52=GqR%y1k4u#f)_{>y>$WxZf>+rWZH%W|z6Vmt zF$?8Cq%#mro2&~izp?x5<=&E!Q>D&#-MjN;;4$(*-!h?=s(bK|gz6TCesnCZAR(Fd zcdU~~l#x)VTqM_g#zzu5ncXa3Y~7xW-a#|H#-6&!T*fiHdYjsJj`C}8XP2%!cE0M& zCQ{ku3tdycy!uQ+DsMW@?qv8k39Ylrx?Vg3T=w-iZog>y=1!y)pL!1Up8vR(B1f83 z9*I=_^p1q8lOFZGd=I=nP#9L#WMc)aFe@9m`?C3Sa-d0OX6|v{FC?^VzM@W)m+0jj zUbJOUGnO^<(R(Zq6Eh~RgPQl68(<3#q!cAzIsx8@i(VI^^KKt_V!BOntZE;7@Fou1 zzWP(Z>3Gr-1)WU}d#jG+k@p495-u5S*aVancdAroCCAfZq8n>-*-*`wQHx1s(<_~2 zma^`UP_CiEJht?E5*oEv@x|4WY7(lKyXQv9doYi81&nX7Y*b$T6+GJYy&l%XSn4G? zQb$><>EpP4d4yrvoclGLw^gK6c#7qiwWGk4YxlZD{kbU4``%X>cTX9`uNj;w2eT73JKafz!IO#t+ z4->71U)QQRzwKg=10x6V%ViikBIEA;9fg2 zN+x(=JqX2ZOYFL2?Un-PnyR>Vd&;f?V=o;zX=UzMM_Qq)wzd5W4Zu{J#;5PLT>|=z z^S%tGU)UxG9z5~6#{M}NsFZw{(%-udY??*JIX$_%fJzl>(rdPQ0FLi7SfW@H!H_{r zo6YgjSGw5TB&S`w!RW?{1AmZ^=l%&HC0;>!yyy|0YwqT$Crw!Zrv0#9_s{DEcgc|^ z)(wRp%0Plxd&*WWe=3L`gPLif*KUK%CB+GQbCL@(qhaS8>Mem3aIT6&$a&Yl=xq@b z&eUj{n(74Lwkt!gsfHxu0k5U8>%#X8t0$!>WtVX~+|Pgz zdhtpp_?VA6X+^thvl8aHfSc5ANkhh@ZRt!(eflfdG^6o%UOTeBryiIIbMX7_W-PLt`$;1q>L&H~%)B3y)_(I>Xz(n&1 zk;Uk@e2c?NvlqW2-D7&ya9XlG2-e+I3k%mPGf_^QYXXOjNVL8KDwD;3 zHw#{|rNBFcMepZmnOI#VrNT-GNMz=2=r-SDR^~b0I zN!IJZ6OMCTF1?QU4LrIC((Jfuk0E#+BmGs){2FyIFm0yoSI$Xo(u%?zh7YdI2G0sl z9{+sK#zu6q=jZ=jhVmQNuIeG;O;wb#KR4Ei;Ei&+ZEiO-5HE`xdQi`iGt+VTJ*nd5lQavoF2!!IPN###c4ujp;B)e&}{r*}~ z*~8@8X9pL84eE7Q$H2h{w$Pum|Y04-DU%&qG~UbDfuCXnj-&fySvG+DfSo6c=t zVA%RqN>}=n0m8+dzVG?vD~O4q#>Wr7*#@LOjl0p}Ed%(hOPvTyec=R_ly-JO-Ct(_ z)Xp-&VhzP)baA=)Ht%#h^Ee<`c3c0oRXX@phWDfV$#-sz-^!hu1^KY$-XkTouG!)Am`u<(t8-wCMPaylM%a zWZM*G&!Y-Y_9P=#{&2+peUJ_u}+o=_Mj#UmK_4m&*ih6VpJR8d}J*T8$)L%rK zH%}I1SQZvXC-^K}I#EiBe)n1M`g6xktOdpi8`ekDU98?jnfXc%4i2_0^oRa~YH67$ z6`U~3iX+8*81gLiI43P-#bKeSH;Q_p=nNDU%2IH;IjAr`9ggA}JW;*0EXDr;G(u<_ zMDginp^AT9mfru_s0!Q&Toz0xpeM6^;3xAr3`%I>*ZfOsIp-HC5@0rg%we>J&LNKXqQ$f#Uq)Jd(acoq%hr0K3M9y zAYQjSQx%!rVPqtnh@)PlLS}#I_`|o7h!G+fa1@U+oQV>Ip_p~dL~OI-NC^K>7M~6` zO}u15{rO?2| zKYw>b@wa`ClVJLPUCkdDIZ< zJ2b<+frlKb_!sq3f;JCVflr500Tf;~(xdSO<^4W;;Wi z^KN{Pzx<0~CWGPh#Y0MHsYLvN*+nF1Lp-&qWYUD8*H3?b-~y2m_&R76{14QCqgWIE z-ToI}Gmhe3!;Jvp5Yr=r5RVs}y&R9)~B z5hZa{6o|}wZA@VNL|E`tf|14Z0-h!DjpAwG`yYmu7r8>wI&j6JbpX?FS=3;PgVPK* z0+&S0)^+Km}K~bE{ z(^1(mC@PUqg!gYa-@tSvZNIZZKy#6gaMX)$UaR=SLl`Fk9;-O|9liAY{0}6-OcecJ z81aMqH{9Z1V4}$JxWPa#MbkDhsAt&TPlww=Tj(yvTN11_M4c$XPhtGw#TF+%j^d3o zUb@7E8Qj0&B^XzMXB51c;d=Q@7nBSe9n({*Wrp zgo;ew5qBY7v4l^1TOV$)d7n|F^{A8{^KZp%*OT~tQqhZj|`k<_>{u97j7!k|4PDw7P>8L zXp06ymh69Vi*X|e6-0dFvOhv;{|j9qSaVwx|7%&Az8gAS-e$SJfckG&UHl5Fb>e4E zqFDUl`;E22Rq5%*tXtuX%Z_eC(qSaEO^e`iHl=Sa?x=Tp%C;_W2f(}JU+h>5rc zoQeEPW?uY3+wj%lOvKXxZiJEit1A%J z^5%zQ8F&2>3g!`6iYIFjXd+n(&-XZrhX~G3TtlmiV~YQlC(f9*D8^00{X^cb@qYh= zsb+YY#1ohq8PtHactb)?DV~w`5f*>A)%dP@k%owtf|CI28@_x0hjR|6I8GQ`gM_TG z{@`TBRp4uuP|R#9j*4dk^#=&^ify2UN`ikW?}$5xKAHW?BQCFuGrQE9)Y$&x!uTtw z-O>HmMo}FErT4!V0$Xt~VJi-P$%Rih5Y;PD)r9+RjS8Ei=tvYbK~em{wU~~){RjNT zSBR$~y!_*>0)9E*DcFC>8i1?7S&gH34rxu_Hs8OAsY5=R$Bn>IFLI;A{RftvtT=7S zdO$Hv-|V4V&%Up{i&S0O0;cE+wNBnZlQVlX-#D1)NvV~)z={Dl3lNCnP9kR*PAwKzWk3B>U?6SW$T zN?aDFxEHEdNDYbp7Y}!w%t93~v~U#HOPeOhKTrj}I@}20Owsn?Uc*ucU4OuIG;K2j z{lZ6^*#3~VKjbIla0xEs>x=Eg8mnOn+vaV_yt>Am^hp; zg6MyT60o>O#f`vaaU;Ni8nBk`m9h8(H(8>X!VeU9VB@h04!(@u@AkjAVw~b2f<*@Z z5;cVN2N%ZI(H2?UG;xvu^#@mss{pMzkF^;+Rt(&v)K=CK@yCRp1Z{CtTtLZ6*#(VY zH54@#pp?XPc%AEks=!I`OOKtI{AI|*zl0GaLB&iQgBroCLqdZIDvO{LeK^5s3p@d6 zYk?}qozI9$m?uU-{E1<5SO2FdMQewt|9?7NtJ!$-^FRL&M$oLSrn;<94m!Jrcf_4R zKcx8&j<~FUwGgoXc(bg8DhA;x%DJuXm-yF9_@iBZ-EO`AIH7Q^jiuP8V-Ty-*`?;!7OEg#9o6;)}RXC9@B18ZL_)EM79GKe)x-ZCD*$Lirz< zP{h;G%b;0%p>9BF^{n^2UYQdA`WG|wXN>v5Ka0^84UDX0{K+ynTtHhCgA^sI0vf1Z zWxW5@Muj`0D1Ntvrw*hw=k>q1Ff-dPZ3ig9fVDXJ#Y_b8HxpeUh#rwcaT2s8R{4bC z<$t)>xMG|~_;gNwpp@|GC*Eb_C{!WTpa>KTQM{)Y?GF6?H#~Rr zLlw72(Vi&U3q?oKC@=m6`vGWch7Q1s6x~JgKafVhSb_!dC(Z|C{CS~@C!;7%4ZQ6o z$wl-(I1_OrkW5MkGM+lf>BL-MW|OQxxDohu590`uPP7pKrCED1XWUmCOZp$5@>w{Q z_MOBDm%snKA#nGA5>d{@QB4Mh8t_8eTT)iA|H9J{RADM9%rE_*Lll{Yzoy`fY9O;6 zvH$o{4}b9=k8YeFcyKevP09VoEyiOWpn$%3$;2uU-T&hC1CQF-s6%J{dsNu};wo?> zaD#Ca*DGEI=s(2p&(n|>y5>1;K=~SY{DV;sjG+1Zd#Rzj)*Su+YSpFFguK8~$@s$) z8_qfWXlRML0KZdE#rt3UO(eX|01}AO0nRyu%e?*r~6>Z^Mt7Qo4Pb-*TY4FwI zvH+!7yFc;fQ033lf67{3{4q|H$OAZvm(|heK4_sRl7e;!MOVrRbgj z_P==K;AIt;#kq{5JcD`hFX4FHVo*{+vBlSoo7al9g7}+>Ius9g5NRSNwndht9MS*$ z=Zd}Bz6{CwgBw9wEN-E3;^P`{|I7mQho=?c9lwJB`Sd`BTi`;G;sjr4fm7T&Se^8o~npPcD8w~=}%BG@L z78P#GIQcyyL3xLnG$wv#GqFr^&Urrqn_B90RFe zT9~A)#rK3h?WMUgqHw;8k3=^l_pM9uOui#}hGi-_(V@LV{F^~;7e5t7ru-a-)oDK` zW*l-JfQia8%$n^D8Ycm!b{%6?W9(m`ycYrYzHFrYEh!>qiwTbL1CG5m-WZJEN&XT1 z%l-fnAllYzr^P#iE(A>Y2{7hek~R%MCGC%iPfTAN6)mx;Qs-5 zBrPRaI)o7Q3zCo)l@iGZH9q OS6RBSpsLO6%&xzKhXe`$ diff --git a/results/get_finetune_scores.py b/results/get_finetune_scores.py index dd31ad19..1fb9303d 100644 --- a/results/get_finetune_scores.py +++ b/results/get_finetune_scores.py @@ -32,9 +32,14 @@ def get_run_scores(run_id, is_dt=False): break for _, row in run.history(keys=[score_key], samples=5000).iterrows(): full_scores.append(row[score_key]) + + for _, row in run.history(keys=["train/regret"], samples=5000).iterrows(): + if "train/regret" in row: + regret = row["train/regret"] for _, row in run.history(keys=["eval/regret"], samples=5000).iterrows(): if "eval/regret" in row: regret = row["eval/regret"] + offline_iters = len(full_scores) // 2 return full_scores[:offline_iters], full_scores[offline_iters:], regret diff --git a/results/get_finetune_tables_and_plots.py b/results/get_finetune_tables_and_plots.py index d8072652..997bc15a 100644 --- a/results/get_finetune_tables_and_plots.py +++ b/results/get_finetune_tables_and_plots.py @@ -76,6 +76,7 @@ def get_last_scores(avg_scores, avg_stds): for algo in full_offline_scores: for data in full_offline_scores[algo]: + # print(algo, flush=True) full_offline_scores[algo][data] = [s[0] for s in full_scores[algo][data]] full_online_scores[algo][data] = [s[1] for s in full_scores[algo][data]] regrets[algo][data] = np.mean([s[2] for s in full_scores[algo][data]]) @@ -127,7 +128,7 @@ def add_domains_avg(scores): add_domains_avg(last_online_scores) add_domains_avg(regrets) -algorithms = ["AWAC", "CQL", "IQL", "SPOT", "Cal-QL"] +algorithms = ["AWAC", "CQL", "IQL", "SPOT", "Cal-QL", "ReBRAC",] datasets = dataframe["dataset"].unique() ordered_datasets = [ "antmaze-umaze-v2", @@ -417,8 +418,9 @@ def flatten(data): "CQL", "IQL", "AWAC", + "Cal-QL", ] -for a1 in ["Cal-QL"]: +for a1 in ["ReBRAC"]: for a2 in algs: algorithm_pairs[f"{a1},{a2}"] = (flat[a1], flat[a2]) average_probabilities, average_prob_cis = rly.get_interval_estimates( diff --git a/results/get_finetune_urls.py b/results/get_finetune_urls.py index 12d7e374..07048ca3 100644 --- a/results/get_finetune_urls.py +++ b/results/get_finetune_urls.py @@ -18,6 +18,8 @@ def get_urls(sweep_id, algo_name): dataset = run.config["env"] elif "env_name" in run.config: dataset = run.config["env_name"] + elif "dataset_name" in run.config: + dataset = run.config["dataset_name"] name = algo_name if "10" in "-".join(run.name.split("-")[:-1]): name = "10% " + name @@ -41,6 +43,8 @@ def get_urls(sweep_id, algo_name): get_urls("tlab/CORL/sweeps/efvz7d68", "Cal-QL") +get_urls("tlab/CORL/sweeps/62cgqb8c", "ReBRAC") + dataframe = pd.DataFrame(collected_urls) dataframe.to_csv("runs_tables/finetune_urls.csv", index=False) diff --git a/results/runs_tables/finetune_urls.csv b/results/runs_tables/finetune_urls.csv index 3058d67a..ae637030 100644 --- a/results/runs_tables/finetune_urls.csv +++ b/results/runs_tables/finetune_urls.csv @@ -32,8 +32,8 @@ SPOT,antmaze-large-diverse-v2,tlab/CORL/runs/0kn4pl04 SPOT,antmaze-large-diverse-v2,tlab/CORL/runs/onon4kbo SPOT,antmaze-large-play-v2,tlab/CORL/runs/5ldclyhi SPOT,antmaze-large-play-v2,tlab/CORL/runs/v8uskc0k -SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/sysutdr0 SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/tnksp757 +SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/sysutdr0 SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/cg7vkg7p SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/12ivynlo SPOT,antmaze-large-diverse-v2,tlab/CORL/runs/kvxnj3cw @@ -75,20 +75,20 @@ AWAC,antmaze-large-play-v2,tlab/CORL/runs/z88yrg2k AWAC,antmaze-large-play-v2,tlab/CORL/runs/b1e1up19 AWAC,antmaze-large-play-v2,tlab/CORL/runs/bpq150gg AWAC,antmaze-large-play-v2,tlab/CORL/runs/rffmurq2 -AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/tiq65215 AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/uo63dgzx +AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/tiq65215 AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/mfjh57xv AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/325fy2js CQL,antmaze-umaze-v2,tlab/CORL/runs/vdh2wmw9 CQL,antmaze-umaze-v2,tlab/CORL/runs/wh27aupq CQL,antmaze-umaze-v2,tlab/CORL/runs/7r4uwutz -CQL,antmaze-large-play-v2,tlab/CORL/runs/kt7jwqcz CQL,antmaze-umaze-v2,tlab/CORL/runs/l5xvgwt4 +CQL,antmaze-large-play-v2,tlab/CORL/runs/kt7jwqcz CQL,antmaze-large-play-v2,tlab/CORL/runs/8fm40vpm CQL,antmaze-large-play-v2,tlab/CORL/runs/yeax28su -CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/gvhslqyo -CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/mowkqr6u CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/pswhm9pi +CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/mowkqr6u +CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/gvhslqyo CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/hh5vv5qc CQL,door-cloned-v1,tlab/CORL/runs/2xc0y5sd CQL,door-cloned-v1,tlab/CORL/runs/2ylul8yr @@ -188,14 +188,54 @@ Cal-QL,antmaze-umaze-diverse-v2,tlab/CORL/runs/a015fjb1 Cal-QL,antmaze-umaze-diverse-v2,tlab/CORL/runs/1pu06s2i Cal-QL,antmaze-umaze-diverse-v2,tlab/CORL/runs/iwa1o31k Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/yvqv3mxa -Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/4myjeu5g Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/6ptdr78l +Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/4myjeu5g Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/8ix0469p -Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/4chdwkua Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/fzrlcnwp +Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/4chdwkua Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/f9hz4fal Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/fpq2ob8q Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/zhf7tr7p Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/m02ew5oy Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/9r1a0trx Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/ds2dbx2u +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/jt59ttpm +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/i8v1hu1q +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/qedc68y9 +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/69ax7dcn +ReBRAC,door-cloned-v1,tlab/CORL/runs/eqfcelaj +ReBRAC,door-cloned-v1,tlab/CORL/runs/prxbzf7x +ReBRAC,door-cloned-v1,tlab/CORL/runs/tdcevdlp +ReBRAC,door-cloned-v1,tlab/CORL/runs/kl86pkrh +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/gpzs7ko9 +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/mpqjq5iv +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/5fu56nvb +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/dfzw050e +ReBRAC,pen-cloned-v1,tlab/CORL/runs/j29tzijm +ReBRAC,pen-cloned-v1,tlab/CORL/runs/4en0ayoi +ReBRAC,pen-cloned-v1,tlab/CORL/runs/4eyhtq9d +ReBRAC,pen-cloned-v1,tlab/CORL/runs/so5u47bk +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/yskdxtsn +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/32gy3ulp +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/l4jqb7k7 +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/mgd0l6wd +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/owmozman +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/0w7wg74n +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/arp17u1o +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/erip8h20 +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/5pd0cva3 +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/2beklksm +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/68m94c0v +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/g5cshcb2 +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/qyggi592 +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/4u324o5s +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/vbgrhecg +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/2286tpe9 +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/24zf9wms +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/wc7z2xz3 +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/wxq1x3o2 +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/p8j2dn0y +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/0ig2xzi7 +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/puzizafo +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/m84hfj5d +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/u1y0nldu