diff --git a/README.md b/README.md index 9242fb4f..88b8a19a 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ docker run --gpus=all -it --rm --name | ✅ [Conservative Q-Learning for Offline Reinforcement Learning
(CQL)](https://arxiv.org/abs/2006.04779) | [`offline/cql.py`](algorithms/offline/cql.py)
[`finetune/cql.py`](algorithms/finetune/cql.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-CQL--Vmlldzo1MzM4MjY3)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-CQL--Vmlldzo0NTQ3NTMz) | ✅ [Accelerating Online Reinforcement Learning with Offline Datasets
(AWAC)](https://arxiv.org/abs/2006.09359) | [`offline/awac.py`](algorithms/offline/awac.py)
[`finetune/awac.py`](algorithms/finetune/awac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-AWAC--Vmlldzo1MzM4MTEy)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-AWAC--VmlldzozODAyNzQz) | ✅ [Offline Reinforcement Learning with Implicit Q-Learning
(IQL)](https://arxiv.org/abs/2110.06169) | [`offline/iql.py`](algorithms/offline/iql.py)
[`finetune/iql.py`](algorithms/finetune/iql.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-IQL--Vmlldzo1MzM4MzQz)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-IQL--VmlldzozNzE1MTEy) +| ✅ [Revisiting the Minimalist Approach to Offline Reinforcement Learning
(ReBRAC)](https://arxiv.org/abs/2305.09836) | [`offline/rebrac.py`](algorithms/offline/rebrac.py)
[`finetune/rebrac.py`](algorithms/finetune/rebrac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-ReBRAC--Vmlldzo0ODkzOTQ2)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-ReBRAC--Vmlldzo1NDAyNjE5) | **Offline-to-Online only** | | | ✅ [Supported Policy Optimization for Offline Reinforcement Learning
(SPOT)](https://arxiv.org/abs/2202.06239) | [`finetune/spot.py`](algorithms/finetune/spot.py) | [`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-SPOT--VmlldzozODk5MTgx) | ✅ [Cal-QL: Calibrated Offline RL Pre-Training for Efficient Online Fine-Tuning
(Cal-QL)](https://arxiv.org/abs/2303.05479) | [`finetune/cal_ql.py`](algorithms/finetune/cal_ql.py) | [`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-Cal-QL--Vmlldzo0NTQ3NDk5) @@ -57,7 +58,6 @@ docker run --gpus=all -it --rm --name | ✅ [Decision Transformer: Reinforcement Learning via Sequence Modeling
(DT)](https://arxiv.org/abs/2106.01345) | [`offline/dt.py`](algorithms/offline/dt.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-Decision-Transformer--Vmlldzo1MzM3OTkx) | ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(SAC-N)](https://arxiv.org/abs/2110.01548) | [`offline/sac_n.py`](algorithms/offline/sac_n.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-SAC-N--VmlldzoyNzA1NTY1) | ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(EDAC)](https://arxiv.org/abs/2110.01548) | [`offline/edac.py`](algorithms/offline/edac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-EDAC--VmlldzoyNzA5ODUw) -| ✅ [Revisiting the Minimalist Approach to Offline Reinforcement Learning
(ReBRAC)](https://arxiv.org/abs/2305.09836) | [`offline/rebrac.py`](algorithms/offline/rebrac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-ReBRAC--Vmlldzo0ODkzOTQ2) | ✅ [Q-Ensemble for Offline RL: Don't Scale the Ensemble, Scale the Batch Size
(LB-SAC)](https://arxiv.org/abs/2211.11092) | [`offline/lb_sac.py`](algorithms/offline/lb_sac.py) | [`Offline Gym-MuJoCo`](https://wandb.ai/tlab/CORL/reports/LB-SAC-D4RL-Results--VmlldzozNjIxMDY1) @@ -179,42 +179,42 @@ You can check the links above for learning curves and details. Here, we report r ### Offline-to-Online #### Scores -| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL| -|---------------------------|------------|--------|--------|-----|-----| -|antmaze-umaze-v2|52.75 ± 8.67 → 98.75 ± 1.09|94.00 ± 1.58 → 99.50 ± 0.87|77.00 ± 0.71 → 96.50 ± 1.12|91.00 ± 2.55 → 99.50 ± 0.50|76.75 ± 7.53 → 99.75 ± 0.43| -|antmaze-umaze-diverse-v2|56.00 ± 2.74 → 0.00 ± 0.00|9.50 ± 9.91 → 99.00 ± 1.22|59.50 ± 9.55 → 63.75 ± 25.02|36.25 ± 2.17 → 95.00 ± 3.67|32.00 ± 27.79 → 98.50 ± 1.12| -|antmaze-medium-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|59.00 ± 11.18 → 97.75 ± 1.30|71.75 ± 2.95 → 89.75 ± 1.09|67.25 ± 10.47 → 97.25 ± 1.30|71.75 ± 3.27 → 98.75 ± 1.64| -|antmaze-medium-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|63.50 ± 6.84 → 97.25 ± 1.92|64.25 ± 1.92 → 92.25 ± 2.86|73.75 ± 7.29 → 94.50 ± 1.66|62.00 ± 4.30 → 98.25 ± 1.48| -|antmaze-large-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|28.75 ± 7.76 → 88.25 ± 2.28|38.50 ± 8.73 → 64.50 ± 17.04|31.50 ± 12.58 → 87.00 ± 3.24|31.75 ± 8.87 → 97.25 ± 1.79| -|antmaze-large-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|35.50 ± 3.64 → 91.75 ± 3.96|26.75 ± 3.77 → 64.25 ± 4.15|17.50 ± 7.26 → 81.00 ± 14.14|44.00 ± 8.69 → 91.50 ± 3.91| +| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL|ReBRAC| +|---------------------------|------------|--------|--------|-----|-----|-----| +|antmaze-umaze-v2|52.75 ± 8.67 → 98.75 ± 1.09|94.00 ± 1.58 → 99.50 ± 0.87|77.00 ± 0.71 → 96.50 ± 1.12|91.00 ± 2.55 → 99.50 ± 0.50|76.75 ± 7.53 → 99.75 ± 0.43|98.00 ± 1.58 → 74.75 ± 42.59| +|antmaze-umaze-diverse-v2|56.00 ± 2.74 → 0.00 ± 0.00|9.50 ± 9.91 → 99.00 ± 1.22|59.50 ± 9.55 → 63.75 ± 25.02|36.25 ± 2.17 → 95.00 ± 3.67|32.00 ± 27.79 → 98.50 ± 1.12|73.75 ± 13.27 → 98.00 ± 2.92| +|antmaze-medium-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|59.00 ± 11.18 → 97.75 ± 1.30|71.75 ± 2.95 → 89.75 ± 1.09|67.25 ± 10.47 → 97.25 ± 1.30|71.75 ± 3.27 → 98.75 ± 1.64|87.50 ± 3.77 → 98.00 ± 1.58| +|antmaze-medium-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|63.50 ± 6.84 → 97.25 ± 1.92|64.25 ± 1.92 → 92.25 ± 2.86|73.75 ± 7.29 → 94.50 ± 1.66|62.00 ± 4.30 → 98.25 ± 1.48|85.25 ± 2.17 → 98.75 ± 0.43| +|antmaze-large-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|28.75 ± 7.76 → 88.25 ± 2.28|38.50 ± 8.73 → 64.50 ± 17.04|31.50 ± 12.58 → 87.00 ± 3.24|31.75 ± 8.87 → 97.25 ± 1.79|68.50 ± 6.18 → 31.50 ± 33.56| +|antmaze-large-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|35.50 ± 3.64 → 91.75 ± 3.96|26.75 ± 3.77 → 64.25 ± 4.15|17.50 ± 7.26 → 81.00 ± 14.14|44.00 ± 8.69 → 91.50 ± 3.91|67.00 ± 10.61 → 72.25 ± 41.73| | | | | | | | | | | | -| **antmaze average** |18.12 → 16.46|48.38 → 95.58|56.29 → 78.50|52.88 → 92.38|53.04 → 97.33| +| **antmaze average** |18.12 → 16.46|48.38 → 95.58|56.29 → 78.50|52.88 → 92.38|53.04 → 97.33|80.00 → 78.88| | | | | | | | | | | | -|pen-cloned-v1|88.66 ± 15.10 → 86.82 ± 11.12|-2.76 ± 0.08 → -1.28 ± 2.16|84.19 ± 3.96 → 102.02 ± 20.75|6.19 ± 5.21 → 43.63 ± 20.09|-2.66 ± 0.04 → -2.68 ± 0.12| -|door-cloned-v1|0.93 ± 1.66 → 0.01 ± 0.00|-0.33 ± 0.01 → -0.33 ± 0.01|1.19 ± 0.93 → 20.34 ± 9.32|-0.21 ± 0.14 → 0.02 ± 0.31|-0.33 ± 0.01 → -0.33 ± 0.01| -|hammer-cloned-v1|1.80 ± 3.01 → 0.24 ± 0.04|0.56 ± 0.55 → 2.85 ± 4.81|1.35 ± 0.32 → 57.27 ± 28.49|3.97 ± 6.39 → 3.73 ± 4.99|0.25 ± 0.04 → 0.17 ± 0.17| -|relocate-cloned-v1|-0.04 ± 0.04 → -0.04 ± 0.01|-0.33 ± 0.01 → -0.33 ± 0.01|0.04 ± 0.04 → 0.32 ± 0.38|-0.24 ± 0.01 → -0.15 ± 0.05|-0.31 ± 0.05 → -0.31 ± 0.04| +|pen-cloned-v1|88.66 ± 15.10 → 86.82 ± 11.12|-2.76 ± 0.08 → -1.28 ± 2.16|84.19 ± 3.96 → 102.02 ± 20.75|6.19 ± 5.21 → 43.63 ± 20.09|-2.66 ± 0.04 → -2.68 ± 0.12|74.04 ± 11.97 → 138.15 ± 3.22| +|door-cloned-v1|0.93 ± 1.66 → 0.01 ± 0.00|-0.33 ± 0.01 → -0.33 ± 0.01|1.19 ± 0.93 → 20.34 ± 9.32|-0.21 ± 0.14 → 0.02 ± 0.31|-0.33 ± 0.01 → -0.33 ± 0.01|0.07 ± 0.04 → 102.39 ± 8.27| +|hammer-cloned-v1|1.80 ± 3.01 → 0.24 ± 0.04|0.56 ± 0.55 → 2.85 ± 4.81|1.35 ± 0.32 → 57.27 ± 28.49|3.97 ± 6.39 → 3.73 ± 4.99|0.25 ± 0.04 → 0.17 ± 0.17|6.54 ± 3.35 → 124.65 ± 7.37| +|relocate-cloned-v1|-0.04 ± 0.04 → -0.04 ± 0.01|-0.33 ± 0.01 → -0.33 ± 0.01|0.04 ± 0.04 → 0.32 ± 0.38|-0.24 ± 0.01 → -0.15 ± 0.05|-0.31 ± 0.05 → -0.31 ± 0.04|0.70 ± 0.62 → 6.96 ± 4.59| | | | | | | | | | | | -| **adroit average** |22.84 → 21.76|-0.72 → 0.22|21.69 → 44.99|2.43 → 11.81|-0.76 → -0.79| +| **adroit average** |22.84 → 21.76|-0.72 → 0.22|21.69 → 44.99|2.43 → 11.81|-0.76 → -0.79|20.33 → 93.04| #### Regrets -| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL| -|---------------------------|------------|--------|--------|-----|-----| -|antmaze-umaze-v2|0.04 ± 0.01|0.02 ± 0.00|0.07 ± 0.00|0.02 ± 0.00|0.01 ± 0.00| -|antmaze-umaze-diverse-v2|0.88 ± 0.01|0.09 ± 0.01|0.43 ± 0.11|0.22 ± 0.07|0.05 ± 0.01| -|antmaze-medium-play-v2|1.00 ± 0.00|0.08 ± 0.01|0.09 ± 0.01|0.06 ± 0.00|0.04 ± 0.01| -|antmaze-medium-diverse-v2|1.00 ± 0.00|0.08 ± 0.00|0.10 ± 0.01|0.05 ± 0.01|0.04 ± 0.01| -|antmaze-large-play-v2|1.00 ± 0.00|0.21 ± 0.02|0.34 ± 0.05|0.29 ± 0.07|0.13 ± 0.02| -|antmaze-large-diverse-v2|1.00 ± 0.00|0.21 ± 0.03|0.41 ± 0.03|0.23 ± 0.08|0.13 ± 0.02| +| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL|ReBRAC| +|---------------------------|------------|--------|--------|-----|-----|-----| +|antmaze-umaze-v2|0.04 ± 0.01|0.02 ± 0.00|0.07 ± 0.00|0.02 ± 0.00|0.01 ± 0.00|0.11 ± 0.18| +|antmaze-umaze-diverse-v2|0.88 ± 0.01|0.09 ± 0.01|0.43 ± 0.11|0.22 ± 0.07|0.05 ± 0.01|0.04 ± 0.02| +|antmaze-medium-play-v2|1.00 ± 0.00|0.08 ± 0.01|0.09 ± 0.01|0.06 ± 0.00|0.04 ± 0.01|0.03 ± 0.01| +|antmaze-medium-diverse-v2|1.00 ± 0.00|0.08 ± 0.00|0.10 ± 0.01|0.05 ± 0.01|0.04 ± 0.01|0.03 ± 0.00| +|antmaze-large-play-v2|1.00 ± 0.00|0.21 ± 0.02|0.34 ± 0.05|0.29 ± 0.07|0.13 ± 0.02|0.14 ± 0.05| +|antmaze-large-diverse-v2|1.00 ± 0.00|0.21 ± 0.03|0.41 ± 0.03|0.23 ± 0.08|0.13 ± 0.02|0.29 ± 0.39| | | | | | | | | | | | -| **antmaze average** |0.82|0.11|0.24|0.15|0.07| +| **antmaze average** |0.82|0.11|0.24|0.15|0.07|0.11| | | | | | | | | | | | -|pen-cloned-v1|0.46 ± 0.02|0.97 ± 0.00|0.37 ± 0.01|0.58 ± 0.02|0.98 ± 0.01| -|door-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.83 ± 0.03|0.99 ± 0.01|1.00 ± 0.00| -|hammer-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.65 ± 0.10|0.98 ± 0.01|1.00 ± 0.00| -|relocate-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00| +|pen-cloned-v1|0.46 ± 0.02|0.97 ± 0.00|0.37 ± 0.01|0.58 ± 0.02|0.98 ± 0.01|0.08 ± 0.01| +|door-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.83 ± 0.03|0.99 ± 0.01|1.00 ± 0.00|0.19 ± 0.05| +|hammer-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.65 ± 0.10|0.98 ± 0.01|1.00 ± 0.00|0.13 ± 0.03| +|relocate-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|0.90 ± 0.06| | | | | | | | | | | | -| **adroit average** |0.86|0.99|0.71|0.89|0.99| +| **adroit average** |0.86|0.99|0.71|0.89|0.99|0.33| ## Citing CORL diff --git a/algorithms/finetune/rebrac.py b/algorithms/finetune/rebrac.py new file mode 100644 index 00000000..26b0d661 --- /dev/null +++ b/algorithms/finetune/rebrac.py @@ -0,0 +1,1099 @@ +import os + +os.environ["TF_CUDNN_DETERMINISTIC"] = "1" + +import math +import random +import uuid +from copy import deepcopy +from dataclasses import asdict, dataclass +from functools import partial +from typing import Any, Callable, Dict, Sequence, Tuple, Union + +import chex +import d4rl # noqa +import flax.linen as nn +import gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +import pyrallis +import tqdm +import wandb +from flax.core import FrozenDict +from flax.training.train_state import TrainState +from tqdm.auto import trange + +ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate") + +default_kernel_init = nn.initializers.lecun_normal() +default_bias_init = nn.initializers.zeros + + +@dataclass +class Config: + # wandb params + project: str = "ReBRAC" + group: str = "rebrac-finetune" + name: str = "rebrac-finetune" + # model params + actor_learning_rate: float = 3e-4 + critic_learning_rate: float = 3e-4 + hidden_dim: int = 256 + actor_n_hiddens: int = 3 + critic_n_hiddens: int = 3 + replay_buffer_size: int = 2_000_000 + mixing_ratio: float = 0.5 + gamma: float = 0.99 + tau: float = 5e-3 + actor_bc_coef: float = 1.0 + critic_bc_coef: float = 1.0 + actor_ln: bool = False + critic_ln: bool = True + policy_noise: float = 0.2 + noise_clip: float = 0.5 + expl_noise: float = 0.0 + policy_freq: int = 2 + normalize_q: bool = True + min_decay_coef: float = 0.5 + use_calibration: bool = False + reset_opts: bool = False + # training params + dataset_name: str = "halfcheetah-medium-v2" + batch_size: int = 256 + num_offline_updates: int = 1_000_000 + num_online_updates: int = 1_000_000 + num_warmup_steps: int = 0 + normalize_reward: bool = False + normalize_states: bool = False + # evaluation params + eval_episodes: int = 10 + eval_every: int = 5000 + # general params + train_seed: int = 10 + eval_seed: int = 42 + + def __post_init__(self): + self.name = f"{self.name}-{self.dataset_name}-{str(uuid.uuid4())[:8]}" + + +def pytorch_init(fan_in: float) -> Callable: + """ + Default init for PyTorch Linear layer weights and biases: + https://pytorch.org/docs/stable/generated/torch.nn.Linear.html + """ + bound = math.sqrt(1 / fan_in) + + def _init(key: jax.random.PRNGKey, shape: Tuple, dtype: type) -> jax.Array: + return jax.random.uniform( + key, shape=shape, minval=-bound, maxval=bound, dtype=dtype + ) + + return _init + + +def uniform_init(bound: float) -> Callable: + def _init(key: jax.random.PRNGKey, shape: Tuple, dtype: type) -> jax.Array: + return jax.random.uniform( + key, shape=shape, minval=-bound, maxval=bound, dtype=dtype + ) + + return _init + + +def identity(x: Any) -> Any: + return x + + +class DetActor(nn.Module): + action_dim: int + hidden_dim: int = 256 + layernorm: bool = True + n_hiddens: int = 3 + + @nn.compact + def __call__(self, state: jax.Array) -> jax.Array: + s_d, h_d = state.shape[-1], self.hidden_dim + # Initialization as in the EDAC paper + layers = [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(s_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + for _ in range(self.n_hiddens - 1): + layers += [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(h_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + layers += [ + nn.Dense( + self.action_dim, + kernel_init=uniform_init(1e-3), + bias_init=uniform_init(1e-3), + ), + nn.tanh, + ] + net = nn.Sequential(layers) + actions = net(state) + return actions + + +class Critic(nn.Module): + hidden_dim: int = 256 + layernorm: bool = True + n_hiddens: int = 3 + + @nn.compact + def __call__(self, state: jax.Array, action: jax.Array) -> jax.Array: + s_d, a_d, h_d = state.shape[-1], action.shape[-1], self.hidden_dim + # Initialization as in the EDAC paper + layers = [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(s_d + a_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + for _ in range(self.n_hiddens - 1): + layers += [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(h_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + layers += [ + nn.Dense(1, kernel_init=uniform_init(3e-3), bias_init=uniform_init(3e-3)) + ] + network = nn.Sequential(layers) + state_action = jnp.hstack([state, action]) + out = network(state_action).squeeze(-1) + return out + + +class EnsembleCritic(nn.Module): + hidden_dim: int = 256 + num_critics: int = 10 + layernorm: bool = True + n_hiddens: int = 3 + + @nn.compact + def __call__(self, state: jax.Array, action: jax.Array) -> jax.Array: + ensemble = nn.vmap( + target=Critic, + in_axes=None, + out_axes=0, + variable_axes={"params": 0}, + split_rngs={"params": True}, + axis_size=self.num_critics, + ) + q_values = ensemble(self.hidden_dim, self.layernorm, self.n_hiddens)( + state, action + ) + return q_values + + +def calc_return_to_go(is_sparse_reward, rewards, terminals, gamma): + """ + A config dict for getting the default high/low rewrd values for each envs + This is used in calc_return_to_go func in sampler.py and replay_buffer.py + """ + if len(rewards) == 0: + return [] + reward_neg = 0 + if is_sparse_reward and np.all(np.array(rewards) == reward_neg): + """ + If the env has sparse reward and the trajectory is all negative rewards, + we use r / (1-gamma) as return to go. + For exapmle, if gamma = 0.99 and the rewards = [-1, -1, -1], + then return_to_go = [-100, -100, -100] + """ + # assuming failure reward is negative + # use r / (1-gamma) for negative trajctory + return_to_go = [float(reward_neg / (1 - gamma))] * len(rewards) + else: + return_to_go = [0] * len(rewards) + prev_return = 0 + for i in range(len(rewards)): + return_to_go[-i - 1] = \ + rewards[-i - 1] + gamma * prev_return * (1 - terminals[-i - 1]) + prev_return = return_to_go[-i - 1] + + return return_to_go + + +def qlearning_dataset( + env, dataset_name, + normalize_reward=False, dataset=None, + terminate_on_end=False, discount=0.99, **kwargs +): + """ + Returns datasets formatted for use by standard Q-learning algorithms, + with observations, actions, next_observations, next_actins, rewards, + and a terminal flag. + Args: + env: An OfflineEnv object. + dataset: An optional dataset to pass in for processing. If None, + the dataset will default to env.get_dataset() + terminate_on_end (bool): Set done=True on the last timestep + in a trajectory. Default is False, and will discard the + last timestep in each trajectory. + **kwargs: Arguments to pass to env.get_dataset(). + Returns: + A dictionary containing keys: + observations: An N x dim_obs array of observations. + actions: An N x dim_action array of actions. + next_observations: An N x dim_obs array of next observations. + next_actions: An N x dim_action array of next actions. + rewards: An N-dim float array of rewards. + terminals: An N-dim boolean array of "done" or episode termination flags. + """ + if dataset is None: + dataset = env.get_dataset(**kwargs) + if normalize_reward: + dataset['rewards'] = ReplayBuffer.normalize_reward( + dataset_name, dataset['rewards'] + ) + N = dataset['rewards'].shape[0] + is_sparse = "antmaze" in dataset_name + obs_ = [] + next_obs_ = [] + action_ = [] + next_action_ = [] + reward_ = [] + done_ = [] + mc_returns_ = [] + print("SIZE", N) + # The newer version of the dataset adds an explicit + # timeouts field. Keep old method for backwards compatability. + use_timeouts = 'timeouts' in dataset + + episode_step = 0 + episode_rewards = [] + episode_terminals = [] + for i in range(N - 1): + if episode_step == 0: + episode_rewards = [] + episode_terminals = [] + + obs = dataset['observations'][i].astype(np.float32) + new_obs = dataset['observations'][i + 1].astype(np.float32) + action = dataset['actions'][i].astype(np.float32) + new_action = dataset['actions'][i + 1].astype(np.float32) + reward = dataset['rewards'][i].astype(np.float32) + done_bool = bool(dataset['terminals'][i]) + + if use_timeouts: + final_timestep = dataset['timeouts'][i] + else: + final_timestep = (episode_step == env._max_episode_steps - 1) + if (not terminate_on_end) and final_timestep: + # Skip this transition and don't apply terminals on the last step of episode + episode_step = 0 + mc_returns_ += calc_return_to_go( + is_sparse, episode_rewards, episode_terminals, discount + ) + continue + if done_bool or final_timestep: + episode_step = 0 + + episode_rewards.append(reward) + episode_terminals.append(done_bool) + + obs_.append(obs) + next_obs_.append(new_obs) + action_.append(action) + next_action_.append(new_action) + reward_.append(reward) + done_.append(done_bool) + episode_step += 1 + if episode_step != 0: + mc_returns_ += calc_return_to_go( + is_sparse, episode_rewards, episode_terminals, discount + ) + assert np.array(mc_returns_).shape == np.array(reward_).shape + return { + 'observations': np.array(obs_), + 'actions': np.array(action_), + 'next_observations': np.array(next_obs_), + 'next_actions': np.array(next_action_), + 'rewards': np.array(reward_), + 'terminals': np.array(done_), + 'mc_returns': np.array(mc_returns_), + } + + +def compute_mean_std(states: jax.Array, eps: float) -> Tuple[jax.Array, jax.Array]: + mean = states.mean(0) + std = states.std(0) + eps + return mean, std + + +def normalize_states(states: jax.Array, mean: jax.Array, std: jax.Array) -> jax.Array: + return (states - mean) / std + + +@chex.dataclass +class ReplayBuffer: + data: Dict[str, jax.Array] = None + mean: float = 0 + std: float = 1 + + def create_from_d4rl( + self, + dataset_name: str, + normalize_reward: bool = False, + is_normalize: bool = False, + ): + d4rl_data = qlearning_dataset(gym.make(dataset_name), dataset_name) + buffer = { + "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), + "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), + "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32), + "next_states": jnp.asarray( + d4rl_data["next_observations"], dtype=jnp.float32 + ), + "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32), + "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32), + } + if is_normalize: + self.mean, self.std = compute_mean_std(buffer["states"], eps=1e-3) + buffer["states"] = normalize_states(buffer["states"], self.mean, self.std) + buffer["next_states"] = normalize_states( + buffer["next_states"], self.mean, self.std + ) + if normalize_reward: + buffer["rewards"] = ReplayBuffer.normalize_reward( + dataset_name, buffer["rewards"] + ) + self.data = buffer + + @property + def size(self) -> int: + # WARN: It will use len of the dataclass, i.e. number of fields. + return self.data["states"].shape[0] + + def sample_batch( + self, key: jax.random.PRNGKey, batch_size: int + ) -> Dict[str, jax.Array]: + indices = jax.random.randint( + key, shape=(batch_size,), minval=0, maxval=self.size + ) + batch = jax.tree_map(lambda arr: arr[indices], self.data) + return batch + + def get_moments(self, modality: str) -> Tuple[jax.Array, jax.Array]: + mean = self.data[modality].mean(0) + std = self.data[modality].std(0) + return mean, std + + @staticmethod + def normalize_reward(dataset_name: str, rewards: jax.Array) -> jax.Array: + if "antmaze" in dataset_name: + return rewards * 100.0 # like in LAPO + else: + raise NotImplementedError( + "Reward normalization is implemented only for AntMaze yet!" + ) + + +class Dataset(object): + def __init__(self, observations: np.ndarray, actions: np.ndarray, + rewards: np.ndarray, masks: np.ndarray, + dones_float: np.ndarray, next_observations: np.ndarray, + next_actions: np.ndarray, + mc_returns: np.ndarray, + size: int): + self.observations = observations + self.actions = actions + self.rewards = rewards + self.masks = masks + self.dones_float = dones_float + self.next_observations = next_observations + self.next_actions = next_actions + self.mc_returns = mc_returns + self.size = size + + def sample(self, batch_size: int) -> Dict[str, np.ndarray]: + indx = np.random.randint(self.size, size=batch_size) + return { + "states": self.observations[indx], + "actions": self.actions[indx], + "rewards": self.rewards[indx], + "dones": self.dones_float[indx], + "next_states": self.next_observations[indx], + "next_actions": self.next_actions[indx], + "mc_returns": self.mc_returns[indx], + } + + +class OnlineReplayBuffer(Dataset): + def __init__(self, observation_space: gym.spaces.Box, action_dim: int, + capacity: int): + + observations = np.empty((capacity, *observation_space.shape), + dtype=observation_space.dtype) + actions = np.empty((capacity, action_dim), dtype=np.float32) + rewards = np.empty((capacity,), dtype=np.float32) + mc_returns = np.empty((capacity,), dtype=np.float32) + masks = np.empty((capacity,), dtype=np.float32) + dones_float = np.empty((capacity,), dtype=np.float32) + next_observations = np.empty((capacity, *observation_space.shape), + dtype=observation_space.dtype) + next_actions = np.empty((capacity, action_dim), dtype=np.float32) + super().__init__(observations=observations, + actions=actions, + rewards=rewards, + masks=masks, + dones_float=dones_float, + next_observations=next_observations, + next_actions=next_actions, + mc_returns=mc_returns, + size=0) + + self.size = 0 + + self.insert_index = 0 + self.capacity = capacity + + def initialize_with_dataset(self, dataset: Dataset, + num_samples=None): + assert self.insert_index == 0, \ + 'Can insert a batch online in an empty replay buffer.' + + dataset_size = len(dataset.observations) + + if num_samples is None: + num_samples = dataset_size + else: + num_samples = min(dataset_size, num_samples) + assert self.capacity >= num_samples, \ + 'Dataset cannot be larger than the replay buffer capacity.' + + if num_samples < dataset_size: + perm = np.random.permutation(dataset_size) + indices = perm[:num_samples] + else: + indices = np.arange(num_samples) + + self.observations[:num_samples] = dataset.observations[indices] + self.actions[:num_samples] = dataset.actions[indices] + self.rewards[:num_samples] = dataset.rewards[indices] + self.masks[:num_samples] = dataset.masks[indices] + self.dones_float[:num_samples] = dataset.dones_float[indices] + self.next_observations[:num_samples] = dataset.next_observations[ + indices] + self.next_actions[:num_samples] = dataset.next_actions[ + indices] + self.mc_returns[:num_samples] = dataset.mc_returns[indices] + + self.insert_index = num_samples + self.size = num_samples + + def insert(self, observation: np.ndarray, action: np.ndarray, + reward: float, mask: float, done_float: float, + next_observation: np.ndarray, + next_action: np.ndarray, mc_return: np.ndarray): + self.observations[self.insert_index] = observation + self.actions[self.insert_index] = action + self.rewards[self.insert_index] = reward + self.masks[self.insert_index] = mask + self.dones_float[self.insert_index] = done_float + self.next_observations[self.insert_index] = next_observation + self.next_actions[self.insert_index] = next_action + self.mc_returns[self.insert_index] = mc_return + + self.insert_index = (self.insert_index + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) + + +class D4RLDataset(Dataset): + def __init__( + self, + env: gym.Env, + env_name: str, + normalize_reward: bool, + discount: float, + ): + d4rl_data = qlearning_dataset( + env, env_name, normalize_reward=normalize_reward, discount=discount + ) + dataset = { + "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), + "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), + "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32), + "next_states": jnp.asarray( + d4rl_data["next_observations"], dtype=jnp.float32 + ), + "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32), + "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32), + "mc_returns": jnp.asarray(d4rl_data["mc_returns"], dtype=jnp.float32) + } + + super().__init__(dataset['states'].astype(np.float32), + actions=dataset['actions'].astype(np.float32), + rewards=dataset['rewards'].astype(np.float32), + masks=1.0 - dataset['dones'].astype(np.float32), + dones_float=dataset['dones'].astype(np.float32), + next_observations=dataset['next_states'].astype( + np.float32), + next_actions=dataset["next_actions"], + mc_returns=dataset["mc_returns"], + size=len(dataset['states'])) + + +def concat_batches(b1, b2): + new_batch = {} + for k in b1: + new_batch[k] = np.concatenate((b1[k], b2[k]), axis=0) + return new_batch + + +@chex.dataclass(frozen=True) +class Metrics: + accumulators: Dict[str, Tuple[jax.Array, jax.Array]] + + @staticmethod + def create(metrics: Sequence[str]) -> "Metrics": + init_metrics = {key: (jnp.array([0.0]), jnp.array([0.0])) for key in metrics} + return Metrics(accumulators=init_metrics) + + def update(self, updates: Dict[str, jax.Array]) -> "Metrics": + new_accumulators = deepcopy(self.accumulators) + for key, value in updates.items(): + acc, steps = new_accumulators[key] + new_accumulators[key] = (acc + value, steps + 1) + + return self.replace(accumulators=new_accumulators) + + def compute(self) -> Dict[str, np.ndarray]: + # cumulative_value / total_steps + return {k: np.array(v[0] / v[1]) for k, v in self.accumulators.items()} + + +def normalize( + arr: jax.Array, mean: jax.Array, std: jax.Array, eps: float = 1e-8 +) -> jax.Array: + return (arr - mean) / (std + eps) + + +def make_env(env_name: str, seed: int) -> gym.Env: + env = gym.make(env_name) + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + +def wrap_env( + env: gym.Env, + state_mean: Union[np.ndarray, float] = 0.0, + state_std: Union[np.ndarray, float] = 1.0, + reward_scale: float = 1.0, +) -> gym.Env: + # PEP 8: E731 do not assign a lambda expression, use a def + def normalize_state(state: np.ndarray) -> np.ndarray: + return ( + state - state_mean + ) / state_std # epsilon should be already added in std. + + def scale_reward(reward: float) -> float: + # Please be careful, here reward is multiplied by scale! + return reward_scale * reward + + env = gym.wrappers.TransformObservation(env, normalize_state) + if reward_scale != 1.0: + env = gym.wrappers.TransformReward(env, scale_reward) + return env + + +def make_env_and_dataset(env_name: str, + seed: int, + normalize_reward: bool, + discount: float) -> Tuple[gym.Env, D4RLDataset]: + env = gym.make(env_name) + + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + + dataset = D4RLDataset(env, env_name, normalize_reward, discount=discount) + + return env, dataset + + +def is_goal_reached(reward: float, info: Dict) -> bool: + if "goal_achieved" in info: + return info["goal_achieved"] + return reward > 0 # Assuming that reaching target is a positive reward + + +def evaluate( + env: gym.Env, params, action_fn: Callable, num_episodes: int, seed: int +) -> Tuple[np.ndarray, np.ndarray]: + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + + returns = [] + successes = [] + for _ in trange(num_episodes, desc="Eval", leave=False): + obs, done = env.reset(), False + goal_achieved = False + total_reward = 0.0 + while not done: + action = np.asarray(jax.device_get(action_fn(params, obs))) + obs, reward, done, info = env.step(action) + total_reward += reward + if not goal_achieved: + goal_achieved = is_goal_reached(reward, info) + successes.append(float(goal_achieved)) + returns.append(total_reward) + + return np.array(returns), np.mean(successes) + + +class CriticTrainState(TrainState): + target_params: FrozenDict + + +class ActorTrainState(TrainState): + target_params: FrozenDict + + +@jax.jit +def update_actor( + key: jax.random.PRNGKey, + actor: TrainState, + critic: TrainState, + batch: Dict[str, jax.Array], + beta: float, + tau: float, + normalize_q: bool, + metrics: Metrics, +) -> Tuple[jax.random.PRNGKey, TrainState, jax.Array, Metrics]: + key, random_action_key = jax.random.split(key, 2) + + def actor_loss_fn(params): + actions = actor.apply_fn(params, batch["states"]) + + bc_penalty = ((actions - batch["actions"]) ** 2).sum(-1) + q_values = critic.apply_fn(critic.params, batch["states"], actions).min(0) + # lmbda = 1 + # # if normalize_q: + lmbda = jax.lax.stop_gradient(1 / jax.numpy.abs(q_values).mean()) + + loss = (beta * bc_penalty - lmbda * q_values).mean() + + # logging stuff + random_actions = jax.random.uniform( + random_action_key, shape=batch["actions"].shape, minval=-1.0, maxval=1.0 + ) + new_metrics = metrics.update({ + "actor_loss": loss, + "bc_mse_policy": bc_penalty.mean(), + "bc_mse_random": ((random_actions - batch["actions"]) ** 2).sum(-1).mean(), + "action_mse": ((actions - batch["actions"]) ** 2).mean() + }) + return loss, new_metrics + + grads, new_metrics = jax.grad(actor_loss_fn, has_aux=True)(actor.params) + new_actor = actor.apply_gradients(grads=grads) + + new_actor = new_actor.replace( + target_params=optax.incremental_update(actor.params, actor.target_params, tau) + ) + new_critic = critic.replace( + target_params=optax.incremental_update(critic.params, critic.target_params, tau) + ) + + return key, new_actor, new_critic, new_metrics + + +def update_critic( + key: jax.random.PRNGKey, + actor: TrainState, + critic: CriticTrainState, + batch: Dict[str, jax.Array], + gamma: float, + beta: float, + tau: float, + policy_noise: float, + noise_clip: float, + use_calibration: bool, + metrics: Metrics, +) -> Tuple[jax.random.PRNGKey, TrainState, Metrics]: + key, actions_key = jax.random.split(key) + + next_actions = actor.apply_fn(actor.target_params, batch["next_states"]) + noise = jax.numpy.clip( + (jax.random.normal(actions_key, next_actions.shape) * policy_noise), + -noise_clip, + noise_clip, + ) + next_actions = jax.numpy.clip(next_actions + noise, -1, 1) + + bc_penalty = ((next_actions - batch["next_actions"]) ** 2).sum(-1) + next_q = critic.apply_fn( + critic.target_params, batch["next_states"], next_actions + ).min(0) + next_q = next_q - beta * bc_penalty + target_q = jax.lax.cond( + use_calibration, + lambda: jax.numpy.maximum( + batch["rewards"] + (1 - batch["dones"]) * gamma * next_q, + batch['mc_returns'] + ), + lambda: batch["rewards"] + (1 - batch["dones"]) * gamma * next_q + ) + + def critic_loss_fn(critic_params): + # [N, batch_size] - [1, batch_size] + q = critic.apply_fn(critic_params, batch["states"], batch["actions"]) + q_min = q.min(0).mean() + loss = ((q - target_q[None, ...]) ** 2).mean(1).sum(0) + return loss, q_min + + (loss, q_min), grads = jax.value_and_grad( + critic_loss_fn, has_aux=True + )(critic.params) + new_critic = critic.apply_gradients(grads=grads) + new_metrics = metrics.update({ + "critic_loss": loss, + "q_min": q_min, + }) + return key, new_critic, new_metrics + + +@jax.jit +def update_td3( + key: jax.random.PRNGKey, + actor: TrainState, + critic: CriticTrainState, + batch: Dict[str, Any], + metrics: Metrics, + gamma: float, + actor_bc_coef: float, + critic_bc_coef: float, + tau: float, + policy_noise: float, + noise_clip: float, + normalize_q: bool, + use_calibration: bool, +): + key, new_critic, new_metrics = update_critic( + key, actor, critic, batch, gamma, critic_bc_coef, tau, + policy_noise, noise_clip, use_calibration, metrics + ) + key, new_actor, new_critic, new_metrics = update_actor( + key, actor, new_critic, batch, actor_bc_coef, tau, + normalize_q, new_metrics + ) + return key, new_actor, new_critic, new_metrics + + +@jax.jit +def update_td3_no_targets( + key: jax.random.PRNGKey, + actor: TrainState, + critic: CriticTrainState, + batch: Dict[str, Any], + gamma: float, + metrics: Metrics, + actor_bc_coef: float, + critic_bc_coef: float, + tau: float, + policy_noise: float, + noise_clip: float, + use_calibration: bool, +): + key, new_critic, new_metrics = update_critic( + key, actor, critic, batch, gamma, critic_bc_coef, tau, + policy_noise, noise_clip, use_calibration, metrics + ) + return key, actor, new_critic, new_metrics + + +def action_fn(actor: TrainState) -> Callable: + @jax.jit + def _action_fn(obs: jax.Array) -> jax.Array: + action = actor.apply_fn(actor.params, obs) + return action + + return _action_fn + + +@pyrallis.wrap() +def train(config: Config): + dict_config = asdict(config) + dict_config["mlc_job_name"] = os.environ.get("PLATFORM_JOB_NAME") + is_env_with_goal = config.dataset_name.startswith(ENVS_WITH_GOAL) + np.random.seed(config.train_seed) + random.seed(config.train_seed) + + wandb.init( + config=dict_config, + project=config.project, + group=config.group, + name=config.name, + id=str(uuid.uuid4()), + ) + buffer = ReplayBuffer() + buffer.create_from_d4rl( + config.dataset_name, config.normalize_reward, config.normalize_states + ) + + key = jax.random.PRNGKey(seed=config.train_seed) + key, actor_key, critic_key = jax.random.split(key, 3) + + init_state = buffer.data["states"][0][None, ...] + init_action = buffer.data["actions"][0][None, ...] + + actor_module = DetActor( + action_dim=init_action.shape[-1], hidden_dim=config.hidden_dim, + layernorm=config.actor_ln, n_hiddens=config.actor_n_hiddens + ) + actor = ActorTrainState.create( + apply_fn=actor_module.apply, + params=actor_module.init(actor_key, init_state), + target_params=actor_module.init(actor_key, init_state), + tx=optax.adam(learning_rate=config.actor_learning_rate), + ) + + critic_module = EnsembleCritic( + hidden_dim=config.hidden_dim, num_critics=2, + layernorm=config.critic_ln, n_hiddens=config.critic_n_hiddens + ) + critic = CriticTrainState.create( + apply_fn=critic_module.apply, + params=critic_module.init(critic_key, init_state, init_action), + target_params=critic_module.init(critic_key, init_state, init_action), + tx=optax.adam(learning_rate=config.critic_learning_rate), + ) + + # metrics + bc_metrics_to_log = [ + "critic_loss", "q_min", "actor_loss", "batch_entropy", + "bc_mse_policy", "bc_mse_random", "action_mse" + ] + # shared carry for update loops + carry = { + "key": key, + "actor": actor, + "critic": critic, + "buffer": buffer, + "delayed_updates": jax.numpy.equal( + jax.numpy.arange( + config.num_offline_updates + config.num_online_updates + ) % config.policy_freq, 0 + ).astype(int) + } + + # Online + offline tuning + env, dataset = make_env_and_dataset( + config.dataset_name, config.train_seed, False, discount=config.gamma + ) + eval_env, _ = make_env_and_dataset( + config.dataset_name, config.eval_seed, False, discount=config.gamma + ) + + max_steps = env._max_episode_steps + + action_dim = env.action_space.shape[0] + replay_buffer = OnlineReplayBuffer(env.observation_space, action_dim, + config.replay_buffer_size) + replay_buffer.initialize_with_dataset(dataset, None) + online_buffer = OnlineReplayBuffer( + env.observation_space, action_dim, config.replay_buffer_size + ) + + online_batch_size = 0 + offline_batch_size = config.batch_size + + observation, done = env.reset(), False + episode_step = 0 + goal_achieved = False + + @jax.jit + def actor_action_fn(params, obs): + return actor.apply_fn(params, obs) + + eval_successes = [] + train_successes = [] + print("Offline training") + for i in tqdm.tqdm( + range(config.num_online_updates + config.num_offline_updates), + smoothing=0.1 + ): + carry["metrics"] = Metrics.create(bc_metrics_to_log) + if i == config.num_offline_updates: + print("Online training") + + online_batch_size = int(config.mixing_ratio * config.batch_size) + offline_batch_size = config.batch_size - online_batch_size + # Reset optimizers similar to SPOT + if config.reset_opts: + actor = actor.replace( + opt_state=optax.adam(learning_rate=config.actor_learning_rate).init(actor.params) + ) + critic = critic.replace( + opt_state=optax.adam(learning_rate=config.critic_learning_rate).init(critic.params) + ) + + update_td3_partial = partial( + update_td3, gamma=config.gamma, + tau=config.tau, + policy_noise=config.policy_noise, + noise_clip=config.noise_clip, + normalize_q=config.normalize_q, + use_calibration=config.use_calibration, + ) + + update_td3_no_targets_partial = partial( + update_td3_no_targets, gamma=config.gamma, + tau=config.tau, + policy_noise=config.policy_noise, + noise_clip=config.noise_clip, + use_calibration=config.use_calibration, + ) + online_log = {} + + if i >= config.num_offline_updates: + episode_step += 1 + action = np.asarray(actor_action_fn(carry["actor"].params, observation)) + action = np.array( + [ + ( + action + + np.random.normal(0, 1 * config.expl_noise, size=action_dim) + ).clip(-1, 1) + ] + )[0] + + next_observation, reward, done, info = env.step(action) + if not goal_achieved: + goal_achieved = is_goal_reached(reward, info) + next_action = np.asarray( + actor_action_fn(carry["actor"].params, next_observation) + )[0] + next_action = np.array( + [ + ( + next_action + + np.random.normal(0, 1 * config.expl_noise, size=action_dim) + ).clip(-1, 1) + ] + )[0] + + if not done or 'TimeLimit.truncated' in info: + mask = 1.0 + else: + mask = 0.0 + real_done = False + if done and episode_step < max_steps: + real_done = True + + online_buffer.insert(observation, action, reward, mask, + float(real_done), next_observation, next_action, 0) + observation = next_observation + if done: + train_successes.append(goal_achieved) + observation, done = env.reset(), False + episode_step = 0 + goal_achieved = False + + if config.num_offline_updates <= \ + i < \ + config.num_offline_updates + config.num_warmup_steps: + continue + + offline_batch = replay_buffer.sample(offline_batch_size) + online_batch = online_buffer.sample(online_batch_size) + batch = concat_batches(offline_batch, online_batch) + + if 'antmaze' in config.dataset_name and config.normalize_reward: + batch["rewards"] *= 100 + + ### Update step + actor_bc_coef = config.actor_bc_coef + critic_bc_coef = config.critic_bc_coef + if i >= config.num_offline_updates: + lin_coef = ( + config.num_online_updates + + config.num_offline_updates - + i + config.num_warmup_steps + ) / config.num_online_updates + decay_coef = max(config.min_decay_coef, lin_coef) + actor_bc_coef *= decay_coef + critic_bc_coef *= 0 + if i % config.policy_freq == 0: + update_fn = partial(update_td3_partial, + actor_bc_coef=actor_bc_coef, + critic_bc_coef=critic_bc_coef, + key=key, + actor=carry["actor"], + critic=carry["critic"], + batch=batch, + metrics=carry["metrics"]) + else: + update_fn = partial(update_td3_no_targets_partial, + actor_bc_coef=actor_bc_coef, + critic_bc_coef=critic_bc_coef, + key=key, + actor=carry["actor"], + critic=carry["critic"], + batch=batch, + metrics=carry["metrics"]) + key, new_actor, new_critic, new_metrics = update_fn() + carry.update( + key=key, actor=new_actor, critic=new_critic, metrics=new_metrics + ) + + if i % 1000 == 0: + mean_metrics = carry["metrics"].compute() + common = {f"ReBRAC/{k}": v for k, v in mean_metrics.items()} + common["actor_bc_coef"] = actor_bc_coef + common["critic_bc_coef"] = critic_bc_coef + if i < config.num_offline_updates: + wandb.log({"offline_iter": i, **common}) + else: + wandb.log({"online_iter": i - config.num_offline_updates, **common}) + if i % config.eval_every == 0 or\ + i == config.num_offline_updates + config.num_online_updates - 1 or\ + i == config.num_offline_updates - 1: + eval_returns, success_rate = evaluate( + eval_env, carry["actor"].params, actor_action_fn, + config.eval_episodes, + seed=config.eval_seed + ) + normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0 + eval_successes.append(success_rate) + if is_env_with_goal: + online_log["train/regret"] = np.mean(1 - np.array(train_successes)) + offline_log = { + "eval/return_mean": np.mean(eval_returns), + "eval/return_std": np.std(eval_returns), + "eval/normalized_score_mean": np.mean(normalized_score), + "eval/normalized_score_std": np.std(normalized_score), + "eval/success_rate": success_rate + } + offline_log.update(online_log) + wandb.log(offline_log) + + +if __name__ == "__main__": + train() diff --git a/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml new file mode 100644 index 00000000..cd208878 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.002 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.002 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-large-diverse-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-large-diverse-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/large_play_v2.yaml b/configs/finetune/rebrac/antmaze/large_play_v2.yaml new file mode 100644 index 00000000..77e56655 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/large_play_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.002 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.001 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-large-play-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-large-play-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml new file mode 100644 index 00000000..d5d56435 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.001 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.0 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-medium-diverse-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-medium-diverse-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/medium_play_v2.yaml b/configs/finetune/rebrac/antmaze/medium_play_v2.yaml new file mode 100644 index 00000000..e1c4966a --- /dev/null +++ b/configs/finetune/rebrac/antmaze/medium_play_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.001 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.0005 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-medium-play-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-medium-play-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml new file mode 100644 index 00000000..41fdf292 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.003 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.001 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-umaze-diverse-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-umaze-diverse-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/umaze_v2.yaml b/configs/finetune/rebrac/antmaze/umaze_v2.yaml new file mode 100644 index 00000000..042b5232 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/umaze_v2.yaml @@ -0,0 +1,36 @@ +actor_bc_coef: 0.003 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.002 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-umaze-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-umaze-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: true +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false + diff --git a/configs/finetune/rebrac/door/cloned_v1.yaml b/configs/finetune/rebrac/door/cloned_v1.yaml new file mode 100644 index 00000000..06db15cd --- /dev/null +++ b/configs/finetune/rebrac/door/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.01 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.1 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: door-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-door-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/hammer/cloned_v1.yaml b/configs/finetune/rebrac/hammer/cloned_v1.yaml new file mode 100644 index 00000000..c55be6f2 --- /dev/null +++ b/configs/finetune/rebrac/hammer/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.1 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.5 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: hammer-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-hammer-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/pen/cloned_v1.yaml b/configs/finetune/rebrac/pen/cloned_v1.yaml new file mode 100644 index 00000000..9144ab9c --- /dev/null +++ b/configs/finetune/rebrac/pen/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.05 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.5 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: pen-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-pen-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/relocate/cloned_v1.yaml b/configs/finetune/rebrac/relocate/cloned_v1.yaml new file mode 100644 index 00000000..fefa5048 --- /dev/null +++ b/configs/finetune/rebrac/relocate/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.1 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.01 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: relocate-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-relocate-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/results/bin/finetune_scores.pickle b/results/bin/finetune_scores.pickle index d60377ec..a0a682d3 100644 Binary files a/results/bin/finetune_scores.pickle and b/results/bin/finetune_scores.pickle differ diff --git a/results/get_finetune_scores.py b/results/get_finetune_scores.py index dd31ad19..1fb9303d 100644 --- a/results/get_finetune_scores.py +++ b/results/get_finetune_scores.py @@ -32,9 +32,14 @@ def get_run_scores(run_id, is_dt=False): break for _, row in run.history(keys=[score_key], samples=5000).iterrows(): full_scores.append(row[score_key]) + + for _, row in run.history(keys=["train/regret"], samples=5000).iterrows(): + if "train/regret" in row: + regret = row["train/regret"] for _, row in run.history(keys=["eval/regret"], samples=5000).iterrows(): if "eval/regret" in row: regret = row["eval/regret"] + offline_iters = len(full_scores) // 2 return full_scores[:offline_iters], full_scores[offline_iters:], regret diff --git a/results/get_finetune_tables_and_plots.py b/results/get_finetune_tables_and_plots.py index d8072652..997bc15a 100644 --- a/results/get_finetune_tables_and_plots.py +++ b/results/get_finetune_tables_and_plots.py @@ -76,6 +76,7 @@ def get_last_scores(avg_scores, avg_stds): for algo in full_offline_scores: for data in full_offline_scores[algo]: + # print(algo, flush=True) full_offline_scores[algo][data] = [s[0] for s in full_scores[algo][data]] full_online_scores[algo][data] = [s[1] for s in full_scores[algo][data]] regrets[algo][data] = np.mean([s[2] for s in full_scores[algo][data]]) @@ -127,7 +128,7 @@ def add_domains_avg(scores): add_domains_avg(last_online_scores) add_domains_avg(regrets) -algorithms = ["AWAC", "CQL", "IQL", "SPOT", "Cal-QL"] +algorithms = ["AWAC", "CQL", "IQL", "SPOT", "Cal-QL", "ReBRAC",] datasets = dataframe["dataset"].unique() ordered_datasets = [ "antmaze-umaze-v2", @@ -417,8 +418,9 @@ def flatten(data): "CQL", "IQL", "AWAC", + "Cal-QL", ] -for a1 in ["Cal-QL"]: +for a1 in ["ReBRAC"]: for a2 in algs: algorithm_pairs[f"{a1},{a2}"] = (flat[a1], flat[a2]) average_probabilities, average_prob_cis = rly.get_interval_estimates( diff --git a/results/get_finetune_urls.py b/results/get_finetune_urls.py index 12d7e374..07048ca3 100644 --- a/results/get_finetune_urls.py +++ b/results/get_finetune_urls.py @@ -18,6 +18,8 @@ def get_urls(sweep_id, algo_name): dataset = run.config["env"] elif "env_name" in run.config: dataset = run.config["env_name"] + elif "dataset_name" in run.config: + dataset = run.config["dataset_name"] name = algo_name if "10" in "-".join(run.name.split("-")[:-1]): name = "10% " + name @@ -41,6 +43,8 @@ def get_urls(sweep_id, algo_name): get_urls("tlab/CORL/sweeps/efvz7d68", "Cal-QL") +get_urls("tlab/CORL/sweeps/62cgqb8c", "ReBRAC") + dataframe = pd.DataFrame(collected_urls) dataframe.to_csv("runs_tables/finetune_urls.csv", index=False) diff --git a/results/runs_tables/finetune_urls.csv b/results/runs_tables/finetune_urls.csv index 3058d67a..ae637030 100644 --- a/results/runs_tables/finetune_urls.csv +++ b/results/runs_tables/finetune_urls.csv @@ -32,8 +32,8 @@ SPOT,antmaze-large-diverse-v2,tlab/CORL/runs/0kn4pl04 SPOT,antmaze-large-diverse-v2,tlab/CORL/runs/onon4kbo SPOT,antmaze-large-play-v2,tlab/CORL/runs/5ldclyhi SPOT,antmaze-large-play-v2,tlab/CORL/runs/v8uskc0k -SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/sysutdr0 SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/tnksp757 +SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/sysutdr0 SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/cg7vkg7p SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/12ivynlo SPOT,antmaze-large-diverse-v2,tlab/CORL/runs/kvxnj3cw @@ -75,20 +75,20 @@ AWAC,antmaze-large-play-v2,tlab/CORL/runs/z88yrg2k AWAC,antmaze-large-play-v2,tlab/CORL/runs/b1e1up19 AWAC,antmaze-large-play-v2,tlab/CORL/runs/bpq150gg AWAC,antmaze-large-play-v2,tlab/CORL/runs/rffmurq2 -AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/tiq65215 AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/uo63dgzx +AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/tiq65215 AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/mfjh57xv AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/325fy2js CQL,antmaze-umaze-v2,tlab/CORL/runs/vdh2wmw9 CQL,antmaze-umaze-v2,tlab/CORL/runs/wh27aupq CQL,antmaze-umaze-v2,tlab/CORL/runs/7r4uwutz -CQL,antmaze-large-play-v2,tlab/CORL/runs/kt7jwqcz CQL,antmaze-umaze-v2,tlab/CORL/runs/l5xvgwt4 +CQL,antmaze-large-play-v2,tlab/CORL/runs/kt7jwqcz CQL,antmaze-large-play-v2,tlab/CORL/runs/8fm40vpm CQL,antmaze-large-play-v2,tlab/CORL/runs/yeax28su -CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/gvhslqyo -CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/mowkqr6u CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/pswhm9pi +CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/mowkqr6u +CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/gvhslqyo CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/hh5vv5qc CQL,door-cloned-v1,tlab/CORL/runs/2xc0y5sd CQL,door-cloned-v1,tlab/CORL/runs/2ylul8yr @@ -188,14 +188,54 @@ Cal-QL,antmaze-umaze-diverse-v2,tlab/CORL/runs/a015fjb1 Cal-QL,antmaze-umaze-diverse-v2,tlab/CORL/runs/1pu06s2i Cal-QL,antmaze-umaze-diverse-v2,tlab/CORL/runs/iwa1o31k Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/yvqv3mxa -Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/4myjeu5g Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/6ptdr78l +Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/4myjeu5g Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/8ix0469p -Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/4chdwkua Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/fzrlcnwp +Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/4chdwkua Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/f9hz4fal Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/fpq2ob8q Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/zhf7tr7p Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/m02ew5oy Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/9r1a0trx Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/ds2dbx2u +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/jt59ttpm +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/i8v1hu1q +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/qedc68y9 +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/69ax7dcn +ReBRAC,door-cloned-v1,tlab/CORL/runs/eqfcelaj +ReBRAC,door-cloned-v1,tlab/CORL/runs/prxbzf7x +ReBRAC,door-cloned-v1,tlab/CORL/runs/tdcevdlp +ReBRAC,door-cloned-v1,tlab/CORL/runs/kl86pkrh +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/gpzs7ko9 +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/mpqjq5iv +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/5fu56nvb +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/dfzw050e +ReBRAC,pen-cloned-v1,tlab/CORL/runs/j29tzijm +ReBRAC,pen-cloned-v1,tlab/CORL/runs/4en0ayoi +ReBRAC,pen-cloned-v1,tlab/CORL/runs/4eyhtq9d +ReBRAC,pen-cloned-v1,tlab/CORL/runs/so5u47bk +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/yskdxtsn +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/32gy3ulp +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/l4jqb7k7 +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/mgd0l6wd +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/owmozman +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/0w7wg74n +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/arp17u1o +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/erip8h20 +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/5pd0cva3 +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/2beklksm +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/68m94c0v +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/g5cshcb2 +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/qyggi592 +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/4u324o5s +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/vbgrhecg +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/2286tpe9 +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/24zf9wms +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/wc7z2xz3 +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/wxq1x3o2 +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/p8j2dn0y +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/0ig2xzi7 +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/puzizafo +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/m84hfj5d +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/u1y0nldu